Commit d34de4c4 authored by mashun1's avatar mashun1
Browse files

add_loop

parent 4352b6e6
a beach with waves and clouds at sunset
\ No newline at end of file
This diff is collapsed.
import os mport os
import time import time
from omegaconf import OmegaConf from omegaconf import OmegaConf
import torch import torch
from scripts.evaluation.funcs import load_model_checkpoint, save_videos, batch_ddim_sampling, get_latent_z from scripts.evaluation.funcs import load_model_checkpoint, save_videos, batch_ddim_sampling, get_latent_z
from utils.utils import instantiate_from_config from utils.utils import instantiate_from_config
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from einops import repeat from einops import repeat
import torchvision.transforms as transforms import torchvision.transforms as transforms
from pytorch_lightning import seed_everything from pytorch_lightning import seed_everything
class Image2Video(): class Image2Video():
def __init__(self,result_dir='./tmp/',gpu_num=1,resolution='256_256', **kwargs) -> None: def __init__(self,result_dir='./tmp/',gpu_num=1,resolution='256_256') -> None:
self.resolution = (int(resolution.split('_')[0]), int(resolution.split('_')[1])) #hw self.resolution = (int(resolution.split('_')[0]), int(resolution.split('_')[1])) #hw
self.download_model() self.download_model()
self.result_dir = result_dir self.result_dir = result_dir
if not os.path.exists(self.result_dir): if not os.path.exists(self.result_dir):
os.mkdir(self.result_dir) os.mkdir(self.result_dir)
ckpt_path='checkpoints/dynamicrafter_'+resolution.split('_')[1]+'_v1/model.ckpt' ckpt_path='checkpoints/dynamicrafter_'+resolution.split('_')[1]+'_v1/model.ckpt'
config_file='configs/inference_'+resolution.split('_')[1]+'_v1.0.yaml' config_file='configs/inference_'+resolution.split('_')[1]+'_v1.0.yaml'
config = OmegaConf.load(config_file) config = OmegaConf.load(config_file)
model_config = config.pop("model", OmegaConf.create()) model_config = config.pop("model", OmegaConf.create())
model_config['params']['unet_config']['params']['use_checkpoint']=False model_config['params']['unet_config']['params']['use_checkpoint']=False
model_list = [] model_list = []
for gpu_id in range(gpu_num): for gpu_id in range(gpu_num):
model = instantiate_from_config(model_config) model = instantiate_from_config(model_config)
# model = model.cuda(gpu_id) # model = model.cuda(gpu_id)
assert os.path.exists(ckpt_path), "Error: checkpoint Not Found!" assert os.path.exists(ckpt_path), "Error: checkpoint Not Found!"
model = load_model_checkpoint(model, ckpt_path) model = load_model_checkpoint(model, ckpt_path)
model.eval() model.eval()
model_list.append(model) model_list.append(model)
self.model_list = model_list self.model_list = model_list
self.save_fps = 8 self.save_fps = 8
self.kwargs = kwargs def get_image(self, image, prompt, steps=50, cfg_scale=7.5, eta=1.0, fs=3, seed=123):
seed_everything(seed)
def get_image(self, image, prompt, steps=50, cfg_scale=7.5, eta=1.0, fs=3, seed=123): transform = transforms.Compose([
seed_everything(seed) transforms.Resize(min(self.resolution)),
transform = transforms.Compose([ transforms.CenterCrop(self.resolution),
transforms.Resize(min(self.resolution)), ])
transforms.CenterCrop(self.resolution), torch.cuda.empty_cache()
]) print('start:', prompt, time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())))
torch.cuda.empty_cache() start = time.time()
print('start:', prompt, time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time()))) gpu_id=0
start = time.time() if steps > 60:
gpu_id=0 steps = 60
if steps > 60: model = self.model_list[gpu_id]
steps = 60 model = model.cuda()
model = self.model_list[gpu_id] batch_size=1
model = model.cuda() channels = model.model.diffusion_model.out_channels
batch_size=1 frames = model.temporal_length
channels = model.model.diffusion_model.out_channels h, w = self.resolution[0] // 8, self.resolution[1] // 8
frames = model.temporal_length noise_shape = [batch_size, channels, frames, h, w]
h, w = self.resolution[0] // 8, self.resolution[1] // 8
noise_shape = [batch_size, channels, frames, h, w] # text cond
with torch.no_grad(), torch.cuda.amp.autocast():
# text cond text_emb = model.get_learned_conditioning([prompt])
text_emb = model.get_learned_conditioning([prompt])
# img cond
# img cond img_tensor = torch.from_numpy(image).permute(2, 0, 1).float().to(model.device)
img_tensor = torch.from_numpy(image).permute(2, 0, 1).float().to(model.device) img_tensor = (img_tensor / 255. - 0.5) * 2
img_tensor = (img_tensor / 255. - 0.5) * 2
image_tensor_resized = transform(img_tensor) #3,h,w
image_tensor_resized = transform(img_tensor) #3,h,w videos = image_tensor_resized.unsqueeze(0) # bchw
videos = image_tensor_resized.unsqueeze(0) # bchw
z = get_latent_z(model, videos.unsqueeze(2)) #bc,1,hw
z = get_latent_z(model, videos.unsqueeze(2)) #bc,1,hw
img_tensor_repeat = repeat(z, 'b c t h w -> b c (repeat t) h w', repeat=frames)
img_tensor_repeat = repeat(z, 'b c t h w -> b c (repeat t) h w', repeat=frames)
cond_images = model.embedder(img_tensor.unsqueeze(0)) ## blc
cond_images = model.embedder(img_tensor.unsqueeze(0)) ## blc img_emb = model.image_proj_model(cond_images)
img_emb = model.image_proj_model(cond_images)
imtext_cond = torch.cat([text_emb, img_emb], dim=1)
imtext_cond = torch.cat([text_emb, img_emb], dim=1)
fs = torch.tensor([fs], dtype=torch.long, device=model.device)
fs = torch.tensor([fs], dtype=torch.long, device=model.device) cond = {"c_crossattn": [imtext_cond], "fs": fs, "c_concat": [img_tensor_repeat]}
cond = {"c_crossattn": [imtext_cond], "fs": fs, "c_concat": [img_tensor_repeat]}
## inference
## inference batch_samples = batch_ddim_sampling(model, cond, noise_shape, n_samples=1, ddim_steps=steps, ddim_eta=eta, cfg_scale=cfg_scale)
with torch.no_grad(), torch.cuda.amp.autocast(): ## b,samples,c,t,h,w
batch_samples = batch_ddim_sampling(model, cond, noise_shape, n_samples=1, ddim_steps=steps, ddim_eta=eta, cfg_scale=cfg_scale) prompt_str = prompt.replace("/", "_slash_") if "/" in prompt else prompt
## b,samples,c,t,h,w prompt_str = prompt_str.replace(" ", "_") if " " in prompt else prompt_str
prompt_str = prompt.replace("/", "_slash_") if "/" in prompt else prompt prompt_str=prompt_str[:40]
prompt_str = prompt_str.replace(" ", "_") if " " in prompt else prompt_str if len(prompt_str) == 0:
prompt_str=prompt_str[:40] prompt_str = 'empty_prompt'
if len(prompt_str) == 0:
prompt_str = 'empty_prompt' save_videos(batch_samples, self.result_dir, filenames=[prompt_str], fps=self.save_fps)
print(f"Saved in {prompt_str}. Time used: {(time.time() - start):.2f} seconds")
save_videos(batch_samples, self.result_dir, filenames=[prompt_str], fps=self.save_fps) model = model.cpu()
print(f"Saved in {prompt_str}. Time used: {(time.time() - start):.2f} seconds") return os.path.join(self.result_dir, f"{prompt_str}.mp4")
model = model.cpu()
return os.path.join(self.result_dir, f"{prompt_str}.mp4") def download_model(self):
REPO_ID = 'Doubiiu/DynamiCrafter_'+str(self.resolution[1]) if self.resolution[1]!=256 else 'Doubiiu/DynamiCrafter'
def download_model(self): filename_list = ['model.ckpt']
REPO_ID = 'Doubiiu/DynamiCrafter_'+str(self.resolution[1]) if self.resolution[1]!=256 else 'Doubiiu/DynamiCrafter' if not os.path.exists('./checkpoints/dynamicrafter_'+str(self.resolution[1])+'_v1/'):
filename_list = ['model.ckpt'] os.makedirs('./checkpoints/dynamicrafter_'+str(self.resolution[1])+'_v1/')
if not os.path.exists('./checkpoints/dynamicrafter_'+str(self.resolution[1])+'_v1/'): for filename in filename_list:
os.makedirs('./checkpoints/dynamicrafter_'+str(self.resolution[1])+'_v1/') local_file = os.path.join('./checkpoints/dynamicrafter_'+str(self.resolution[1])+'_v1/', filename)
for filename in filename_list: if not os.path.exists(local_file):
local_file = os.path.join('./checkpoints/dynamicrafter_'+str(self.resolution[1])+'_v1/', filename) hf_hub_download(repo_id=REPO_ID, filename=filename, local_dir='./checkpoints/dynamicrafter_'+str(self.resolution[1])+'_v1/', local_dir_use_symlinks=False)
if not os.path.exists(local_file):
hf_hub_download(repo_id=REPO_ID, filename=filename, local_dir='./checkpoints/dynamicrafter_'+str(self.resolution[1])+'_v1/', local_dir_use_symlinks=False) if __name__ == '__main__':
i2v = Image2Video()
if __name__ == '__main__': video_path = i2v.get_image('prompts/art.png','man fishing in a boat at sunset')
i2v = Image2Video() print('done', video_path)
video_path = i2v.get_image('prompts/art.png','man fishing in a boat at sunset')
print('done', video_path)
\ No newline at end of file
import os
import time
from omegaconf import OmegaConf
import torch
from scripts.evaluation.funcs import load_model_checkpoint, save_videos, batch_ddim_sampling, get_latent_z
from utils.utils import instantiate_from_config
from huggingface_hub import hf_hub_download
from einops import repeat
import torchvision.transforms as transforms
from pytorch_lightning import seed_everything
class Image2Video():
def __init__(self,result_dir='./tmp/',gpu_num=1,resolution='256_256') -> None:
self.resolution = (int(resolution.split('_')[0]), int(resolution.split('_')[1])) #hw
self.download_model()
self.result_dir = result_dir
if not os.path.exists(self.result_dir):
os.mkdir(self.result_dir)
ckpt_path='checkpoints/dynamicrafter_'+resolution.split('_')[1]+'_interp_v1/model.ckpt'
config_file='configs/inference_'+resolution.split('_')[1]+'_v1.0.yaml'
config = OmegaConf.load(config_file)
model_config = config.pop("model", OmegaConf.create())
model_config['params']['unet_config']['params']['use_checkpoint']=False
model_list = []
for gpu_id in range(gpu_num):
model = instantiate_from_config(model_config)
# model = model.cuda(gpu_id)
assert os.path.exists(ckpt_path), "Error: checkpoint Not Found!"
model = load_model_checkpoint(model, ckpt_path)
model.eval()
model_list.append(model)
self.model_list = model_list
self.save_fps = 8
def get_image(self, image, prompt, steps=50, cfg_scale=7.5, eta=1.0, fs=3, seed=123, image2=None):
seed_everything(seed)
transform = transforms.Compose([
transforms.Resize(min(self.resolution)),
transforms.CenterCrop(self.resolution),
])
torch.cuda.empty_cache()
print('start:', prompt, time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())))
start = time.time()
gpu_id=0
if steps > 60:
steps = 60
model = self.model_list[gpu_id]
model = model.cuda()
batch_size=1
channels = model.model.diffusion_model.out_channels
frames = model.temporal_length
h, w = self.resolution[0] // 8, self.resolution[1] // 8
noise_shape = [batch_size, channels, frames, h, w]
# text cond
with torch.no_grad(), torch.cuda.amp.autocast():
text_emb = model.get_learned_conditioning([prompt])
# img cond
img_tensor = torch.from_numpy(image).permute(2, 0, 1).float().to(model.device)
img_tensor = (img_tensor / 255. - 0.5) * 2
image_tensor_resized = transform(img_tensor) #3,h,w
videos = image_tensor_resized.unsqueeze(0) # bchw
z = get_latent_z(model, videos.unsqueeze(2)) #bc,1,hw
if image2 is not None:
img_tensor2 = torch.from_numpy(image2).permute(2, 0, 1).float().to(model.device)
img_tensor2 = (img_tensor2 / 255. - 0.5) * 2
image_tensor_resized2 = transform(img_tensor2) #3,h,w
videos2 = image_tensor_resized2.unsqueeze(0) # bchw
z2 = get_latent_z(model, videos2.unsqueeze(2)) #bc,1,hw
img_tensor_repeat = repeat(z, 'b c t h w -> b c (repeat t) h w', repeat=frames)
img_tensor_repeat = torch.zeros_like(img_tensor_repeat)
## old
img_tensor_repeat[:,:,:1,:,:] = z
if image2 is not None:
img_tensor_repeat[:,:,-1:,:,:] = z2
else:
img_tensor_repeat[:,:,-1:,:,:] = z
cond_images = model.embedder(img_tensor.unsqueeze(0)) ## blc
img_emb = model.image_proj_model(cond_images)
imtext_cond = torch.cat([text_emb, img_emb], dim=1)
fs = torch.tensor([fs], dtype=torch.long, device=model.device)
cond = {"c_crossattn": [imtext_cond], "fs": fs, "c_concat": [img_tensor_repeat]}
## inference
batch_samples = batch_ddim_sampling(model, cond, noise_shape, n_samples=1, ddim_steps=steps, ddim_eta=eta, cfg_scale=cfg_scale)
## remove the last frame
if image2 is None:
batch_samples = batch_samples[:,:,:,:-1,...]
## b,samples,c,t,h,w
prompt_str = prompt.replace("/", "_slash_") if "/" in prompt else prompt
prompt_str = prompt_str.replace(" ", "_") if " " in prompt else prompt_str
prompt_str=prompt_str[:40]
if len(prompt_str) == 0:
prompt_str = 'empty_prompt'
save_videos(batch_samples, self.result_dir, filenames=[prompt_str], fps=self.save_fps)
print(f"Saved in {prompt_str}. Time used: {(time.time() - start):.2f} seconds")
model = model.cpu()
return os.path.join(self.result_dir, f"{prompt_str}.mp4")
def download_model(self):
REPO_ID = 'Doubiiu/DynamiCrafter_'+str(self.resolution[1])+'_Interp'
filename_list = ['model.ckpt']
if not os.path.exists('./checkpoints/dynamicrafter_'+str(self.resolution[1])+'_interp_v1/'):
os.makedirs('./checkpoints/dynamicrafter_'+str(self.resolution[1])+'_interp_v1/')
for filename in filename_list:
local_file = os.path.join('./checkpoints/dynamicrafter_'+str(self.resolution[1])+'_interp_v1/', filename)
if not os.path.exists(local_file):
hf_hub_download(repo_id=REPO_ID, filename=filename, local_dir='./checkpoints/dynamicrafter_'+str(self.resolution[1])+'_interp_v1/', local_dir_use_symlinks=False)
if __name__ == '__main__':
i2v = Image2Video()
video_path = i2v.get_image('prompts/art.png','man fishing in a boat at sunset')
print('done', video_path)
version=$1 # interp or loop
ckpt=checkpoints/dynamicrafter_512_interp_v1/model.ckpt
config=configs/inference_512_v1.0.yaml
prompt_dir=prompts/512_$1/
res_dir="results"
FS=5 ## This model adopts FPS=5, range recommended: 5-30 (smaller value -> larger motion)
if [ "$1" == "interp" ]; then
seed=12306
name=dynamicrafter_512_$1_seed${seed}
CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/inference.py \
--seed ${seed} \
--ckpt_path $ckpt \
--config $config \
--savedir $res_dir/$name \
--n_samples 1 \
--bs 1 --height 320 --width 512 \
--unconditional_guidance_scale 7.5 \
--ddim_steps 50 \
--ddim_eta 1.0 \
--prompt_dir $prompt_dir \
--text_input \
--video_length 16 \
--frame_stride ${FS} \
--timestep_spacing 'uniform_trailing' --guidance_rescale 0.7 --perframe_ae --interp
else
seed=234
name=dynamicrafter_512_$1_seed${seed}
CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/inference.py \
--seed ${seed} \
--ckpt_path $ckpt \
--config $config \
--savedir $res_dir/$name \
--n_samples 1 \
--bs 1 --height 320 --width 512 \
--unconditional_guidance_scale 7.5 \
--ddim_steps 50 \
--ddim_eta 1.0 \
--prompt_dir $prompt_dir \
--text_input \
--video_length 16 \
--frame_stride ${FS} \
--timestep_spacing 'uniform_trailing' --guidance_rescale 0.7 --perframe_ae --loop
fi
...@@ -63,8 +63,8 @@ fi ...@@ -63,8 +63,8 @@ fi
## inference using single node with multi-GPUs: ## inference using single node with multi-GPUs:
if [ "$1" == "256" ]; then if [ "$1" == "256" ]; then
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python3 -m torch.distributed.launch \ CUDA_VISIBLE_DEVICES=0,1 python3 -m torch.distributed.launch \
--nproc_per_node=8 --nnodes=1 --master_addr=127.0.0.1 --master_port=23456 --node_rank=0 \ --nproc_per_node=2 --nnodes=1 --master_addr=127.0.0.1 --master_port=23456 --node_rank=0 \
scripts/evaluation/ddp_wrapper.py \ scripts/evaluation/ddp_wrapper.py \
--module 'inference' \ --module 'inference' \
--seed ${seed} \ --seed ${seed} \
...@@ -81,8 +81,8 @@ scripts/evaluation/ddp_wrapper.py \ ...@@ -81,8 +81,8 @@ scripts/evaluation/ddp_wrapper.py \
--video_length 16 \ --video_length 16 \
--frame_stride ${FS} --frame_stride ${FS}
else else
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python3 -m torch.distributed.launch \ CUDA_VISIBLE_DEVICES=0,1 python3 -m torch.distributed.launch \
--nproc_per_node=8 --nnodes=1 --master_addr=127.0.0.1 --master_port=23456 --node_rank=0 \ --nproc_per_node=2 --nnodes=1 --master_addr=127.0.0.1 --master_port=23456 --node_rank=0 \
scripts/evaluation/ddp_wrapper.py \ scripts/evaluation/ddp_wrapper.py \
--module 'inference' \ --module 'inference' \
--seed ${seed} \ --seed ${seed} \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment