Commit d34de4c4 authored by mashun1's avatar mashun1
Browse files

add_loop

parent 4352b6e6
a beach with waves and clouds at sunset
\ No newline at end of file
...@@ -61,7 +61,7 @@ def load_prompts(prompt_file): ...@@ -61,7 +61,7 @@ def load_prompts(prompt_file):
f.close() f.close()
return prompt_list return prompt_list
def load_data_prompts(data_dir, video_size=(256,256), video_frames=16, gfi=False): def load_data_prompts(data_dir, video_size=(256,256), video_frames=16, interp=False):
transform = transforms.Compose([ transform = transforms.Compose([
transforms.Resize(min(video_size)), transforms.Resize(min(video_size)),
transforms.CenterCrop(video_size), transforms.CenterCrop(video_size),
...@@ -85,12 +85,22 @@ def load_data_prompts(data_dir, video_size=(256,256), video_frames=16, gfi=False ...@@ -85,12 +85,22 @@ def load_data_prompts(data_dir, video_size=(256,256), video_frames=16, gfi=False
prompt_list = load_prompts(prompt_file[default_idx]) prompt_list = load_prompts(prompt_file[default_idx])
n_samples = len(prompt_list) n_samples = len(prompt_list)
for idx in range(n_samples): for idx in range(n_samples):
if interp:
image1 = Image.open(file_list[2*idx]).convert('RGB')
image_tensor1 = transform(image1).unsqueeze(1) # [c,1,h,w]
image2 = Image.open(file_list[2*idx+1]).convert('RGB')
image_tensor2 = transform(image2).unsqueeze(1) # [c,1,h,w]
frame_tensor1 = repeat(image_tensor1, 'c t h w -> c (repeat t) h w', repeat=video_frames//2)
frame_tensor2 = repeat(image_tensor2, 'c t h w -> c (repeat t) h w', repeat=video_frames//2)
frame_tensor = torch.cat([frame_tensor1, frame_tensor2], dim=1)
_, filename = os.path.split(file_list[idx*2])
else:
image = Image.open(file_list[idx]).convert('RGB') image = Image.open(file_list[idx]).convert('RGB')
image_tensor = transform(image).unsqueeze(1) # [c,1,h,w] image_tensor = transform(image).unsqueeze(1) # [c,1,h,w]
frame_tensor = repeat(image_tensor, 'c t h w -> c (repeat t) h w', repeat=video_frames) frame_tensor = repeat(image_tensor, 'c t h w -> c (repeat t) h w', repeat=video_frames)
_, filename = os.path.split(file_list[idx])
data_list.append(frame_tensor) data_list.append(frame_tensor)
_, filename = os.path.split(file_list[idx])
filename_list.append(filename) filename_list.append(filename)
return filename_list, data_list, prompt_list return filename_list, data_list, prompt_list
...@@ -153,7 +163,7 @@ def get_latent_z(model, videos): ...@@ -153,7 +163,7 @@ def get_latent_z(model, videos):
def image_guided_synthesis(model, prompts, videos, noise_shape, n_samples=1, ddim_steps=50, ddim_eta=1., \ def image_guided_synthesis(model, prompts, videos, noise_shape, n_samples=1, ddim_steps=50, ddim_eta=1., \
unconditional_guidance_scale=1.0, cfg_img=None, fs=None, text_input=False, multiple_cond_cfg=False, loop=False, gfi=False, timestep_spacing='uniform', guidance_rescale=0.0, **kwargs): unconditional_guidance_scale=1.0, cfg_img=None, fs=None, text_input=False, multiple_cond_cfg=False, loop=False, interp=False, timestep_spacing='uniform', guidance_rescale=0.0, **kwargs):
ddim_sampler = DDIMSampler(model) if not multiple_cond_cfg else DDIMSampler_multicond(model) ddim_sampler = DDIMSampler(model) if not multiple_cond_cfg else DDIMSampler_multicond(model)
batch_size = noise_shape[0] batch_size = noise_shape[0]
fs = torch.tensor([fs] * batch_size, dtype=torch.long, device=model.device) fs = torch.tensor([fs] * batch_size, dtype=torch.long, device=model.device)
...@@ -169,7 +179,7 @@ def image_guided_synthesis(model, prompts, videos, noise_shape, n_samples=1, ddi ...@@ -169,7 +179,7 @@ def image_guided_synthesis(model, prompts, videos, noise_shape, n_samples=1, ddi
cond = {"c_crossattn": [torch.cat([cond_emb,img_emb], dim=1)]} cond = {"c_crossattn": [torch.cat([cond_emb,img_emb], dim=1)]}
if model.model.conditioning_key == 'hybrid': if model.model.conditioning_key == 'hybrid':
z = get_latent_z(model, videos) # b c t h w z = get_latent_z(model, videos) # b c t h w
if loop or gfi: if loop or interp:
img_cat_cond = torch.zeros_like(z) img_cat_cond = torch.zeros_like(z)
img_cat_cond[:,:,0,:,:] = z[:,:,0,:,:] img_cat_cond[:,:,0,:,:] = z[:,:,0,:,:]
img_cat_cond[:,:,-1,:,:] = z[:,:,-1,:,:] img_cat_cond[:,:,-1,:,:] = z[:,:,-1,:,:]
...@@ -271,7 +281,7 @@ def run_inference(args, gpu_num, gpu_no): ...@@ -271,7 +281,7 @@ def run_inference(args, gpu_num, gpu_no):
## prompt file setting ## prompt file setting
assert os.path.exists(args.prompt_dir), "Error: prompt file Not Found!" assert os.path.exists(args.prompt_dir), "Error: prompt file Not Found!"
filename_list, data_list, prompt_list = load_data_prompts(args.prompt_dir, video_size=(args.height, args.width), video_frames=n_frames, gfi=args.gfi) filename_list, data_list, prompt_list = load_data_prompts(args.prompt_dir, video_size=(args.height, args.width), video_frames=n_frames, interp=args.interp)
num_samples = len(prompt_list) num_samples = len(prompt_list)
samples_split = num_samples // gpu_num samples_split = num_samples // gpu_num
print('Prompts testing [rank:%d] %d/%d samples loaded.'%(gpu_no, samples_split, num_samples)) print('Prompts testing [rank:%d] %d/%d samples loaded.'%(gpu_no, samples_split, num_samples))
...@@ -293,7 +303,7 @@ def run_inference(args, gpu_num, gpu_no): ...@@ -293,7 +303,7 @@ def run_inference(args, gpu_num, gpu_no):
videos = videos.unsqueeze(0).to("cuda") videos = videos.unsqueeze(0).to("cuda")
batch_samples = image_guided_synthesis(model, prompts, videos, noise_shape, args.n_samples, args.ddim_steps, args.ddim_eta, \ batch_samples = image_guided_synthesis(model, prompts, videos, noise_shape, args.n_samples, args.ddim_steps, args.ddim_eta, \
args.unconditional_guidance_scale, args.cfg_img, args.frame_stride, args.text_input, args.multiple_cond_cfg, args.loop, args.gfi, args.timestep_spacing, args.guidance_rescale) args.unconditional_guidance_scale, args.cfg_img, args.frame_stride, args.text_input, args.multiple_cond_cfg, args.loop, args.interp, args.timestep_spacing, args.guidance_rescale)
## save each example individually ## save each example individually
for nn, samples in enumerate(batch_samples): for nn, samples in enumerate(batch_samples):
...@@ -332,7 +342,7 @@ def get_parser(): ...@@ -332,7 +342,7 @@ def get_parser():
## currently not support looping video and generative frame interpolation ## currently not support looping video and generative frame interpolation
parser.add_argument("--loop", action='store_true', default=False, help="generate looping videos or not") parser.add_argument("--loop", action='store_true', default=False, help="generate looping videos or not")
parser.add_argument("--gfi", action='store_true', default=False, help="generate generative frame interpolation (gfi) or not") parser.add_argument("--interp", action='store_true', default=False, help="generate generative frame interpolation or not")
return parser return parser
......
import os mport os
import time import time
from omegaconf import OmegaConf from omegaconf import OmegaConf
import torch import torch
...@@ -11,7 +11,7 @@ from pytorch_lightning import seed_everything ...@@ -11,7 +11,7 @@ from pytorch_lightning import seed_everything
class Image2Video(): class Image2Video():
def __init__(self,result_dir='./tmp/',gpu_num=1,resolution='256_256', **kwargs) -> None: def __init__(self,result_dir='./tmp/',gpu_num=1,resolution='256_256') -> None:
self.resolution = (int(resolution.split('_')[0]), int(resolution.split('_')[1])) #hw self.resolution = (int(resolution.split('_')[0]), int(resolution.split('_')[1])) #hw
self.download_model() self.download_model()
...@@ -34,8 +34,6 @@ class Image2Video(): ...@@ -34,8 +34,6 @@ class Image2Video():
self.model_list = model_list self.model_list = model_list
self.save_fps = 8 self.save_fps = 8
self.kwargs = kwargs
def get_image(self, image, prompt, steps=50, cfg_scale=7.5, eta=1.0, fs=3, seed=123): def get_image(self, image, prompt, steps=50, cfg_scale=7.5, eta=1.0, fs=3, seed=123):
seed_everything(seed) seed_everything(seed)
transform = transforms.Compose([ transform = transforms.Compose([
...@@ -57,6 +55,7 @@ class Image2Video(): ...@@ -57,6 +55,7 @@ class Image2Video():
noise_shape = [batch_size, channels, frames, h, w] noise_shape = [batch_size, channels, frames, h, w]
# text cond # text cond
with torch.no_grad(), torch.cuda.amp.autocast():
text_emb = model.get_learned_conditioning([prompt]) text_emb = model.get_learned_conditioning([prompt])
# img cond # img cond
...@@ -79,7 +78,6 @@ class Image2Video(): ...@@ -79,7 +78,6 @@ class Image2Video():
cond = {"c_crossattn": [imtext_cond], "fs": fs, "c_concat": [img_tensor_repeat]} cond = {"c_crossattn": [imtext_cond], "fs": fs, "c_concat": [img_tensor_repeat]}
## inference ## inference
with torch.no_grad(), torch.cuda.amp.autocast():
batch_samples = batch_ddim_sampling(model, cond, noise_shape, n_samples=1, ddim_steps=steps, ddim_eta=eta, cfg_scale=cfg_scale) batch_samples = batch_ddim_sampling(model, cond, noise_shape, n_samples=1, ddim_steps=steps, ddim_eta=eta, cfg_scale=cfg_scale)
## b,samples,c,t,h,w ## b,samples,c,t,h,w
prompt_str = prompt.replace("/", "_slash_") if "/" in prompt else prompt prompt_str = prompt.replace("/", "_slash_") if "/" in prompt else prompt
......
import os
import time
from omegaconf import OmegaConf
import torch
from scripts.evaluation.funcs import load_model_checkpoint, save_videos, batch_ddim_sampling, get_latent_z
from utils.utils import instantiate_from_config
from huggingface_hub import hf_hub_download
from einops import repeat
import torchvision.transforms as transforms
from pytorch_lightning import seed_everything
class Image2Video():
def __init__(self,result_dir='./tmp/',gpu_num=1,resolution='256_256') -> None:
self.resolution = (int(resolution.split('_')[0]), int(resolution.split('_')[1])) #hw
self.download_model()
self.result_dir = result_dir
if not os.path.exists(self.result_dir):
os.mkdir(self.result_dir)
ckpt_path='checkpoints/dynamicrafter_'+resolution.split('_')[1]+'_interp_v1/model.ckpt'
config_file='configs/inference_'+resolution.split('_')[1]+'_v1.0.yaml'
config = OmegaConf.load(config_file)
model_config = config.pop("model", OmegaConf.create())
model_config['params']['unet_config']['params']['use_checkpoint']=False
model_list = []
for gpu_id in range(gpu_num):
model = instantiate_from_config(model_config)
# model = model.cuda(gpu_id)
assert os.path.exists(ckpt_path), "Error: checkpoint Not Found!"
model = load_model_checkpoint(model, ckpt_path)
model.eval()
model_list.append(model)
self.model_list = model_list
self.save_fps = 8
def get_image(self, image, prompt, steps=50, cfg_scale=7.5, eta=1.0, fs=3, seed=123, image2=None):
seed_everything(seed)
transform = transforms.Compose([
transforms.Resize(min(self.resolution)),
transforms.CenterCrop(self.resolution),
])
torch.cuda.empty_cache()
print('start:', prompt, time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())))
start = time.time()
gpu_id=0
if steps > 60:
steps = 60
model = self.model_list[gpu_id]
model = model.cuda()
batch_size=1
channels = model.model.diffusion_model.out_channels
frames = model.temporal_length
h, w = self.resolution[0] // 8, self.resolution[1] // 8
noise_shape = [batch_size, channels, frames, h, w]
# text cond
with torch.no_grad(), torch.cuda.amp.autocast():
text_emb = model.get_learned_conditioning([prompt])
# img cond
img_tensor = torch.from_numpy(image).permute(2, 0, 1).float().to(model.device)
img_tensor = (img_tensor / 255. - 0.5) * 2
image_tensor_resized = transform(img_tensor) #3,h,w
videos = image_tensor_resized.unsqueeze(0) # bchw
z = get_latent_z(model, videos.unsqueeze(2)) #bc,1,hw
if image2 is not None:
img_tensor2 = torch.from_numpy(image2).permute(2, 0, 1).float().to(model.device)
img_tensor2 = (img_tensor2 / 255. - 0.5) * 2
image_tensor_resized2 = transform(img_tensor2) #3,h,w
videos2 = image_tensor_resized2.unsqueeze(0) # bchw
z2 = get_latent_z(model, videos2.unsqueeze(2)) #bc,1,hw
img_tensor_repeat = repeat(z, 'b c t h w -> b c (repeat t) h w', repeat=frames)
img_tensor_repeat = torch.zeros_like(img_tensor_repeat)
## old
img_tensor_repeat[:,:,:1,:,:] = z
if image2 is not None:
img_tensor_repeat[:,:,-1:,:,:] = z2
else:
img_tensor_repeat[:,:,-1:,:,:] = z
cond_images = model.embedder(img_tensor.unsqueeze(0)) ## blc
img_emb = model.image_proj_model(cond_images)
imtext_cond = torch.cat([text_emb, img_emb], dim=1)
fs = torch.tensor([fs], dtype=torch.long, device=model.device)
cond = {"c_crossattn": [imtext_cond], "fs": fs, "c_concat": [img_tensor_repeat]}
## inference
batch_samples = batch_ddim_sampling(model, cond, noise_shape, n_samples=1, ddim_steps=steps, ddim_eta=eta, cfg_scale=cfg_scale)
## remove the last frame
if image2 is None:
batch_samples = batch_samples[:,:,:,:-1,...]
## b,samples,c,t,h,w
prompt_str = prompt.replace("/", "_slash_") if "/" in prompt else prompt
prompt_str = prompt_str.replace(" ", "_") if " " in prompt else prompt_str
prompt_str=prompt_str[:40]
if len(prompt_str) == 0:
prompt_str = 'empty_prompt'
save_videos(batch_samples, self.result_dir, filenames=[prompt_str], fps=self.save_fps)
print(f"Saved in {prompt_str}. Time used: {(time.time() - start):.2f} seconds")
model = model.cpu()
return os.path.join(self.result_dir, f"{prompt_str}.mp4")
def download_model(self):
REPO_ID = 'Doubiiu/DynamiCrafter_'+str(self.resolution[1])+'_Interp'
filename_list = ['model.ckpt']
if not os.path.exists('./checkpoints/dynamicrafter_'+str(self.resolution[1])+'_interp_v1/'):
os.makedirs('./checkpoints/dynamicrafter_'+str(self.resolution[1])+'_interp_v1/')
for filename in filename_list:
local_file = os.path.join('./checkpoints/dynamicrafter_'+str(self.resolution[1])+'_interp_v1/', filename)
if not os.path.exists(local_file):
hf_hub_download(repo_id=REPO_ID, filename=filename, local_dir='./checkpoints/dynamicrafter_'+str(self.resolution[1])+'_interp_v1/', local_dir_use_symlinks=False)
if __name__ == '__main__':
i2v = Image2Video()
video_path = i2v.get_image('prompts/art.png','man fishing in a boat at sunset')
print('done', video_path)
version=$1 # interp or loop
ckpt=checkpoints/dynamicrafter_512_interp_v1/model.ckpt
config=configs/inference_512_v1.0.yaml
prompt_dir=prompts/512_$1/
res_dir="results"
FS=5 ## This model adopts FPS=5, range recommended: 5-30 (smaller value -> larger motion)
if [ "$1" == "interp" ]; then
seed=12306
name=dynamicrafter_512_$1_seed${seed}
CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/inference.py \
--seed ${seed} \
--ckpt_path $ckpt \
--config $config \
--savedir $res_dir/$name \
--n_samples 1 \
--bs 1 --height 320 --width 512 \
--unconditional_guidance_scale 7.5 \
--ddim_steps 50 \
--ddim_eta 1.0 \
--prompt_dir $prompt_dir \
--text_input \
--video_length 16 \
--frame_stride ${FS} \
--timestep_spacing 'uniform_trailing' --guidance_rescale 0.7 --perframe_ae --interp
else
seed=234
name=dynamicrafter_512_$1_seed${seed}
CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/inference.py \
--seed ${seed} \
--ckpt_path $ckpt \
--config $config \
--savedir $res_dir/$name \
--n_samples 1 \
--bs 1 --height 320 --width 512 \
--unconditional_guidance_scale 7.5 \
--ddim_steps 50 \
--ddim_eta 1.0 \
--prompt_dir $prompt_dir \
--text_input \
--video_length 16 \
--frame_stride ${FS} \
--timestep_spacing 'uniform_trailing' --guidance_rescale 0.7 --perframe_ae --loop
fi
...@@ -63,8 +63,8 @@ fi ...@@ -63,8 +63,8 @@ fi
## inference using single node with multi-GPUs: ## inference using single node with multi-GPUs:
if [ "$1" == "256" ]; then if [ "$1" == "256" ]; then
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python3 -m torch.distributed.launch \ CUDA_VISIBLE_DEVICES=0,1 python3 -m torch.distributed.launch \
--nproc_per_node=8 --nnodes=1 --master_addr=127.0.0.1 --master_port=23456 --node_rank=0 \ --nproc_per_node=2 --nnodes=1 --master_addr=127.0.0.1 --master_port=23456 --node_rank=0 \
scripts/evaluation/ddp_wrapper.py \ scripts/evaluation/ddp_wrapper.py \
--module 'inference' \ --module 'inference' \
--seed ${seed} \ --seed ${seed} \
...@@ -81,8 +81,8 @@ scripts/evaluation/ddp_wrapper.py \ ...@@ -81,8 +81,8 @@ scripts/evaluation/ddp_wrapper.py \
--video_length 16 \ --video_length 16 \
--frame_stride ${FS} --frame_stride ${FS}
else else
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python3 -m torch.distributed.launch \ CUDA_VISIBLE_DEVICES=0,1 python3 -m torch.distributed.launch \
--nproc_per_node=8 --nnodes=1 --master_addr=127.0.0.1 --master_port=23456 --node_rank=0 \ --nproc_per_node=2 --nnodes=1 --master_addr=127.0.0.1 --master_port=23456 --node_rank=0 \
scripts/evaluation/ddp_wrapper.py \ scripts/evaluation/ddp_wrapper.py \
--module 'inference' \ --module 'inference' \
--seed ${seed} \ --seed ${seed} \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment