Commit efb4d161 authored by helloyongyang's avatar helloyongyang
Browse files

删除args传参,统一使用config传递,简化代码

parent f4b343f6
......@@ -29,6 +29,7 @@ from lightx2v.text2v.models.video_encoders.hf.wan.vae import WanVAE
from lightx2v.utils.utils import save_videos_grid, seed_all, cache_video
from lightx2v.common.ops import *
from lightx2v.image2v.models.wan.model import CLIPModel
from lightx2v.utils.set_config import set_config
@contextmanager
......@@ -41,92 +42,92 @@ def time_duration(label: str = ""):
print(f"==> {label} start:{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(start_time))} cost {end_time - start_time:.2f} seconds")
def load_models(args, model_config):
if model_config["parallel_attn_type"]:
def load_models(config):
if config["parallel_attn_type"]:
cur_rank = dist.get_rank() # 获取当前进程的 rank
torch.cuda.set_device(cur_rank) # 设置当前进程的 CUDA 设备
image_encoder = None
if args.cpu_offload:
if config.cpu_offload:
init_device = torch.device("cpu")
else:
init_device = torch.device("cuda")
if args.model_cls == "hunyuan":
if args.task == "t2v":
text_encoder_1 = TextEncoderHFLlamaModel(os.path.join(args.model_path, "text_encoder"), init_device)
if config.model_cls == "hunyuan":
if config.task == "t2v":
text_encoder_1 = TextEncoderHFLlamaModel(os.path.join(config.model_path, "text_encoder"), init_device)
else:
text_encoder_1 = TextEncoderHFLlavaModel(os.path.join(args.model_path, "text_encoder_i2v"), init_device)
text_encoder_2 = TextEncoderHFClipModel(os.path.join(args.model_path, "text_encoder_2"), init_device)
text_encoder_1 = TextEncoderHFLlavaModel(os.path.join(config.model_path, "text_encoder_i2v"), init_device)
text_encoder_2 = TextEncoderHFClipModel(os.path.join(config.model_path, "text_encoder_2"), init_device)
text_encoders = [text_encoder_1, text_encoder_2]
model = HunyuanModel(args.model_path, model_config, init_device, args)
vae_model = VideoEncoderKLCausal3DModel(args.model_path, dtype=torch.float16, device=init_device, args=args)
model = HunyuanModel(config.model_path, config, init_device, config)
vae_model = VideoEncoderKLCausal3DModel(config.model_path, dtype=torch.float16, device=init_device, config=config)
elif args.model_cls == "wan2.1":
elif config.model_cls == "wan2.1":
with time_duration("Load Text Encoder"):
text_encoder = T5EncoderModel(
text_len=model_config["text_len"],
text_len=config["text_len"],
dtype=torch.bfloat16,
device=init_device,
checkpoint_path=os.path.join(args.model_path, "models_t5_umt5-xxl-enc-bf16.pth"),
tokenizer_path=os.path.join(args.model_path, "google/umt5-xxl"),
checkpoint_path=os.path.join(config.model_path, "models_t5_umt5-xxl-enc-bf16.pth"),
tokenizer_path=os.path.join(config.model_path, "google/umt5-xxl"),
shard_fn=None,
)
text_encoders = [text_encoder]
with time_duration("Load Wan Model"):
model = WanModel(args.model_path, model_config, init_device)
model = WanModel(config.model_path, config, init_device)
if args.lora_path:
if config.lora_path:
lora_wrapper = WanLoraWrapper(model)
with time_duration("Load LoRA Model"):
lora_name = lora_wrapper.load_lora(args.lora_path)
lora_wrapper.apply_lora(lora_name, args.strength_model)
lora_name = lora_wrapper.load_lora(config.lora_path)
lora_wrapper.apply_lora(lora_name, config.strength_model)
print(f"Loaded LoRA: {lora_name}")
with time_duration("Load WAN VAE Model"):
vae_model = WanVAE(vae_pth=os.path.join(args.model_path, "Wan2.1_VAE.pth"), device=init_device, parallel=args.parallel_vae)
if args.task == "i2v":
vae_model = WanVAE(vae_pth=os.path.join(config.model_path, "Wan2.1_VAE.pth"), device=init_device, parallel=config.parallel_vae)
if config.task == "i2v":
with time_duration("Load Image Encoder"):
image_encoder = CLIPModel(
dtype=torch.float16,
device=init_device,
checkpoint_path=os.path.join(args.model_path, "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"),
tokenizer_path=os.path.join(args.model_path, "xlm-roberta-large"),
checkpoint_path=os.path.join(config.model_path, "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"),
tokenizer_path=os.path.join(config.model_path, "xlm-roberta-large"),
)
else:
raise NotImplementedError(f"Unsupported model class: {args.model_cls}")
raise NotImplementedError(f"Unsupported model class: {config.model_cls}")
return model, text_encoders, vae_model, image_encoder
def set_target_shape(args, image_encoder_output):
if args.model_cls == "hunyuan":
if args.task == "t2v":
def set_target_shape(config, image_encoder_output):
if config.model_cls == "hunyuan":
if config.task == "t2v":
vae_scale_factor = 2 ** (4 - 1)
args.target_shape = (
config.target_shape = (
1,
16,
(args.target_video_length - 1) // 4 + 1,
int(args.target_height) // vae_scale_factor,
int(args.target_width) // vae_scale_factor,
(config.target_video_length - 1) // 4 + 1,
int(config.target_height) // vae_scale_factor,
int(config.target_width) // vae_scale_factor,
)
elif args.task == "i2v":
elif config.task == "i2v":
vae_scale_factor = 2 ** (4 - 1)
args.target_shape = (
config.target_shape = (
1,
16,
(args.target_video_length - 1) // 4 + 1,
(config.target_video_length - 1) // 4 + 1,
int(image_encoder_output["target_height"]) // vae_scale_factor,
int(image_encoder_output["target_width"]) // vae_scale_factor,
)
elif args.model_cls == "wan2.1":
if args.task == "i2v":
args.target_shape = (16, 21, args.lat_h, args.lat_w)
elif args.task == "t2v":
args.target_shape = (
elif config.model_cls == "wan2.1":
if config.task == "i2v":
config.target_shape = (16, 21, config.lat_h, config.lat_w)
elif config.task == "t2v":
config.target_shape = (
16,
(args.target_video_length - 1) // 4 + 1,
int(args.target_height) // args.vae_stride[1],
int(args.target_width) // args.vae_stride[2],
(config.target_video_length - 1) // 4 + 1,
int(config.target_height) // config.vae_stride[1],
int(config.target_width) // config.vae_stride[2],
)
......@@ -161,9 +162,9 @@ def get_closest_ratio(height: float, width: float, ratios: list, buckets: list):
return closest_size, closest_ratio
def run_image_encoder(args, image_encoder, vae_model):
if args.model_cls == "hunyuan":
img = Image.open(args.image_path).convert("RGB")
def run_image_encoder(config, image_encoder, vae_model):
if config.model_cls == "hunyuan":
img = Image.open(config.image_path).convert("RGB")
origin_size = img.size
i2v_resolution = "720p"
......@@ -190,7 +191,7 @@ def run_image_encoder(args, image_encoder, vae_model):
semantic_image_pixel_values = [ref_image_transform(img)]
semantic_image_pixel_values = torch.cat(semantic_image_pixel_values).unsqueeze(0).unsqueeze(2).to(torch.float16).to(torch.device("cuda"))
img_latents = vae_model.encode(semantic_image_pixel_values, args).mode()
img_latents = vae_model.encode(semantic_image_pixel_values, config).mode()
scaling_factor = 0.476986
img_latents.mul_(scaling_factor)
......@@ -199,20 +200,20 @@ def run_image_encoder(args, image_encoder, vae_model):
return {"img": img, "img_latents": img_latents, "target_height": target_height, "target_width": target_width}
elif args.model_cls == "wan2.1":
img = Image.open(args.image_path).convert("RGB")
elif config.model_cls == "wan2.1":
img = Image.open(config.image_path).convert("RGB")
img = TF.to_tensor(img).sub_(0.5).div_(0.5).cuda()
clip_encoder_out = image_encoder.visual([img[:, None, :, :]], args).squeeze(0).to(torch.bfloat16)
clip_encoder_out = image_encoder.visual([img[:, None, :, :]], config).squeeze(0).to(torch.bfloat16)
h, w = img.shape[1:]
aspect_ratio = h / w
max_area = args.target_height * args.target_width
lat_h = round(np.sqrt(max_area * aspect_ratio) // args.vae_stride[1] // args.patch_size[1] * args.patch_size[1])
lat_w = round(np.sqrt(max_area / aspect_ratio) // args.vae_stride[2] // args.patch_size[2] * args.patch_size[2])
h = lat_h * args.vae_stride[1]
w = lat_w * args.vae_stride[2]
max_area = config.target_height * config.target_width
lat_h = round(np.sqrt(max_area * aspect_ratio) // config.vae_stride[1] // config.patch_size[1] * config.patch_size[1])
lat_w = round(np.sqrt(max_area / aspect_ratio) // config.vae_stride[2] // config.patch_size[2] * config.patch_size[2])
h = lat_h * config.vae_stride[1]
w = lat_w * config.vae_stride[2]
args.lat_h = lat_h
args.lat_w = lat_w
config.lat_h = lat_h
config.lat_w = lat_w
msk = torch.ones(1, 81, lat_h, lat_w, device=torch.device("cuda"))
msk[:, 1:] = 0
......@@ -220,64 +221,64 @@ def run_image_encoder(args, image_encoder, vae_model):
msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
msk = msk.transpose(1, 2)[0]
vae_encode_out = vae_model.encode(
[torch.concat([torch.nn.functional.interpolate(img[None].cpu(), size=(h, w), mode="bicubic").transpose(0, 1), torch.zeros(3, 80, h, w)], dim=1).cuda()], args
[torch.concat([torch.nn.functional.interpolate(img[None].cpu(), size=(h, w), mode="bicubic").transpose(0, 1), torch.zeros(3, 80, h, w)], dim=1).cuda()], config
)[0]
vae_encode_out = torch.concat([msk, vae_encode_out]).to(torch.bfloat16)
return {"clip_encoder_out": clip_encoder_out, "vae_encode_out": vae_encode_out}
else:
raise NotImplementedError(f"Unsupported model class: {args.model_cls}")
raise NotImplementedError(f"Unsupported model class: {config.model_cls}")
def run_text_encoder(args, text, text_encoders, model_config, image_encoder_output):
def run_text_encoder(text, text_encoders, config, image_encoder_output):
text_encoder_output = {}
if args.model_cls == "hunyuan":
if config.model_cls == "hunyuan":
for i, encoder in enumerate(text_encoders):
if args.task == "i2v" and i == 0:
text_state, attention_mask = encoder.infer(text, image_encoder_output["img"], args)
if config.task == "i2v" and i == 0:
text_state, attention_mask = encoder.infer(text, image_encoder_output["img"], config)
else:
text_state, attention_mask = encoder.infer(text, args)
text_state, attention_mask = encoder.infer(text, config)
text_encoder_output[f"text_encoder_{i + 1}_text_states"] = text_state.to(dtype=torch.bfloat16)
text_encoder_output[f"text_encoder_{i + 1}_attention_mask"] = attention_mask
elif args.model_cls == "wan2.1":
n_prompt = model_config.get("sample_neg_prompt", "")
context = text_encoders[0].infer([text], args)
context_null = text_encoders[0].infer([n_prompt if n_prompt else ""], args)
elif config.model_cls == "wan2.1":
n_prompt = config.get("sample_neg_prompt", "")
context = text_encoders[0].infer([text], config)
context_null = text_encoders[0].infer([n_prompt if n_prompt else ""], config)
text_encoder_output["context"] = context
text_encoder_output["context_null"] = context_null
else:
raise NotImplementedError(f"Unsupported model type: {args.model_cls}")
raise NotImplementedError(f"Unsupported model type: {config.model_cls}")
return text_encoder_output
def init_scheduler(args, image_encoder_output):
if args.model_cls == "hunyuan":
if args.feature_caching == "NoCaching":
scheduler = HunyuanScheduler(args, image_encoder_output)
elif args.feature_caching == "Tea":
scheduler = HunyuanSchedulerTeaCaching(args, image_encoder_output)
elif args.feature_caching == "TaylorSeer":
scheduler = HunyuanSchedulerTaylorCaching(args, image_encoder_output)
def init_scheduler(config, image_encoder_output):
if config.model_cls == "hunyuan":
if config.feature_caching == "NoCaching":
scheduler = HunyuanScheduler(config, image_encoder_output)
elif config.feature_caching == "Tea":
scheduler = HunyuanSchedulerTeaCaching(config, image_encoder_output)
elif config.feature_caching == "TaylorSeer":
scheduler = HunyuanSchedulerTaylorCaching(config, image_encoder_output)
else:
raise NotImplementedError(f"Unsupported feature_caching type: {args.feature_caching}")
raise NotImplementedError(f"Unsupported feature_caching type: {config.feature_caching}")
elif args.model_cls == "wan2.1":
if args.feature_caching == "NoCaching":
scheduler = WanScheduler(args)
elif args.feature_caching == "Tea":
scheduler = WanSchedulerTeaCaching(args)
elif config.model_cls == "wan2.1":
if config.feature_caching == "NoCaching":
scheduler = WanScheduler(config)
elif config.feature_caching == "Tea":
scheduler = WanSchedulerTeaCaching(config)
else:
raise NotImplementedError(f"Unsupported feature_caching type: {args.feature_caching}")
raise NotImplementedError(f"Unsupported feature_caching type: {config.feature_caching}")
else:
raise NotImplementedError(f"Unsupported model class: {args.model_cls}")
raise NotImplementedError(f"Unsupported model class: {config.model_cls}")
return scheduler
def run_main_inference(args, model, text_encoder_output, image_encoder_output):
def run_main_inference(model, inputs):
for step_index in range(model.scheduler.infer_steps):
torch.cuda.synchronize()
time1 = time.time()
......@@ -287,7 +288,7 @@ def run_main_inference(args, model, text_encoder_output, image_encoder_output):
torch.cuda.synchronize()
time2 = time.time()
model.infer(text_encoder_output, image_encoder_output, args)
model.infer(inputs)
torch.cuda.synchronize()
time3 = time.time()
......@@ -304,8 +305,8 @@ def run_main_inference(args, model, text_encoder_output, image_encoder_output):
return model.scheduler.latents, model.scheduler.generator
def run_vae(latents, generator, args):
images = vae_model.decode(latents, generator=generator, args=args)
def run_vae(latents, generator, config):
images = vae_model.decode(latents, generator=generator, config=config)
return images
......@@ -348,69 +349,49 @@ if __name__ == "__main__":
seed_all(args.seed)
if args.parallel_attn_type:
config = set_config(args)
if config.parallel_attn_type:
dist.init_process_group(backend="nccl")
if args.mm_config:
mm_config = json.loads(args.mm_config)
else:
mm_config = None
model_config = {
"model_cls": args.model_cls,
"task": args.task,
"attention_type": args.attention_type,
"sample_neg_prompt": args.sample_neg_prompt,
"mm_config": mm_config,
"do_mm_calib": args.do_mm_calib,
"cpu_offload": args.cpu_offload,
"feature_caching": args.feature_caching,
"parallel_attn_type": args.parallel_attn_type,
"parallel_vae": args.parallel_vae,
"use_bfloat16": args.use_bfloat16,
}
if args.config_path is not None:
with open(args.config_path, "r") as f:
config = json.load(f)
model_config.update(config)
print(f"model_config: {model_config}")
print(f"config: {config}")
with time_duration("Load models"):
model, text_encoders, vae_model, image_encoder = load_models(args, model_config)
model, text_encoders, vae_model, image_encoder = load_models(config)
if args.task in ["i2v"]:
image_encoder_output = run_image_encoder(args, image_encoder, vae_model)
if config["task"] in ["i2v"]:
image_encoder_output = run_image_encoder(config, image_encoder, vae_model)
else:
image_encoder_output = {"clip_encoder_out": None, "vae_encode_out": None}
with time_duration("Run Text Encoder"):
text_encoder_output = run_text_encoder(args, args.prompt, text_encoders, model_config, image_encoder_output)
text_encoder_output = run_text_encoder(config["prompt"], text_encoders, config, image_encoder_output)
inputs = {"text_encoder_output": text_encoder_output, "image_encoder_output": image_encoder_output}
set_target_shape(args, image_encoder_output)
scheduler = init_scheduler(args, image_encoder_output)
set_target_shape(config, image_encoder_output)
scheduler = init_scheduler(config, image_encoder_output)
model.set_scheduler(scheduler)
gc.collect()
torch.cuda.empty_cache()
latents, generator = run_main_inference(args, model, text_encoder_output, image_encoder_output)
latents, generator = run_main_inference(model, inputs)
if args.cpu_offload:
if config.cpu_offload:
scheduler.clear()
del text_encoder_output, image_encoder_output, model, text_encoders, scheduler
torch.cuda.empty_cache()
with time_duration("Run VAE"):
images = run_vae(latents, generator, args)
images = run_vae(latents, generator, config)
if not args.parallel_attn_type or (args.parallel_attn_type and dist.get_rank() == 0):
if not config.parallel_attn_type or (config.parallel_attn_type and dist.get_rank() == 0):
with time_duration("Save video"):
if args.model_cls == "wan2.1":
cache_video(tensor=images, save_file=args.save_video_path, fps=16, nrow=1, normalize=True, value_range=(-1, 1))
if config.model_cls == "wan2.1":
cache_video(tensor=images, save_file=config.save_video_path, fps=16, nrow=1, normalize=True, value_range=(-1, 1))
else:
save_videos_grid(images, args.save_video_path, fps=24)
save_videos_grid(images, config.save_video_path, fps=24)
end_time = time.time()
print(f"Total cost: {end_time - start_time}")
......@@ -2,17 +2,20 @@ import torch
class HunyuanPostInfer:
def __init__(self):
pass
def __init__(self, config):
self.config = config
def infer(self, weights, img, vec, shape):
def set_scheduler(self, scheduler):
self.scheduler = scheduler
def infer(self, weights, img, vec):
out = torch.nn.functional.silu(vec)
out = weights.final_layer_adaLN_modulation_1.apply(out)
shift, scale = out.chunk(2, dim=1)
out = torch.nn.functional.layer_norm(img, (img.shape[1],), None, None, 1e-6)
out = out * (1 + scale) + shift
out = weights.final_layer_linear.apply(out.to(torch.float32))
_, _, ot, oh, ow = shape
_, _, ot, oh, ow = self.scheduler.latents.shape
patch_size = [1, 2, 2]
tt, th, tw = (
ot // patch_size[0],
......
......@@ -5,11 +5,25 @@ from lightx2v.attentions import attention
class HunyuanPreInfer:
def __init__(self):
def __init__(self, config):
self.heads_num = 24
self.config = config
def infer(self, weights, x, t, text_states, text_mask, text_states_2, freqs_cos, freqs_sin, guidance, img_latents=None):
if img_latents is not None:
def set_scheduler(self, scheduler):
self.scheduler = scheduler
def infer(self, weights, inputs):
x = self.scheduler.latents
t = self.scheduler.timesteps[self.scheduler.step_index]
freqs_cos = self.scheduler.freqs_cos
freqs_sin = self.scheduler.freqs_sin
guidance = self.scheduler.guidance
text_states = inputs["text_encoder_output"]["text_encoder_1_text_states"]
text_mask = inputs["text_encoder_output"]["text_encoder_1_attention_mask"]
text_states_2 = inputs["text_encoder_output"]["text_encoder_2_text_states"]
if self.config["task"] == "i2v":
token_replace_t = torch.zeros_like(t)
token_replace_vec = self.infer_time_in(weights, token_replace_t)
th = x.shape[-2] // 2
......@@ -22,7 +36,7 @@ class HunyuanPreInfer:
infer_vector_out = self.infer_vector_in(weights, text_states_2)
vec = time_out + infer_vector_out
if img_latents is not None:
if self.config["task"] == "i2v":
token_replace_vec = token_replace_vec + infer_vector_out
guidance_out = self.infer_guidance_in(weights, guidance)
......@@ -43,7 +57,7 @@ class HunyuanPreInfer:
cu_seqlens_qkv[2 * i + 2] = s2
max_seqlen_qkv = img_seq_len + txt_seq_len
if img_latents is not None:
if self.config["task"] == "i2v":
return img_out[0], infer_text_out, vec, cu_seqlens_qkv, max_seqlen_qkv, (freqs_cos, freqs_sin), token_replace_vec, frist_frame_token_num
return img_out[0], infer_text_out, vec, cu_seqlens_qkv, max_seqlen_qkv, (freqs_cos, freqs_sin)
......
......@@ -69,12 +69,14 @@ class HunyuanModel:
self.transformer_weights.load_weights(weight_dict)
def _init_infer(self):
self.pre_infer = self.pre_infer_class()
self.post_infer = self.post_infer_class()
self.pre_infer = self.pre_infer_class(self.config)
self.post_infer = self.post_infer_class(self.config)
self.transformer_infer = self.transformer_infer_class(self.config)
def set_scheduler(self, scheduler):
self.scheduler = scheduler
self.pre_infer.set_scheduler(scheduler)
self.post_infer.set_scheduler(scheduler)
self.transformer_infer.set_scheduler(scheduler)
def to_cpu(self):
......@@ -88,28 +90,18 @@ class HunyuanModel:
self.transformer_weights.to_cuda()
@torch.no_grad()
def infer(self, text_encoder_output, image_encoder_output, args):
def infer(self, inputs):
if self.config["cpu_offload"]:
self.pre_weight.to_cuda()
self.post_weight.to_cuda()
pre_infer_out = self.pre_infer.infer(
self.pre_weight,
self.scheduler.latents,
self.scheduler.timesteps[self.scheduler.step_index],
text_encoder_output["text_encoder_1_text_states"],
text_encoder_output["text_encoder_1_attention_mask"],
text_encoder_output["text_encoder_2_text_states"],
self.scheduler.freqs_cos,
self.scheduler.freqs_sin,
self.scheduler.guidance,
img_latents=image_encoder_output["img_latents"] if "img_latents" in image_encoder_output else None,
)
img, vec = self.transformer_infer.infer(self.transformer_weights, *pre_infer_out)
self.scheduler.noise_pred = self.post_infer.infer(self.post_weight, img, vec, self.scheduler.latents.shape)
inputs = self.pre_infer.infer(self.pre_weight, inputs)
inputs = self.transformer_infer.infer(self.transformer_weights, *inputs)
self.scheduler.noise_pred = self.post_infer.infer(self.post_weight, *inputs)
if self.config["cpu_offload"]:
self.pre_weight.to_cpu()
self.post_weight.to_cpu()
if self.config["feature_caching"] == "Tea":
self.scheduler.cnt += 1
if self.scheduler.cnt == self.scheduler.num_steps:
......
......@@ -8,6 +8,9 @@ class WanPostInfer:
self.out_dim = config["out_dim"]
self.patch_size = (1, 2, 2)
def set_scheduler(self, scheduler):
self.scheduler = scheduler
def infer(self, weights, x, e, grid_sizes):
e = (weights.head_modulation + e.unsqueeze(1)).chunk(2, dim=1)
norm_out = torch.nn.functional.layer_norm(x, (x.shape[1],), None, None, 1e-6).type_as(x)
......
......@@ -22,7 +22,20 @@ class WanPreInfer:
self.dim = config["dim"]
self.text_len = config["text_len"]
def infer(self, weights, x, t, context, seq_len, clip_fea=None, y=None):
def set_scheduler(self, scheduler):
self.scheduler = scheduler
def infer(self, weights, inputs, positive):
x = [self.scheduler.latents]
t = torch.stack([self.scheduler.timesteps[self.scheduler.step_index]])
if positive:
context = inputs["text_encoder_output"]["context"]
else:
context = inputs["text_encoder_output"]["context_null"]
seq_len = self.scheduler.seq_len
clip_fea = inputs["image_encoder_output"]["clip_encoder_out"]
y = [inputs["image_encoder_output"]["vae_encode_out"]]
if self.task == "i2v":
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
......
......@@ -95,6 +95,8 @@ class WanModel:
def set_scheduler(self, scheduler):
self.scheduler = scheduler
self.pre_infer.set_scheduler(scheduler)
self.post_infer.set_scheduler(scheduler)
self.transformer_infer.set_scheduler(scheduler)
def to_cpu(self):
......@@ -108,24 +110,13 @@ class WanModel:
self.transformer_weights.to_cuda()
@torch.no_grad()
def infer(self, text_encoders_output, image_encoder_output, args):
timestep = torch.stack([self.scheduler.timesteps[self.scheduler.step_index]])
def infer(self, inputs):
if self.config["cpu_offload"]:
self.pre_weight.to_cuda()
self.post_weight.to_cuda()
embed, grid_sizes, pre_infer_out = self.pre_infer.infer(
self.pre_weight,
[self.scheduler.latents],
timestep,
text_encoders_output["context"],
self.scheduler.seq_len,
image_encoder_output["clip_encoder_out"],
[image_encoder_output["vae_encode_out"]],
)
embed, grid_sizes, pre_infer_out = self.pre_infer.infer(self.pre_weight, inputs, positive=True)
x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out)
noise_pred_cond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0]
if self.config["feature_caching"] == "Tea":
......@@ -133,16 +124,7 @@ class WanModel:
if self.scheduler.cnt >= self.scheduler.num_steps:
self.scheduler.cnt = 0
embed, grid_sizes, pre_infer_out = self.pre_infer.infer(
self.pre_weight,
[self.scheduler.latents],
timestep,
text_encoders_output["context_null"],
self.scheduler.seq_len,
image_encoder_output["clip_encoder_out"],
[image_encoder_output["vae_encode_out"]],
)
embed, grid_sizes, pre_infer_out = self.pre_infer.infer(self.pre_weight, inputs, positive=False)
x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out)
noise_pred_uncond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0]
......@@ -151,7 +133,7 @@ class WanModel:
if self.scheduler.cnt >= self.scheduler.num_steps:
self.scheduler.cnt = 0
self.scheduler.noise_pred = noise_pred_uncond + args.sample_guide_scale * (noise_pred_cond - noise_pred_uncond)
self.scheduler.noise_pred = noise_pred_uncond + self.config.sample_guide_scale * (noise_pred_cond - noise_pred_uncond)
if self.config["cpu_offload"]:
self.pre_weight.to_cpu()
......
......@@ -23,8 +23,8 @@ class TextEncoderHFClipModel:
self.model = self.model.to("cuda")
@torch.no_grad()
def infer(self, text, args):
if args.cpu_offload:
def infer(self, text, config):
if config.cpu_offload:
self.to_cuda()
tokens = self.tokenizer(
text,
......@@ -44,7 +44,7 @@ class TextEncoderHFClipModel:
)
last_hidden_state = outputs["pooler_output"]
if args.cpu_offload:
if config.cpu_offload:
self.to_cpu()
return last_hidden_state, tokens["attention_mask"]
......
......@@ -34,8 +34,8 @@ class TextEncoderHFLlamaModel:
self.model = self.model.to("cuda")
@torch.no_grad()
def infer(self, text, args):
if args.cpu_offload:
def infer(self, text, config):
if config.cpu_offload:
self.to_cuda()
text = self.prompt_template.format(text)
tokens = self.tokenizer(
......@@ -57,7 +57,7 @@ class TextEncoderHFLlamaModel:
last_hidden_state = outputs.hidden_states[-(self.hidden_state_skip_layer + 1)][:, self.crop_start :]
attention_mask = tokens["attention_mask"][:, self.crop_start :]
if args.cpu_offload:
if config.cpu_offload:
self.to_cpu()
return last_hidden_state, attention_mask
......
......@@ -98,9 +98,9 @@ class TextEncoderHFLlavaModel:
self.model = self.model.to("cuda")
@torch.no_grad()
def infer(self, text, img, args):
# if args.cpu_offload:
# self.to_cuda()
def infer(self, text, img, config):
if config.cpu_offload:
self.to_cuda()
text = self.prompt_template.format(text)
print(f"text: {text}")
tokens = self.tokenizer(
......@@ -148,8 +148,8 @@ class TextEncoderHFLlavaModel:
last_hidden_state = torch.cat([image_last_hidden_state, text_last_hidden_state], dim=1)
attention_mask = torch.cat([image_attention_mask, text_attention_mask], dim=1)
# if args.cpu_offload:
# self.to_cpu()
if config.cpu_offload:
self.to_cpu()
return last_hidden_state, attention_mask
......
......@@ -492,8 +492,8 @@ class T5EncoderModel:
def to_cuda(self):
self.model = self.model.to("cuda")
def infer(self, texts, args):
if args.cpu_offload:
def infer(self, texts, config):
if config.cpu_offload:
self.to_cuda()
ids, mask = self.tokenizer(texts, return_mask=True, add_special_tokens=True)
......@@ -502,7 +502,7 @@ class T5EncoderModel:
seq_lens = mask.gt(0).sum(dim=1).long()
context = self.model(ids, mask)
if args.cpu_offload:
if config.cpu_offload:
self.to_cpu()
return [u[:v] for u, v in zip(context, seq_lens)]
......
......@@ -4,15 +4,15 @@ from .autoencoder_kl_causal_3d import AutoencoderKLCausal3D, DiagonalGaussianDis
class VideoEncoderKLCausal3DModel:
def __init__(self, model_path, dtype, device, args):
def __init__(self, model_path, dtype, device, config):
self.model_path = model_path
self.dtype = dtype
self.device = device
self.args = args
self.config = config
self.load()
def load(self):
if self.args.task == "t2v":
if self.config.task == "t2v":
self.vae_path = os.path.join(self.model_path, "hunyuan-video-t2v-720p/vae")
else:
self.vae_path = os.path.join(self.model_path, "hunyuan-video-i2v-720p/vae")
......@@ -30,8 +30,8 @@ class VideoEncoderKLCausal3DModel:
def to_cuda(self):
self.model = self.model.to("cuda")
def decode(self, latents, generator, args):
if args.cpu_offload:
def decode(self, latents, generator, config):
if config.cpu_offload:
self.to_cuda()
latents = latents / self.model.config.scaling_factor
latents = latents.to(dtype=self.dtype, device=torch.device("cuda"))
......@@ -39,11 +39,11 @@ class VideoEncoderKLCausal3DModel:
image = self.model.decode(latents, return_dict=False, generator=generator)[0]
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().float()
if args.cpu_offload:
if config.cpu_offload:
self.to_cpu()
return image
def encode(self, x, args):
def encode(self, x, config):
h = self.model.encoder(x)
moments = self.model.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
......
......@@ -786,8 +786,8 @@ class WanVAE:
return images
def decode(self, zs, generator, args):
if args.cpu_offload:
def decode(self, zs, generator, config):
if config.cpu_offload:
self.to_cuda()
if self.parallel:
......@@ -806,7 +806,7 @@ class WanVAE:
else:
images = self.model.decode(zs.unsqueeze(0), self.scale).float().clamp_(-1, 1)
if args.cpu_offload:
if config.cpu_offload:
images = images.cpu().float()
self.to_cpu()
......
import json
from easydict import EasyDict
def set_config(args):
config = {k: v for k, v in vars(args).items()}
config = EasyDict(config)
if args.mm_config:
config.mm_config = json.loads(args.mm_config)
else:
config.mm_config = None
if args.config_path is not None:
with open(args.config_path, "r") as f:
model_config = json.load(f)
config.update(model_config)
return config
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment