Commit d34de4c4 authored by mashun1's avatar mashun1
Browse files

add_loop

parent 4352b6e6
a beach with waves and clouds at sunset
\ No newline at end of file
import argparse, os, sys, glob import argparse, os, sys, glob
import datetime, time import datetime, time
from omegaconf import OmegaConf from omegaconf import OmegaConf
from tqdm import tqdm from tqdm import tqdm
from einops import rearrange, repeat from einops import rearrange, repeat
from collections import OrderedDict from collections import OrderedDict
import torch import torch
import torchvision import torchvision
import torchvision.transforms as transforms import torchvision.transforms as transforms
from pytorch_lightning import seed_everything from pytorch_lightning import seed_everything
from PIL import Image from PIL import Image
sys.path.insert(1, os.path.join(sys.path[0], '..', '..')) sys.path.insert(1, os.path.join(sys.path[0], '..', '..'))
from lvdm.models.samplers.ddim import DDIMSampler from lvdm.models.samplers.ddim import DDIMSampler
from lvdm.models.samplers.ddim_multiplecond import DDIMSampler as DDIMSampler_multicond from lvdm.models.samplers.ddim_multiplecond import DDIMSampler as DDIMSampler_multicond
from utils.utils import instantiate_from_config from utils.utils import instantiate_from_config
def get_filelist(data_dir, postfixes): def get_filelist(data_dir, postfixes):
patterns = [os.path.join(data_dir, f"*.{postfix}") for postfix in postfixes] patterns = [os.path.join(data_dir, f"*.{postfix}") for postfix in postfixes]
file_list = [] file_list = []
for pattern in patterns: for pattern in patterns:
file_list.extend(glob.glob(pattern)) file_list.extend(glob.glob(pattern))
file_list.sort() file_list.sort()
return file_list return file_list
def load_model_checkpoint(model, ckpt): def load_model_checkpoint(model, ckpt):
state_dict = torch.load(ckpt, map_location="cpu") state_dict = torch.load(ckpt, map_location="cpu")
if "state_dict" in list(state_dict.keys()): if "state_dict" in list(state_dict.keys()):
state_dict = state_dict["state_dict"] state_dict = state_dict["state_dict"]
try: try:
model.load_state_dict(state_dict, strict=True) model.load_state_dict(state_dict, strict=True)
except: except:
## rename the keys for 256x256 model ## rename the keys for 256x256 model
new_pl_sd = OrderedDict() new_pl_sd = OrderedDict()
for k,v in state_dict.items(): for k,v in state_dict.items():
new_pl_sd[k] = v new_pl_sd[k] = v
for k in list(new_pl_sd.keys()): for k in list(new_pl_sd.keys()):
if "framestride_embed" in k: if "framestride_embed" in k:
new_key = k.replace("framestride_embed", "fps_embedding") new_key = k.replace("framestride_embed", "fps_embedding")
new_pl_sd[new_key] = new_pl_sd[k] new_pl_sd[new_key] = new_pl_sd[k]
del new_pl_sd[k] del new_pl_sd[k]
model.load_state_dict(new_pl_sd, strict=True) model.load_state_dict(new_pl_sd, strict=True)
else: else:
# deepspeed # deepspeed
new_pl_sd = OrderedDict() new_pl_sd = OrderedDict()
for key in state_dict['module'].keys(): for key in state_dict['module'].keys():
new_pl_sd[key[16:]]=state_dict['module'][key] new_pl_sd[key[16:]]=state_dict['module'][key]
model.load_state_dict(new_pl_sd) model.load_state_dict(new_pl_sd)
print('>>> model checkpoint loaded.') print('>>> model checkpoint loaded.')
return model return model
def load_prompts(prompt_file): def load_prompts(prompt_file):
f = open(prompt_file, 'r') f = open(prompt_file, 'r')
prompt_list = [] prompt_list = []
for idx, line in enumerate(f.readlines()): for idx, line in enumerate(f.readlines()):
l = line.strip() l = line.strip()
if len(l) != 0: if len(l) != 0:
prompt_list.append(l) prompt_list.append(l)
f.close() f.close()
return prompt_list return prompt_list
def load_data_prompts(data_dir, video_size=(256,256), video_frames=16, gfi=False): def load_data_prompts(data_dir, video_size=(256,256), video_frames=16, interp=False):
transform = transforms.Compose([ transform = transforms.Compose([
transforms.Resize(min(video_size)), transforms.Resize(min(video_size)),
transforms.CenterCrop(video_size), transforms.CenterCrop(video_size),
transforms.ToTensor(), transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))]) transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
## load prompts ## load prompts
prompt_file = get_filelist(data_dir, ['txt']) prompt_file = get_filelist(data_dir, ['txt'])
assert len(prompt_file) > 0, "Error: found NO prompt file!" assert len(prompt_file) > 0, "Error: found NO prompt file!"
###### default prompt ###### default prompt
default_idx = 0 default_idx = 0
default_idx = min(default_idx, len(prompt_file)-1) default_idx = min(default_idx, len(prompt_file)-1)
if len(prompt_file) > 1: if len(prompt_file) > 1:
print(f"Warning: multiple prompt files exist. The one {os.path.split(prompt_file[default_idx])[1]} is used.") print(f"Warning: multiple prompt files exist. The one {os.path.split(prompt_file[default_idx])[1]} is used.")
## only use the first one (sorted by name) if multiple exist ## only use the first one (sorted by name) if multiple exist
## load video ## load video
file_list = get_filelist(data_dir, ['jpg', 'png', 'jpeg', 'JPEG', 'PNG']) file_list = get_filelist(data_dir, ['jpg', 'png', 'jpeg', 'JPEG', 'PNG'])
# assert len(file_list) == n_samples, "Error: data and prompts are NOT paired!" # assert len(file_list) == n_samples, "Error: data and prompts are NOT paired!"
data_list = [] data_list = []
filename_list = [] filename_list = []
prompt_list = load_prompts(prompt_file[default_idx]) prompt_list = load_prompts(prompt_file[default_idx])
n_samples = len(prompt_list) n_samples = len(prompt_list)
for idx in range(n_samples): for idx in range(n_samples):
image = Image.open(file_list[idx]).convert('RGB') if interp:
image_tensor = transform(image).unsqueeze(1) # [c,1,h,w] image1 = Image.open(file_list[2*idx]).convert('RGB')
frame_tensor = repeat(image_tensor, 'c t h w -> c (repeat t) h w', repeat=video_frames) image_tensor1 = transform(image1).unsqueeze(1) # [c,1,h,w]
image2 = Image.open(file_list[2*idx+1]).convert('RGB')
data_list.append(frame_tensor) image_tensor2 = transform(image2).unsqueeze(1) # [c,1,h,w]
_, filename = os.path.split(file_list[idx]) frame_tensor1 = repeat(image_tensor1, 'c t h w -> c (repeat t) h w', repeat=video_frames//2)
filename_list.append(filename) frame_tensor2 = repeat(image_tensor2, 'c t h w -> c (repeat t) h w', repeat=video_frames//2)
frame_tensor = torch.cat([frame_tensor1, frame_tensor2], dim=1)
return filename_list, data_list, prompt_list _, filename = os.path.split(file_list[idx*2])
else:
image = Image.open(file_list[idx]).convert('RGB')
def save_results(prompt, samples, filename, fakedir, fps=8, loop=False): image_tensor = transform(image).unsqueeze(1) # [c,1,h,w]
filename = filename.split('.')[0]+'.mp4' frame_tensor = repeat(image_tensor, 'c t h w -> c (repeat t) h w', repeat=video_frames)
prompt = prompt[0] if isinstance(prompt, list) else prompt _, filename = os.path.split(file_list[idx])
## save video data_list.append(frame_tensor)
videos = [samples] filename_list.append(filename)
savedirs = [fakedir]
for idx, video in enumerate(videos): return filename_list, data_list, prompt_list
if video is None:
continue
# b,c,t,h,w def save_results(prompt, samples, filename, fakedir, fps=8, loop=False):
video = video.detach().cpu() filename = filename.split('.')[0]+'.mp4'
video = torch.clamp(video.float(), -1., 1.) prompt = prompt[0] if isinstance(prompt, list) else prompt
n = video.shape[0]
video = video.permute(2, 0, 1, 3, 4) # t,n,c,h,w ## save video
if loop: videos = [samples]
video = video[:-1,...] savedirs = [fakedir]
for idx, video in enumerate(videos):
frame_grids = [torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0) for framesheet in video] #[3, 1*h, n*w] if video is None:
grid = torch.stack(frame_grids, dim=0) # stack in temporal dim [t, 3, h, n*w] continue
grid = (grid + 1.0) / 2.0 # b,c,t,h,w
grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1) video = video.detach().cpu()
path = os.path.join(savedirs[idx], filename) video = torch.clamp(video.float(), -1., 1.)
torchvision.io.write_video(path, grid, fps=fps, video_codec='h264', options={'crf': '10'}) ## crf indicates the quality n = video.shape[0]
video = video.permute(2, 0, 1, 3, 4) # t,n,c,h,w
if loop:
def save_results_seperate(prompt, samples, filename, fakedir, fps=10, loop=False): video = video[:-1,...]
prompt = prompt[0] if isinstance(prompt, list) else prompt
frame_grids = [torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0) for framesheet in video] #[3, 1*h, n*w]
## save video grid = torch.stack(frame_grids, dim=0) # stack in temporal dim [t, 3, h, n*w]
videos = [samples] grid = (grid + 1.0) / 2.0
savedirs = [fakedir] grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
for idx, video in enumerate(videos): path = os.path.join(savedirs[idx], filename)
if video is None: torchvision.io.write_video(path, grid, fps=fps, video_codec='h264', options={'crf': '10'}) ## crf indicates the quality
continue
# b,c,t,h,w
video = video.detach().cpu() def save_results_seperate(prompt, samples, filename, fakedir, fps=10, loop=False):
if loop: # remove the last frame prompt = prompt[0] if isinstance(prompt, list) else prompt
video = video[:,:,:-1,...]
video = torch.clamp(video.float(), -1., 1.) ## save video
n = video.shape[0] videos = [samples]
for i in range(n): savedirs = [fakedir]
grid = video[i,...] for idx, video in enumerate(videos):
grid = (grid + 1.0) / 2.0 if video is None:
grid = (grid * 255).to(torch.uint8).permute(1, 2, 3, 0) #thwc continue
path = os.path.join(savedirs[idx].replace('samples', 'samples_separate'), f'{filename.split(".")[0]}_sample{i}.mp4') # b,c,t,h,w
torchvision.io.write_video(path, grid, fps=fps, video_codec='h264', options={'crf': '10'}) video = video.detach().cpu()
if loop: # remove the last frame
def get_latent_z(model, videos): video = video[:,:,:-1,...]
b, c, t, h, w = videos.shape video = torch.clamp(video.float(), -1., 1.)
x = rearrange(videos, 'b c t h w -> (b t) c h w') n = video.shape[0]
z = model.encode_first_stage(x) for i in range(n):
z = rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t) grid = video[i,...]
return z grid = (grid + 1.0) / 2.0
grid = (grid * 255).to(torch.uint8).permute(1, 2, 3, 0) #thwc
path = os.path.join(savedirs[idx].replace('samples', 'samples_separate'), f'{filename.split(".")[0]}_sample{i}.mp4')
def image_guided_synthesis(model, prompts, videos, noise_shape, n_samples=1, ddim_steps=50, ddim_eta=1., \ torchvision.io.write_video(path, grid, fps=fps, video_codec='h264', options={'crf': '10'})
unconditional_guidance_scale=1.0, cfg_img=None, fs=None, text_input=False, multiple_cond_cfg=False, loop=False, gfi=False, timestep_spacing='uniform', guidance_rescale=0.0, **kwargs):
ddim_sampler = DDIMSampler(model) if not multiple_cond_cfg else DDIMSampler_multicond(model) def get_latent_z(model, videos):
batch_size = noise_shape[0] b, c, t, h, w = videos.shape
fs = torch.tensor([fs] * batch_size, dtype=torch.long, device=model.device) x = rearrange(videos, 'b c t h w -> (b t) c h w')
z = model.encode_first_stage(x)
if not text_input: z = rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t)
prompts = [""]*batch_size return z
img = videos[:,:,0] #bchw
img_emb = model.embedder(img) ## blc def image_guided_synthesis(model, prompts, videos, noise_shape, n_samples=1, ddim_steps=50, ddim_eta=1., \
img_emb = model.image_proj_model(img_emb) unconditional_guidance_scale=1.0, cfg_img=None, fs=None, text_input=False, multiple_cond_cfg=False, loop=False, interp=False, timestep_spacing='uniform', guidance_rescale=0.0, **kwargs):
ddim_sampler = DDIMSampler(model) if not multiple_cond_cfg else DDIMSampler_multicond(model)
cond_emb = model.get_learned_conditioning(prompts) batch_size = noise_shape[0]
cond = {"c_crossattn": [torch.cat([cond_emb,img_emb], dim=1)]} fs = torch.tensor([fs] * batch_size, dtype=torch.long, device=model.device)
if model.model.conditioning_key == 'hybrid':
z = get_latent_z(model, videos) # b c t h w if not text_input:
if loop or gfi: prompts = [""]*batch_size
img_cat_cond = torch.zeros_like(z)
img_cat_cond[:,:,0,:,:] = z[:,:,0,:,:] img = videos[:,:,0] #bchw
img_cat_cond[:,:,-1,:,:] = z[:,:,-1,:,:] img_emb = model.embedder(img) ## blc
else: img_emb = model.image_proj_model(img_emb)
img_cat_cond = z[:,:,:1,:,:]
img_cat_cond = repeat(img_cat_cond, 'b c t h w -> b c (repeat t) h w', repeat=z.shape[2]) cond_emb = model.get_learned_conditioning(prompts)
cond["c_concat"] = [img_cat_cond] # b c 1 h w cond = {"c_crossattn": [torch.cat([cond_emb,img_emb], dim=1)]}
if model.model.conditioning_key == 'hybrid':
if unconditional_guidance_scale != 1.0: z = get_latent_z(model, videos) # b c t h w
if model.uncond_type == "empty_seq": if loop or interp:
prompts = batch_size * [""] img_cat_cond = torch.zeros_like(z)
uc_emb = model.get_learned_conditioning(prompts) img_cat_cond[:,:,0,:,:] = z[:,:,0,:,:]
elif model.uncond_type == "zero_embed": img_cat_cond[:,:,-1,:,:] = z[:,:,-1,:,:]
uc_emb = torch.zeros_like(cond_emb) else:
uc_img_emb = model.embedder(torch.zeros_like(img)) ## b l c img_cat_cond = z[:,:,:1,:,:]
uc_img_emb = model.image_proj_model(uc_img_emb) img_cat_cond = repeat(img_cat_cond, 'b c t h w -> b c (repeat t) h w', repeat=z.shape[2])
uc = {"c_crossattn": [torch.cat([uc_emb,uc_img_emb],dim=1)]} cond["c_concat"] = [img_cat_cond] # b c 1 h w
if model.model.conditioning_key == 'hybrid':
uc["c_concat"] = [img_cat_cond] if unconditional_guidance_scale != 1.0:
else: if model.uncond_type == "empty_seq":
uc = None prompts = batch_size * [""]
uc_emb = model.get_learned_conditioning(prompts)
## we need one more unconditioning image=yes, text="" elif model.uncond_type == "zero_embed":
if multiple_cond_cfg and cfg_img != 1.0: uc_emb = torch.zeros_like(cond_emb)
uc_2 = {"c_crossattn": [torch.cat([uc_emb,img_emb],dim=1)]} uc_img_emb = model.embedder(torch.zeros_like(img)) ## b l c
if model.model.conditioning_key == 'hybrid': uc_img_emb = model.image_proj_model(uc_img_emb)
uc_2["c_concat"] = [img_cat_cond] uc = {"c_crossattn": [torch.cat([uc_emb,uc_img_emb],dim=1)]}
kwargs.update({"unconditional_conditioning_img_nonetext": uc_2}) if model.model.conditioning_key == 'hybrid':
else: uc["c_concat"] = [img_cat_cond]
kwargs.update({"unconditional_conditioning_img_nonetext": None}) else:
uc = None
z0 = None
cond_mask = None ## we need one more unconditioning image=yes, text=""
if multiple_cond_cfg and cfg_img != 1.0:
batch_variants = [] uc_2 = {"c_crossattn": [torch.cat([uc_emb,img_emb],dim=1)]}
for _ in range(n_samples): if model.model.conditioning_key == 'hybrid':
uc_2["c_concat"] = [img_cat_cond]
if z0 is not None: kwargs.update({"unconditional_conditioning_img_nonetext": uc_2})
cond_z0 = z0.clone() else:
kwargs.update({"clean_cond": True}) kwargs.update({"unconditional_conditioning_img_nonetext": None})
else:
cond_z0 = None z0 = None
if ddim_sampler is not None: cond_mask = None
samples, _ = ddim_sampler.sample(S=ddim_steps, batch_variants = []
conditioning=cond, for _ in range(n_samples):
batch_size=batch_size,
shape=noise_shape[1:], if z0 is not None:
verbose=False, cond_z0 = z0.clone()
unconditional_guidance_scale=unconditional_guidance_scale, kwargs.update({"clean_cond": True})
unconditional_conditioning=uc, else:
eta=ddim_eta, cond_z0 = None
cfg_img=cfg_img, if ddim_sampler is not None:
mask=cond_mask,
x0=cond_z0, samples, _ = ddim_sampler.sample(S=ddim_steps,
fs=fs, conditioning=cond,
timestep_spacing=timestep_spacing, batch_size=batch_size,
guidance_rescale=guidance_rescale, shape=noise_shape[1:],
**kwargs verbose=False,
) unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=uc,
## reconstruct from latent to pixel space eta=ddim_eta,
batch_images = model.decode_first_stage(samples) cfg_img=cfg_img,
batch_variants.append(batch_images) mask=cond_mask,
## variants, batch, c, t, h, w x0=cond_z0,
batch_variants = torch.stack(batch_variants) fs=fs,
return batch_variants.permute(1, 0, 2, 3, 4, 5) timestep_spacing=timestep_spacing,
guidance_rescale=guidance_rescale,
**kwargs
def run_inference(args, gpu_num, gpu_no): )
## model config
config = OmegaConf.load(args.config) ## reconstruct from latent to pixel space
model_config = config.pop("model", OmegaConf.create()) batch_images = model.decode_first_stage(samples)
batch_variants.append(batch_images)
## set use_checkpoint as False as when using deepspeed, it encounters an error "deepspeed backend not set" ## variants, batch, c, t, h, w
model_config['params']['unet_config']['params']['use_checkpoint'] = False batch_variants = torch.stack(batch_variants)
model = instantiate_from_config(model_config) return batch_variants.permute(1, 0, 2, 3, 4, 5)
model = model.cuda(gpu_no)
model.perframe_ae = args.perframe_ae
assert os.path.exists(args.ckpt_path), "Error: checkpoint Not Found!" def run_inference(args, gpu_num, gpu_no):
model = load_model_checkpoint(model, args.ckpt_path) ## model config
model.eval() config = OmegaConf.load(args.config)
model_config = config.pop("model", OmegaConf.create())
## run over data
assert (args.height % 16 == 0) and (args.width % 16 == 0), "Error: image size [h,w] should be multiples of 16!" ## set use_checkpoint as False as when using deepspeed, it encounters an error "deepspeed backend not set"
assert args.bs == 1, "Current implementation only support [batch size = 1]!" model_config['params']['unet_config']['params']['use_checkpoint'] = False
## latent noise shape model = instantiate_from_config(model_config)
h, w = args.height // 8, args.width // 8 model = model.cuda(gpu_no)
channels = model.model.diffusion_model.out_channels model.perframe_ae = args.perframe_ae
n_frames = args.video_length assert os.path.exists(args.ckpt_path), "Error: checkpoint Not Found!"
print(f'Inference with {n_frames} frames') model = load_model_checkpoint(model, args.ckpt_path)
noise_shape = [args.bs, channels, n_frames, h, w] model.eval()
fakedir = os.path.join(args.savedir, "samples") ## run over data
fakedir_separate = os.path.join(args.savedir, "samples_separate") assert (args.height % 16 == 0) and (args.width % 16 == 0), "Error: image size [h,w] should be multiples of 16!"
assert args.bs == 1, "Current implementation only support [batch size = 1]!"
# os.makedirs(fakedir, exist_ok=True) ## latent noise shape
os.makedirs(fakedir_separate, exist_ok=True) h, w = args.height // 8, args.width // 8
channels = model.model.diffusion_model.out_channels
## prompt file setting n_frames = args.video_length
assert os.path.exists(args.prompt_dir), "Error: prompt file Not Found!" print(f'Inference with {n_frames} frames')
filename_list, data_list, prompt_list = load_data_prompts(args.prompt_dir, video_size=(args.height, args.width), video_frames=n_frames, gfi=args.gfi) noise_shape = [args.bs, channels, n_frames, h, w]
num_samples = len(prompt_list)
samples_split = num_samples // gpu_num fakedir = os.path.join(args.savedir, "samples")
print('Prompts testing [rank:%d] %d/%d samples loaded.'%(gpu_no, samples_split, num_samples)) fakedir_separate = os.path.join(args.savedir, "samples_separate")
#indices = random.choices(list(range(0, num_samples)), k=samples_per_device)
indices = list(range(samples_split*gpu_no, samples_split*(gpu_no+1))) # os.makedirs(fakedir, exist_ok=True)
prompt_list_rank = [prompt_list[i] for i in indices] os.makedirs(fakedir_separate, exist_ok=True)
data_list_rank = [data_list[i] for i in indices]
filename_list_rank = [filename_list[i] for i in indices] ## prompt file setting
assert os.path.exists(args.prompt_dir), "Error: prompt file Not Found!"
start = time.time() filename_list, data_list, prompt_list = load_data_prompts(args.prompt_dir, video_size=(args.height, args.width), video_frames=n_frames, interp=args.interp)
with torch.no_grad(), torch.cuda.amp.autocast(): num_samples = len(prompt_list)
for idx, indice in tqdm(enumerate(range(0, len(prompt_list_rank), args.bs)), desc='Sample Batch'): samples_split = num_samples // gpu_num
prompts = prompt_list_rank[indice:indice+args.bs] print('Prompts testing [rank:%d] %d/%d samples loaded.'%(gpu_no, samples_split, num_samples))
videos = data_list_rank[indice:indice+args.bs] #indices = random.choices(list(range(0, num_samples)), k=samples_per_device)
filenames = filename_list_rank[indice:indice+args.bs] indices = list(range(samples_split*gpu_no, samples_split*(gpu_no+1)))
if isinstance(videos, list): prompt_list_rank = [prompt_list[i] for i in indices]
videos = torch.stack(videos, dim=0).to("cuda") data_list_rank = [data_list[i] for i in indices]
else: filename_list_rank = [filename_list[i] for i in indices]
videos = videos.unsqueeze(0).to("cuda")
start = time.time()
batch_samples = image_guided_synthesis(model, prompts, videos, noise_shape, args.n_samples, args.ddim_steps, args.ddim_eta, \ with torch.no_grad(), torch.cuda.amp.autocast():
args.unconditional_guidance_scale, args.cfg_img, args.frame_stride, args.text_input, args.multiple_cond_cfg, args.loop, args.gfi, args.timestep_spacing, args.guidance_rescale) for idx, indice in tqdm(enumerate(range(0, len(prompt_list_rank), args.bs)), desc='Sample Batch'):
prompts = prompt_list_rank[indice:indice+args.bs]
## save each example individually videos = data_list_rank[indice:indice+args.bs]
for nn, samples in enumerate(batch_samples): filenames = filename_list_rank[indice:indice+args.bs]
## samples : [n_samples,c,t,h,w] if isinstance(videos, list):
prompt = prompts[nn] videos = torch.stack(videos, dim=0).to("cuda")
filename = filenames[nn] else:
# save_results(prompt, samples, filename, fakedir, fps=8, loop=args.loop) videos = videos.unsqueeze(0).to("cuda")
save_results_seperate(prompt, samples, filename, fakedir, fps=8, loop=args.loop)
batch_samples = image_guided_synthesis(model, prompts, videos, noise_shape, args.n_samples, args.ddim_steps, args.ddim_eta, \
print(f"Saved in {args.savedir}. Time used: {(time.time() - start):.2f} seconds") args.unconditional_guidance_scale, args.cfg_img, args.frame_stride, args.text_input, args.multiple_cond_cfg, args.loop, args.interp, args.timestep_spacing, args.guidance_rescale)
## save each example individually
def get_parser(): for nn, samples in enumerate(batch_samples):
parser = argparse.ArgumentParser() ## samples : [n_samples,c,t,h,w]
parser.add_argument("--savedir", type=str, default=None, help="results saving path") prompt = prompts[nn]
parser.add_argument("--ckpt_path", type=str, default=None, help="checkpoint path") filename = filenames[nn]
parser.add_argument("--config", type=str, help="config (yaml) path") # save_results(prompt, samples, filename, fakedir, fps=8, loop=args.loop)
parser.add_argument("--prompt_dir", type=str, default=None, help="a data dir containing videos and prompts") save_results_seperate(prompt, samples, filename, fakedir, fps=8, loop=args.loop)
parser.add_argument("--n_samples", type=int, default=1, help="num of samples per prompt",)
parser.add_argument("--ddim_steps", type=int, default=50, help="steps of ddim if positive, otherwise use DDPM",) print(f"Saved in {args.savedir}. Time used: {(time.time() - start):.2f} seconds")
parser.add_argument("--ddim_eta", type=float, default=1.0, help="eta for ddim sampling (0.0 yields deterministic sampling)",)
parser.add_argument("--bs", type=int, default=1, help="batch size for inference, should be one")
parser.add_argument("--height", type=int, default=512, help="image height, in pixel space") def get_parser():
parser.add_argument("--width", type=int, default=512, help="image width, in pixel space") parser = argparse.ArgumentParser()
parser.add_argument("--frame_stride", type=int, default=3, help="frame stride control for 256 model (larger->larger motion), FPS control for 512 or 1024 model (smaller->larger motion)") parser.add_argument("--savedir", type=str, default=None, help="results saving path")
parser.add_argument("--unconditional_guidance_scale", type=float, default=1.0, help="prompt classifier-free guidance") parser.add_argument("--ckpt_path", type=str, default=None, help="checkpoint path")
parser.add_argument("--seed", type=int, default=123, help="seed for seed_everything") parser.add_argument("--config", type=str, help="config (yaml) path")
parser.add_argument("--video_length", type=int, default=16, help="inference video length") parser.add_argument("--prompt_dir", type=str, default=None, help="a data dir containing videos and prompts")
parser.add_argument("--negative_prompt", action='store_true', default=False, help="negative prompt") parser.add_argument("--n_samples", type=int, default=1, help="num of samples per prompt",)
parser.add_argument("--text_input", action='store_true', default=False, help="input text to I2V model or not") parser.add_argument("--ddim_steps", type=int, default=50, help="steps of ddim if positive, otherwise use DDPM",)
parser.add_argument("--multiple_cond_cfg", action='store_true', default=False, help="use multi-condition cfg or not") parser.add_argument("--ddim_eta", type=float, default=1.0, help="eta for ddim sampling (0.0 yields deterministic sampling)",)
parser.add_argument("--cfg_img", type=float, default=None, help="guidance scale for image conditioning") parser.add_argument("--bs", type=int, default=1, help="batch size for inference, should be one")
parser.add_argument("--timestep_spacing", type=str, default="uniform", help="The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.") parser.add_argument("--height", type=int, default=512, help="image height, in pixel space")
parser.add_argument("--guidance_rescale", type=float, default=0.0, help="guidance rescale in [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891)") parser.add_argument("--width", type=int, default=512, help="image width, in pixel space")
parser.add_argument("--perframe_ae", action='store_true', default=False, help="if we use per-frame AE decoding, set it to True to save GPU memory, especially for the model of 576x1024") parser.add_argument("--frame_stride", type=int, default=3, help="frame stride control for 256 model (larger->larger motion), FPS control for 512 or 1024 model (smaller->larger motion)")
parser.add_argument("--unconditional_guidance_scale", type=float, default=1.0, help="prompt classifier-free guidance")
## currently not support looping video and generative frame interpolation parser.add_argument("--seed", type=int, default=123, help="seed for seed_everything")
parser.add_argument("--loop", action='store_true', default=False, help="generate looping videos or not") parser.add_argument("--video_length", type=int, default=16, help="inference video length")
parser.add_argument("--gfi", action='store_true', default=False, help="generate generative frame interpolation (gfi) or not") parser.add_argument("--negative_prompt", action='store_true', default=False, help="negative prompt")
return parser parser.add_argument("--text_input", action='store_true', default=False, help="input text to I2V model or not")
parser.add_argument("--multiple_cond_cfg", action='store_true', default=False, help="use multi-condition cfg or not")
parser.add_argument("--cfg_img", type=float, default=None, help="guidance scale for image conditioning")
if __name__ == '__main__': parser.add_argument("--timestep_spacing", type=str, default="uniform", help="The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.")
now = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") parser.add_argument("--guidance_rescale", type=float, default=0.0, help="guidance rescale in [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891)")
print("@DynamiCrafter cond-Inference: %s"%now) parser.add_argument("--perframe_ae", action='store_true', default=False, help="if we use per-frame AE decoding, set it to True to save GPU memory, especially for the model of 576x1024")
parser = get_parser()
args = parser.parse_args() ## currently not support looping video and generative frame interpolation
parser.add_argument("--loop", action='store_true', default=False, help="generate looping videos or not")
seed_everything(args.seed) parser.add_argument("--interp", action='store_true', default=False, help="generate generative frame interpolation or not")
rank, gpu_num = 0, 1 return parser
run_inference(args, gpu_num, rank)
\ No newline at end of file
if __name__ == '__main__':
now = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
print("@DynamiCrafter cond-Inference: %s"%now)
parser = get_parser()
args = parser.parse_args()
seed_everything(args.seed)
rank, gpu_num = 0, 1
run_inference(args, gpu_num, rank)
import os mport os
import time import time
from omegaconf import OmegaConf from omegaconf import OmegaConf
import torch import torch
from scripts.evaluation.funcs import load_model_checkpoint, save_videos, batch_ddim_sampling, get_latent_z from scripts.evaluation.funcs import load_model_checkpoint, save_videos, batch_ddim_sampling, get_latent_z
from utils.utils import instantiate_from_config from utils.utils import instantiate_from_config
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from einops import repeat from einops import repeat
import torchvision.transforms as transforms import torchvision.transforms as transforms
from pytorch_lightning import seed_everything from pytorch_lightning import seed_everything
class Image2Video(): class Image2Video():
def __init__(self,result_dir='./tmp/',gpu_num=1,resolution='256_256', **kwargs) -> None: def __init__(self,result_dir='./tmp/',gpu_num=1,resolution='256_256') -> None:
self.resolution = (int(resolution.split('_')[0]), int(resolution.split('_')[1])) #hw self.resolution = (int(resolution.split('_')[0]), int(resolution.split('_')[1])) #hw
self.download_model() self.download_model()
self.result_dir = result_dir self.result_dir = result_dir
if not os.path.exists(self.result_dir): if not os.path.exists(self.result_dir):
os.mkdir(self.result_dir) os.mkdir(self.result_dir)
ckpt_path='checkpoints/dynamicrafter_'+resolution.split('_')[1]+'_v1/model.ckpt' ckpt_path='checkpoints/dynamicrafter_'+resolution.split('_')[1]+'_v1/model.ckpt'
config_file='configs/inference_'+resolution.split('_')[1]+'_v1.0.yaml' config_file='configs/inference_'+resolution.split('_')[1]+'_v1.0.yaml'
config = OmegaConf.load(config_file) config = OmegaConf.load(config_file)
model_config = config.pop("model", OmegaConf.create()) model_config = config.pop("model", OmegaConf.create())
model_config['params']['unet_config']['params']['use_checkpoint']=False model_config['params']['unet_config']['params']['use_checkpoint']=False
model_list = [] model_list = []
for gpu_id in range(gpu_num): for gpu_id in range(gpu_num):
model = instantiate_from_config(model_config) model = instantiate_from_config(model_config)
# model = model.cuda(gpu_id) # model = model.cuda(gpu_id)
assert os.path.exists(ckpt_path), "Error: checkpoint Not Found!" assert os.path.exists(ckpt_path), "Error: checkpoint Not Found!"
model = load_model_checkpoint(model, ckpt_path) model = load_model_checkpoint(model, ckpt_path)
model.eval() model.eval()
model_list.append(model) model_list.append(model)
self.model_list = model_list self.model_list = model_list
self.save_fps = 8 self.save_fps = 8
self.kwargs = kwargs def get_image(self, image, prompt, steps=50, cfg_scale=7.5, eta=1.0, fs=3, seed=123):
seed_everything(seed)
def get_image(self, image, prompt, steps=50, cfg_scale=7.5, eta=1.0, fs=3, seed=123): transform = transforms.Compose([
seed_everything(seed) transforms.Resize(min(self.resolution)),
transform = transforms.Compose([ transforms.CenterCrop(self.resolution),
transforms.Resize(min(self.resolution)), ])
transforms.CenterCrop(self.resolution), torch.cuda.empty_cache()
]) print('start:', prompt, time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())))
torch.cuda.empty_cache() start = time.time()
print('start:', prompt, time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time()))) gpu_id=0
start = time.time() if steps > 60:
gpu_id=0 steps = 60
if steps > 60: model = self.model_list[gpu_id]
steps = 60 model = model.cuda()
model = self.model_list[gpu_id] batch_size=1
model = model.cuda() channels = model.model.diffusion_model.out_channels
batch_size=1 frames = model.temporal_length
channels = model.model.diffusion_model.out_channels h, w = self.resolution[0] // 8, self.resolution[1] // 8
frames = model.temporal_length noise_shape = [batch_size, channels, frames, h, w]
h, w = self.resolution[0] // 8, self.resolution[1] // 8
noise_shape = [batch_size, channels, frames, h, w] # text cond
with torch.no_grad(), torch.cuda.amp.autocast():
# text cond text_emb = model.get_learned_conditioning([prompt])
text_emb = model.get_learned_conditioning([prompt])
# img cond
# img cond img_tensor = torch.from_numpy(image).permute(2, 0, 1).float().to(model.device)
img_tensor = torch.from_numpy(image).permute(2, 0, 1).float().to(model.device) img_tensor = (img_tensor / 255. - 0.5) * 2
img_tensor = (img_tensor / 255. - 0.5) * 2
image_tensor_resized = transform(img_tensor) #3,h,w
image_tensor_resized = transform(img_tensor) #3,h,w videos = image_tensor_resized.unsqueeze(0) # bchw
videos = image_tensor_resized.unsqueeze(0) # bchw
z = get_latent_z(model, videos.unsqueeze(2)) #bc,1,hw
z = get_latent_z(model, videos.unsqueeze(2)) #bc,1,hw
img_tensor_repeat = repeat(z, 'b c t h w -> b c (repeat t) h w', repeat=frames)
img_tensor_repeat = repeat(z, 'b c t h w -> b c (repeat t) h w', repeat=frames)
cond_images = model.embedder(img_tensor.unsqueeze(0)) ## blc
cond_images = model.embedder(img_tensor.unsqueeze(0)) ## blc img_emb = model.image_proj_model(cond_images)
img_emb = model.image_proj_model(cond_images)
imtext_cond = torch.cat([text_emb, img_emb], dim=1)
imtext_cond = torch.cat([text_emb, img_emb], dim=1)
fs = torch.tensor([fs], dtype=torch.long, device=model.device)
fs = torch.tensor([fs], dtype=torch.long, device=model.device) cond = {"c_crossattn": [imtext_cond], "fs": fs, "c_concat": [img_tensor_repeat]}
cond = {"c_crossattn": [imtext_cond], "fs": fs, "c_concat": [img_tensor_repeat]}
## inference
## inference batch_samples = batch_ddim_sampling(model, cond, noise_shape, n_samples=1, ddim_steps=steps, ddim_eta=eta, cfg_scale=cfg_scale)
with torch.no_grad(), torch.cuda.amp.autocast(): ## b,samples,c,t,h,w
batch_samples = batch_ddim_sampling(model, cond, noise_shape, n_samples=1, ddim_steps=steps, ddim_eta=eta, cfg_scale=cfg_scale) prompt_str = prompt.replace("/", "_slash_") if "/" in prompt else prompt
## b,samples,c,t,h,w prompt_str = prompt_str.replace(" ", "_") if " " in prompt else prompt_str
prompt_str = prompt.replace("/", "_slash_") if "/" in prompt else prompt prompt_str=prompt_str[:40]
prompt_str = prompt_str.replace(" ", "_") if " " in prompt else prompt_str if len(prompt_str) == 0:
prompt_str=prompt_str[:40] prompt_str = 'empty_prompt'
if len(prompt_str) == 0:
prompt_str = 'empty_prompt' save_videos(batch_samples, self.result_dir, filenames=[prompt_str], fps=self.save_fps)
print(f"Saved in {prompt_str}. Time used: {(time.time() - start):.2f} seconds")
save_videos(batch_samples, self.result_dir, filenames=[prompt_str], fps=self.save_fps) model = model.cpu()
print(f"Saved in {prompt_str}. Time used: {(time.time() - start):.2f} seconds") return os.path.join(self.result_dir, f"{prompt_str}.mp4")
model = model.cpu()
return os.path.join(self.result_dir, f"{prompt_str}.mp4") def download_model(self):
REPO_ID = 'Doubiiu/DynamiCrafter_'+str(self.resolution[1]) if self.resolution[1]!=256 else 'Doubiiu/DynamiCrafter'
def download_model(self): filename_list = ['model.ckpt']
REPO_ID = 'Doubiiu/DynamiCrafter_'+str(self.resolution[1]) if self.resolution[1]!=256 else 'Doubiiu/DynamiCrafter' if not os.path.exists('./checkpoints/dynamicrafter_'+str(self.resolution[1])+'_v1/'):
filename_list = ['model.ckpt'] os.makedirs('./checkpoints/dynamicrafter_'+str(self.resolution[1])+'_v1/')
if not os.path.exists('./checkpoints/dynamicrafter_'+str(self.resolution[1])+'_v1/'): for filename in filename_list:
os.makedirs('./checkpoints/dynamicrafter_'+str(self.resolution[1])+'_v1/') local_file = os.path.join('./checkpoints/dynamicrafter_'+str(self.resolution[1])+'_v1/', filename)
for filename in filename_list: if not os.path.exists(local_file):
local_file = os.path.join('./checkpoints/dynamicrafter_'+str(self.resolution[1])+'_v1/', filename) hf_hub_download(repo_id=REPO_ID, filename=filename, local_dir='./checkpoints/dynamicrafter_'+str(self.resolution[1])+'_v1/', local_dir_use_symlinks=False)
if not os.path.exists(local_file):
hf_hub_download(repo_id=REPO_ID, filename=filename, local_dir='./checkpoints/dynamicrafter_'+str(self.resolution[1])+'_v1/', local_dir_use_symlinks=False) if __name__ == '__main__':
i2v = Image2Video()
if __name__ == '__main__': video_path = i2v.get_image('prompts/art.png','man fishing in a boat at sunset')
i2v = Image2Video() print('done', video_path)
video_path = i2v.get_image('prompts/art.png','man fishing in a boat at sunset')
print('done', video_path)
\ No newline at end of file
import os
import time
from omegaconf import OmegaConf
import torch
from scripts.evaluation.funcs import load_model_checkpoint, save_videos, batch_ddim_sampling, get_latent_z
from utils.utils import instantiate_from_config
from huggingface_hub import hf_hub_download
from einops import repeat
import torchvision.transforms as transforms
from pytorch_lightning import seed_everything
class Image2Video():
def __init__(self,result_dir='./tmp/',gpu_num=1,resolution='256_256') -> None:
self.resolution = (int(resolution.split('_')[0]), int(resolution.split('_')[1])) #hw
self.download_model()
self.result_dir = result_dir
if not os.path.exists(self.result_dir):
os.mkdir(self.result_dir)
ckpt_path='checkpoints/dynamicrafter_'+resolution.split('_')[1]+'_interp_v1/model.ckpt'
config_file='configs/inference_'+resolution.split('_')[1]+'_v1.0.yaml'
config = OmegaConf.load(config_file)
model_config = config.pop("model", OmegaConf.create())
model_config['params']['unet_config']['params']['use_checkpoint']=False
model_list = []
for gpu_id in range(gpu_num):
model = instantiate_from_config(model_config)
# model = model.cuda(gpu_id)
assert os.path.exists(ckpt_path), "Error: checkpoint Not Found!"
model = load_model_checkpoint(model, ckpt_path)
model.eval()
model_list.append(model)
self.model_list = model_list
self.save_fps = 8
def get_image(self, image, prompt, steps=50, cfg_scale=7.5, eta=1.0, fs=3, seed=123, image2=None):
seed_everything(seed)
transform = transforms.Compose([
transforms.Resize(min(self.resolution)),
transforms.CenterCrop(self.resolution),
])
torch.cuda.empty_cache()
print('start:', prompt, time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())))
start = time.time()
gpu_id=0
if steps > 60:
steps = 60
model = self.model_list[gpu_id]
model = model.cuda()
batch_size=1
channels = model.model.diffusion_model.out_channels
frames = model.temporal_length
h, w = self.resolution[0] // 8, self.resolution[1] // 8
noise_shape = [batch_size, channels, frames, h, w]
# text cond
with torch.no_grad(), torch.cuda.amp.autocast():
text_emb = model.get_learned_conditioning([prompt])
# img cond
img_tensor = torch.from_numpy(image).permute(2, 0, 1).float().to(model.device)
img_tensor = (img_tensor / 255. - 0.5) * 2
image_tensor_resized = transform(img_tensor) #3,h,w
videos = image_tensor_resized.unsqueeze(0) # bchw
z = get_latent_z(model, videos.unsqueeze(2)) #bc,1,hw
if image2 is not None:
img_tensor2 = torch.from_numpy(image2).permute(2, 0, 1).float().to(model.device)
img_tensor2 = (img_tensor2 / 255. - 0.5) * 2
image_tensor_resized2 = transform(img_tensor2) #3,h,w
videos2 = image_tensor_resized2.unsqueeze(0) # bchw
z2 = get_latent_z(model, videos2.unsqueeze(2)) #bc,1,hw
img_tensor_repeat = repeat(z, 'b c t h w -> b c (repeat t) h w', repeat=frames)
img_tensor_repeat = torch.zeros_like(img_tensor_repeat)
## old
img_tensor_repeat[:,:,:1,:,:] = z
if image2 is not None:
img_tensor_repeat[:,:,-1:,:,:] = z2
else:
img_tensor_repeat[:,:,-1:,:,:] = z
cond_images = model.embedder(img_tensor.unsqueeze(0)) ## blc
img_emb = model.image_proj_model(cond_images)
imtext_cond = torch.cat([text_emb, img_emb], dim=1)
fs = torch.tensor([fs], dtype=torch.long, device=model.device)
cond = {"c_crossattn": [imtext_cond], "fs": fs, "c_concat": [img_tensor_repeat]}
## inference
batch_samples = batch_ddim_sampling(model, cond, noise_shape, n_samples=1, ddim_steps=steps, ddim_eta=eta, cfg_scale=cfg_scale)
## remove the last frame
if image2 is None:
batch_samples = batch_samples[:,:,:,:-1,...]
## b,samples,c,t,h,w
prompt_str = prompt.replace("/", "_slash_") if "/" in prompt else prompt
prompt_str = prompt_str.replace(" ", "_") if " " in prompt else prompt_str
prompt_str=prompt_str[:40]
if len(prompt_str) == 0:
prompt_str = 'empty_prompt'
save_videos(batch_samples, self.result_dir, filenames=[prompt_str], fps=self.save_fps)
print(f"Saved in {prompt_str}. Time used: {(time.time() - start):.2f} seconds")
model = model.cpu()
return os.path.join(self.result_dir, f"{prompt_str}.mp4")
def download_model(self):
REPO_ID = 'Doubiiu/DynamiCrafter_'+str(self.resolution[1])+'_Interp'
filename_list = ['model.ckpt']
if not os.path.exists('./checkpoints/dynamicrafter_'+str(self.resolution[1])+'_interp_v1/'):
os.makedirs('./checkpoints/dynamicrafter_'+str(self.resolution[1])+'_interp_v1/')
for filename in filename_list:
local_file = os.path.join('./checkpoints/dynamicrafter_'+str(self.resolution[1])+'_interp_v1/', filename)
if not os.path.exists(local_file):
hf_hub_download(repo_id=REPO_ID, filename=filename, local_dir='./checkpoints/dynamicrafter_'+str(self.resolution[1])+'_interp_v1/', local_dir_use_symlinks=False)
if __name__ == '__main__':
i2v = Image2Video()
video_path = i2v.get_image('prompts/art.png','man fishing in a boat at sunset')
print('done', video_path)
version=$1 # interp or loop
ckpt=checkpoints/dynamicrafter_512_interp_v1/model.ckpt
config=configs/inference_512_v1.0.yaml
prompt_dir=prompts/512_$1/
res_dir="results"
FS=5 ## This model adopts FPS=5, range recommended: 5-30 (smaller value -> larger motion)
if [ "$1" == "interp" ]; then
seed=12306
name=dynamicrafter_512_$1_seed${seed}
CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/inference.py \
--seed ${seed} \
--ckpt_path $ckpt \
--config $config \
--savedir $res_dir/$name \
--n_samples 1 \
--bs 1 --height 320 --width 512 \
--unconditional_guidance_scale 7.5 \
--ddim_steps 50 \
--ddim_eta 1.0 \
--prompt_dir $prompt_dir \
--text_input \
--video_length 16 \
--frame_stride ${FS} \
--timestep_spacing 'uniform_trailing' --guidance_rescale 0.7 --perframe_ae --interp
else
seed=234
name=dynamicrafter_512_$1_seed${seed}
CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/inference.py \
--seed ${seed} \
--ckpt_path $ckpt \
--config $config \
--savedir $res_dir/$name \
--n_samples 1 \
--bs 1 --height 320 --width 512 \
--unconditional_guidance_scale 7.5 \
--ddim_steps 50 \
--ddim_eta 1.0 \
--prompt_dir $prompt_dir \
--text_input \
--video_length 16 \
--frame_stride ${FS} \
--timestep_spacing 'uniform_trailing' --guidance_rescale 0.7 --perframe_ae --loop
fi
...@@ -63,8 +63,8 @@ fi ...@@ -63,8 +63,8 @@ fi
## inference using single node with multi-GPUs: ## inference using single node with multi-GPUs:
if [ "$1" == "256" ]; then if [ "$1" == "256" ]; then
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python3 -m torch.distributed.launch \ CUDA_VISIBLE_DEVICES=0,1 python3 -m torch.distributed.launch \
--nproc_per_node=8 --nnodes=1 --master_addr=127.0.0.1 --master_port=23456 --node_rank=0 \ --nproc_per_node=2 --nnodes=1 --master_addr=127.0.0.1 --master_port=23456 --node_rank=0 \
scripts/evaluation/ddp_wrapper.py \ scripts/evaluation/ddp_wrapper.py \
--module 'inference' \ --module 'inference' \
--seed ${seed} \ --seed ${seed} \
...@@ -81,8 +81,8 @@ scripts/evaluation/ddp_wrapper.py \ ...@@ -81,8 +81,8 @@ scripts/evaluation/ddp_wrapper.py \
--video_length 16 \ --video_length 16 \
--frame_stride ${FS} --frame_stride ${FS}
else else
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python3 -m torch.distributed.launch \ CUDA_VISIBLE_DEVICES=0,1 python3 -m torch.distributed.launch \
--nproc_per_node=8 --nnodes=1 --master_addr=127.0.0.1 --master_port=23456 --node_rank=0 \ --nproc_per_node=2 --nnodes=1 --master_addr=127.0.0.1 --master_port=23456 --node_rank=0 \
scripts/evaluation/ddp_wrapper.py \ scripts/evaluation/ddp_wrapper.py \
--module 'inference' \ --module 'inference' \
--seed ${seed} \ --seed ${seed} \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment