Commit 1ad55bb4 authored by mashun1's avatar mashun1
Browse files

i2vgen-xl

parents
Pipeline #819 canceled with stages
import os
import json
import torch
import logging
import collections
from utils.registry_class import PRETRAIN
@PRETRAIN.register_function()
def pretrain_specific_strategies(
model,
resume_checkpoint,
sd_keys_path=None,
grad_scale=1,
fix_weight=False,
**kwargs
):
state_dict = torch.load(resume_checkpoint, map_location='cpu')
if 'state_dict' in state_dict:
state_dict = state_dict['state_dict']
# [1] load model
try:
ret = model.load_state_dict(state_dict, strict=False)
logging.info(f'load a fixed model with {ret}')
except:
model_dict = model.state_dict()
key_list = list(state_dict.keys())
for skey, item in state_dict.items():
if skey not in model_dict:
logging.info(f'Skip {skey}')
continue
if item.shape != model_dict[skey].shape:
logging.info(f'Skip {skey} with different shape {item.shape} {model_dict[skey].shape}')
continue
model_dict[skey].copy_(item)
model.load_state_dict(model_dict)
# [2] assign strategies
total_size = 0
state_dict = {} if sd_keys_path is None else json.load(open(sd_keys_path))
for k, p in model.named_parameters():
if k in state_dict:
total_size += p.numel()
if fix_weight:
p.requires_grad=False
else:
p.register_hook(lambda grad: grad_scale * grad)
resume_step = int(os.path.basename(resume_checkpoint).split('_')[-1].split('.')[0])
logging.info(f'Successfully load step {resume_step} model from {resume_checkpoint}')
logging.info(f'load a fixed model with {int(total_size / (1024 ** 2))}M parameters')
return model, resume_step
@PRETRAIN.register_function()
def pretrain_from_sd():
pass
@PRETRAIN.register_function()
def pretrain_ema_model():
pass
from .image_dataset import *
from .video_dataset import *
\ No newline at end of file
import os
import cv2
import torch
import random
import logging
import tempfile
import numpy as np
from copy import copy
from PIL import Image
from io import BytesIO
from torch.utils.data import Dataset
from utils.registry_class import DATASETS
@DATASETS.register_class()
class ImageDataset(Dataset):
def __init__(self,
data_list,
data_dir_list,
max_words=1000,
vit_resolution=[224, 224],
resolution=(384, 256),
max_frames=1,
transforms=None,
vit_transforms=None,
**kwargs):
self.max_frames = max_frames
self.resolution = resolution
self.transforms = transforms
self.vit_resolution = vit_resolution
self.vit_transforms = vit_transforms
image_list = []
for item_path, data_dir in zip(data_list, data_dir_list):
lines = open(item_path, 'r').readlines()
lines = [[data_dir, item.strip()] for item in lines]
image_list.extend(lines)
self.image_list = image_list
def __len__(self):
return len(self.image_list)
def __getitem__(self, index):
data_dir, file_path = self.image_list[index]
img_key = file_path.split('|||')[0]
try:
ref_frame, vit_frame, video_data, caption = self._get_image_data(data_dir, file_path)
except Exception as e:
logging.info('{} get frames failed... with error: {}'.format(img_key, e))
caption = ''
img_key = ''
ref_frame = torch.zeros(3, self.resolution[1], self.resolution[0])
vit_frame = torch.zeros(3, self.vit_resolution[1], self.vit_resolution[0])
video_data = torch.zeros(self.max_frames, 3, self.resolution[1], self.resolution[0])
return ref_frame, vit_frame, video_data, caption, img_key
def _get_image_data(self, data_dir, file_path):
frame_list = []
img_key, caption = file_path.split('|||')
file_path = os.path.join(data_dir, img_key)
for _ in range(5):
try:
image = Image.open(file_path)
if image.mode != 'RGB':
image = image.convert('RGB')
frame_list.append(image)
break
except Exception as e:
logging.info('{} read video frame failed with error: {}'.format(img_key, e))
continue
video_data = torch.zeros(self.max_frames, 3, self.resolution[1], self.resolution[0])
try:
if len(frame_list) > 0:
mid_frame = frame_list[0]
vit_frame = self.vit_transforms(mid_frame)
frame_tensor = self.transforms(frame_list)
video_data[:len(frame_list), ...] = frame_tensor
else:
vit_frame = torch.zeros(3, self.vit_resolution[1], self.vit_resolution[0])
except:
vit_frame = torch.zeros(3, self.vit_resolution[1], self.vit_resolution[0])
ref_frame = copy(video_data[0])
return ref_frame, vit_frame, video_data, caption
import os
import cv2
import json
import torch
import random
import logging
import tempfile
import numpy as np
from copy import copy
from PIL import Image
from torch.utils.data import Dataset
from utils.registry_class import DATASETS
@DATASETS.register_class()
class VideoDataset(Dataset):
def __init__(self,
data_list,
data_dir_list,
max_words=1000,
resolution=(384, 256),
vit_resolution=(224, 224),
max_frames=16,
sample_fps=8,
transforms=None,
vit_transforms=None,
get_first_frame=False,
**kwargs):
self.max_words = max_words
self.max_frames = max_frames
self.resolution = resolution
self.vit_resolution = vit_resolution
self.sample_fps = sample_fps
self.transforms = transforms
self.vit_transforms = vit_transforms
self.get_first_frame = get_first_frame
image_list = []
for item_path, data_dir in zip(data_list, data_dir_list):
lines = open(item_path, 'r').readlines()
lines = [[data_dir, item] for item in lines]
image_list.extend(lines)
self.image_list = image_list
def __getitem__(self, index):
data_dir, file_path = self.image_list[index]
video_key = file_path.split('|||')[0]
try:
ref_frame, vit_frame, video_data, caption = self._get_video_data(data_dir, file_path)
except Exception as e:
logging.info('{} get frames failed... with error: {}'.format(video_key, e))
caption = ''
video_key = ''
ref_frame = torch.zeros(3, self.resolution[1], self.resolution[0])
vit_frame = torch.zeros(3, self.vit_resolution[1], self.vit_resolution[0])
video_data = torch.zeros(self.max_frames, 3, self.resolution[1], self.resolution[0])
return ref_frame, vit_frame, video_data, caption, video_key
def _get_video_data(self, data_dir, file_path):
video_key, caption = file_path.split('|||')
file_path = os.path.join(data_dir, video_key)
for _ in range(5):
try:
capture = cv2.VideoCapture(file_path)
_fps = capture.get(cv2.CAP_PROP_FPS)
_total_frame_num = capture.get(cv2.CAP_PROP_FRAME_COUNT)
stride = round(_fps / self.sample_fps)
cover_frame_num = (stride * self.max_frames)
if _total_frame_num < cover_frame_num + 5:
start_frame = 0
end_frame = _total_frame_num
else:
start_frame = random.randint(0, _total_frame_num-cover_frame_num-5)
end_frame = start_frame + cover_frame_num
pointer, frame_list = 0, []
while(True):
ret, frame = capture.read()
pointer +=1
if (not ret) or (frame is None): break
if pointer < start_frame: continue
if pointer >= end_frame - 1: break
if (pointer - start_frame) % stride == 0:
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame = Image.fromarray(frame)
frame_list.append(frame)
break
except Exception as e:
logging.info('{} read video frame failed with error: {}'.format(video_key, e))
continue
video_data = torch.zeros(self.max_frames, 3, self.resolution[1], self.resolution[0])
if self.get_first_frame:
ref_idx = 0
else:
ref_idx = int(len(frame_list)/2)
try:
if len(frame_list)>0:
mid_frame = copy(frame_list[ref_idx])
vit_frame = self.vit_transforms(mid_frame)
frames = self.transforms(frame_list)
video_data[:len(frame_list), ...] = frames
else:
vit_frame = torch.zeros(3, self.vit_resolution[1], self.vit_resolution[0])
except:
vit_frame = torch.zeros(3, self.vit_resolution[1], self.vit_resolution[0])
ref_frame = copy(frames[ref_idx])
return ref_frame, vit_frame, video_data, caption
def __len__(self):
return len(self.image_list)
from .visual_train_it2v_video import *
import os
import torch
import pynvml
import logging
from einops import rearrange
import torch.cuda.amp as amp
from utils.video_op import save_video_refimg_and_text
from utils.registry_class import VISUAL
@VISUAL.register_class()
class VisualTrainTextImageToVideo(object):
def __init__(self, cfg_global, autoencoder, diffusion, viz_num, partial_keys=[], guide_scale=9.0, use_offset_noise=None, **kwargs):
super(VisualTrainTextImageToVideo, self).__init__(**kwargs)
self.cfg = cfg_global
self.viz_num = viz_num
self.diffusion = diffusion
self.autoencoder = autoencoder
self.guide_scale = guide_scale
self.partial_keys_list = partial_keys
self.use_offset_noise = use_offset_noise
def prepare_model_kwargs(self, partial_keys, full_model_kwargs):
"""
"""
model_kwargs = [{}, {}]
for partial_key in partial_keys:
model_kwargs[0][partial_key] = full_model_kwargs[0][partial_key]
model_kwargs[1][partial_key] = full_model_kwargs[1][partial_key]
return model_kwargs
@torch.no_grad()
def run(self,
model,
video_data,
captions,
step=0,
ref_frame=None,
visual_kwards=[],
**kwargs):
cfg = self.cfg
viz_num = min(self.viz_num, video_data.size(0))
noise = torch.randn_like(video_data[:viz_num])
if self.use_offset_noise:
noise_strength = getattr(cfg, 'noise_strength', 0)
b, c, f, *_ = video_data[:viz_num].shape
noise = noise + noise_strength * torch.randn(b, c, f, 1, 1, device=video_data.device)
# import ipdb; ipdb.set_trace()
# print memory
pynvml.nvmlInit()
handle=pynvml.nvmlDeviceGetHandleByIndex(0)
meminfo=pynvml.nvmlDeviceGetMemoryInfo(handle)
logging.info(f'GPU Memory used {meminfo.used / (1024 ** 3):.2f} GB')
for keys in self.partial_keys_list:
model_kwargs = self.prepare_model_kwargs(keys, visual_kwards)
pre_name = '_'.join(keys)
with amp.autocast(enabled=cfg.use_fp16):
video_data = self.diffusion.ddim_sample_loop(
noise=noise.clone(),
model=model.eval(),
model_kwargs=model_kwargs,
guide_scale=self.guide_scale,
ddim_timesteps=cfg.ddim_timesteps,
eta=0.0)
video_data = 1. / cfg.scale_factor * video_data # [64, 4, 32, 48]
video_data = rearrange(video_data, 'b c f h w -> (b f) c h w')
chunk_size = min(cfg.decoder_bs, video_data.shape[0])
video_data_list = torch.chunk(video_data, video_data.shape[0]//chunk_size,dim=0)
decode_data = []
for vd_data in video_data_list:
gen_frames = self.autoencoder.decode(vd_data)
decode_data.append(gen_frames)
video_data = torch.cat(decode_data, dim=0)
video_data = rearrange(video_data, '(b f) c h w -> b c f h w', b = viz_num)
text_size = cfg.resolution[-1]
ref_frame = ref_frame[:viz_num]
file_name = f'rank_{cfg.world_size:02d}_{cfg.rank:02d}_{cfg.sample_fps:02d}_{pre_name}'
local_path = os.path.join(cfg.log_dir, f'sample_{step:06d}/{file_name}')
os.makedirs(os.path.dirname(local_path), exist_ok=True)
try:
save_video_refimg_and_text(local_path, ref_frame.cpu(), video_data.cpu(), captions, cfg.mean, cfg.std, text_size)
except Exception as e:
logging.info(f'Step: {step} save text or video error with {e}')
\ No newline at end of file
import os
import torch
import pynvml
import logging
from einops import rearrange
import torch.cuda.amp as amp
from utils.video_op import save_video_refimg_and_text
from utils.registry_class import VISUAL
@VISUAL.register_class()
class VisualTrainTextToVideo(object):
def __init__(self, cfg_global, autoencoder, diffusion, viz_num, partial_keys=[], guide_scale=9.0, use_offset_noise=None, **kwargs):
super(VisualTrainTextToVideo, self).__init__(**kwargs)
self.cfg = cfg_global
self.viz_num = viz_num
self.diffusion = diffusion
self.autoencoder = autoencoder
self.guide_scale = guide_scale
self.partial_keys_list = partial_keys
self.use_offset_noise = use_offset_noise
def prepare_model_kwargs(self, partial_keys, full_model_kwargs):
"""
"""
model_kwargs = [{}, {}]
for partial_key in partial_keys:
model_kwargs[0][partial_key] = full_model_kwargs[0][partial_key]
model_kwargs[1][partial_key] = full_model_kwargs[1][partial_key]
return model_kwargs
@torch.no_grad()
def run(self,
model,
video_data,
captions,
step=0,
ref_frame=None,
visual_kwards=[],
**kwargs):
cfg = self.cfg
viz_num = self.viz_num
noise = torch.randn_like(video_data[:viz_num]) # viz_num: 8
if self.use_offset_noise:
noise_strength = getattr(cfg, 'noise_strength', 0)
b, c, f, *_ = video_data[:viz_num].shape
noise = noise + noise_strength * torch.randn(b, c, f, 1, 1, device=video_data.device)
# print memory
pynvml.nvmlInit()
handle=pynvml.nvmlDeviceGetHandleByIndex(0)
meminfo=pynvml.nvmlDeviceGetMemoryInfo(handle)
logging.info(f'GPU Memory used {meminfo.used / (1024 ** 3):.2f} GB')
for keys in self.partial_keys_list:
model_kwargs = self.prepare_model_kwargs(keys, visual_kwards)
pre_name = '_'.join(keys)
with amp.autocast(enabled=cfg.use_fp16):
video_data = self.diffusion.ddim_sample_loop(
noise=noise.clone(),
model=model.eval(),
model_kwargs=model_kwargs,
guide_scale=self.guide_scale,
ddim_timesteps=cfg.ddim_timesteps,
eta=0.0)
video_data = 1. / cfg.scale_factor * video_data # [64, 4, 32, 48]
video_data = rearrange(video_data, 'b c f h w -> (b f) c h w')
chunk_size = min(cfg.decoder_bs, video_data.shape[0])
video_data_list = torch.chunk(video_data, video_data.shape[0]//chunk_size,dim=0)
decode_data = []
for vd_data in video_data_list:
gen_frames = self.autoencoder.decode(vd_data)
decode_data.append(gen_frames)
video_data = torch.cat(decode_data, dim=0)
video_data = rearrange(video_data, '(b f) c h w -> b c f h w', b = viz_num)
text_size = cfg.resolution[-1]
ref_frame = ref_frame[:viz_num]
file_name = f'rank_{cfg.world_size:02d}_{cfg.rank:02d}_{cfg.sample_fps:02d}_{pre_name}'
local_path = os.path.join(cfg.log_dir, f'sample_{step:06d}/{file_name}')
os.makedirs(os.path.dirname(local_path), exist_ok=True)
try:
save_video_refimg_and_text(local_path, ref_frame.cpu(), video_data.cpu(), captions, cfg.mean, cfg.std, text_size)
except Exception as e:
logging.info(f'Step: {step} save text or video error with {e}')
from .inference_i2vgen_entrance import *
from .inference_text2video_entrance import *
from .inference_higen_entrance import *
from .inference_sr600_entrance import *
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment