Commit efb4d161 authored by helloyongyang's avatar helloyongyang
Browse files

删除args传参,统一使用config传递,简化代码

parent f4b343f6
...@@ -29,6 +29,7 @@ from lightx2v.text2v.models.video_encoders.hf.wan.vae import WanVAE ...@@ -29,6 +29,7 @@ from lightx2v.text2v.models.video_encoders.hf.wan.vae import WanVAE
from lightx2v.utils.utils import save_videos_grid, seed_all, cache_video from lightx2v.utils.utils import save_videos_grid, seed_all, cache_video
from lightx2v.common.ops import * from lightx2v.common.ops import *
from lightx2v.image2v.models.wan.model import CLIPModel from lightx2v.image2v.models.wan.model import CLIPModel
from lightx2v.utils.set_config import set_config
@contextmanager @contextmanager
...@@ -41,92 +42,92 @@ def time_duration(label: str = ""): ...@@ -41,92 +42,92 @@ def time_duration(label: str = ""):
print(f"==> {label} start:{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(start_time))} cost {end_time - start_time:.2f} seconds") print(f"==> {label} start:{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(start_time))} cost {end_time - start_time:.2f} seconds")
def load_models(args, model_config): def load_models(config):
if model_config["parallel_attn_type"]: if config["parallel_attn_type"]:
cur_rank = dist.get_rank() # 获取当前进程的 rank cur_rank = dist.get_rank() # 获取当前进程的 rank
torch.cuda.set_device(cur_rank) # 设置当前进程的 CUDA 设备 torch.cuda.set_device(cur_rank) # 设置当前进程的 CUDA 设备
image_encoder = None image_encoder = None
if args.cpu_offload: if config.cpu_offload:
init_device = torch.device("cpu") init_device = torch.device("cpu")
else: else:
init_device = torch.device("cuda") init_device = torch.device("cuda")
if args.model_cls == "hunyuan": if config.model_cls == "hunyuan":
if args.task == "t2v": if config.task == "t2v":
text_encoder_1 = TextEncoderHFLlamaModel(os.path.join(args.model_path, "text_encoder"), init_device) text_encoder_1 = TextEncoderHFLlamaModel(os.path.join(config.model_path, "text_encoder"), init_device)
else: else:
text_encoder_1 = TextEncoderHFLlavaModel(os.path.join(args.model_path, "text_encoder_i2v"), init_device) text_encoder_1 = TextEncoderHFLlavaModel(os.path.join(config.model_path, "text_encoder_i2v"), init_device)
text_encoder_2 = TextEncoderHFClipModel(os.path.join(args.model_path, "text_encoder_2"), init_device) text_encoder_2 = TextEncoderHFClipModel(os.path.join(config.model_path, "text_encoder_2"), init_device)
text_encoders = [text_encoder_1, text_encoder_2] text_encoders = [text_encoder_1, text_encoder_2]
model = HunyuanModel(args.model_path, model_config, init_device, args) model = HunyuanModel(config.model_path, config, init_device, config)
vae_model = VideoEncoderKLCausal3DModel(args.model_path, dtype=torch.float16, device=init_device, args=args) vae_model = VideoEncoderKLCausal3DModel(config.model_path, dtype=torch.float16, device=init_device, config=config)
elif args.model_cls == "wan2.1": elif config.model_cls == "wan2.1":
with time_duration("Load Text Encoder"): with time_duration("Load Text Encoder"):
text_encoder = T5EncoderModel( text_encoder = T5EncoderModel(
text_len=model_config["text_len"], text_len=config["text_len"],
dtype=torch.bfloat16, dtype=torch.bfloat16,
device=init_device, device=init_device,
checkpoint_path=os.path.join(args.model_path, "models_t5_umt5-xxl-enc-bf16.pth"), checkpoint_path=os.path.join(config.model_path, "models_t5_umt5-xxl-enc-bf16.pth"),
tokenizer_path=os.path.join(args.model_path, "google/umt5-xxl"), tokenizer_path=os.path.join(config.model_path, "google/umt5-xxl"),
shard_fn=None, shard_fn=None,
) )
text_encoders = [text_encoder] text_encoders = [text_encoder]
with time_duration("Load Wan Model"): with time_duration("Load Wan Model"):
model = WanModel(args.model_path, model_config, init_device) model = WanModel(config.model_path, config, init_device)
if args.lora_path: if config.lora_path:
lora_wrapper = WanLoraWrapper(model) lora_wrapper = WanLoraWrapper(model)
with time_duration("Load LoRA Model"): with time_duration("Load LoRA Model"):
lora_name = lora_wrapper.load_lora(args.lora_path) lora_name = lora_wrapper.load_lora(config.lora_path)
lora_wrapper.apply_lora(lora_name, args.strength_model) lora_wrapper.apply_lora(lora_name, config.strength_model)
print(f"Loaded LoRA: {lora_name}") print(f"Loaded LoRA: {lora_name}")
with time_duration("Load WAN VAE Model"): with time_duration("Load WAN VAE Model"):
vae_model = WanVAE(vae_pth=os.path.join(args.model_path, "Wan2.1_VAE.pth"), device=init_device, parallel=args.parallel_vae) vae_model = WanVAE(vae_pth=os.path.join(config.model_path, "Wan2.1_VAE.pth"), device=init_device, parallel=config.parallel_vae)
if args.task == "i2v": if config.task == "i2v":
with time_duration("Load Image Encoder"): with time_duration("Load Image Encoder"):
image_encoder = CLIPModel( image_encoder = CLIPModel(
dtype=torch.float16, dtype=torch.float16,
device=init_device, device=init_device,
checkpoint_path=os.path.join(args.model_path, "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"), checkpoint_path=os.path.join(config.model_path, "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"),
tokenizer_path=os.path.join(args.model_path, "xlm-roberta-large"), tokenizer_path=os.path.join(config.model_path, "xlm-roberta-large"),
) )
else: else:
raise NotImplementedError(f"Unsupported model class: {args.model_cls}") raise NotImplementedError(f"Unsupported model class: {config.model_cls}")
return model, text_encoders, vae_model, image_encoder return model, text_encoders, vae_model, image_encoder
def set_target_shape(args, image_encoder_output): def set_target_shape(config, image_encoder_output):
if args.model_cls == "hunyuan": if config.model_cls == "hunyuan":
if args.task == "t2v": if config.task == "t2v":
vae_scale_factor = 2 ** (4 - 1) vae_scale_factor = 2 ** (4 - 1)
args.target_shape = ( config.target_shape = (
1, 1,
16, 16,
(args.target_video_length - 1) // 4 + 1, (config.target_video_length - 1) // 4 + 1,
int(args.target_height) // vae_scale_factor, int(config.target_height) // vae_scale_factor,
int(args.target_width) // vae_scale_factor, int(config.target_width) // vae_scale_factor,
) )
elif args.task == "i2v": elif config.task == "i2v":
vae_scale_factor = 2 ** (4 - 1) vae_scale_factor = 2 ** (4 - 1)
args.target_shape = ( config.target_shape = (
1, 1,
16, 16,
(args.target_video_length - 1) // 4 + 1, (config.target_video_length - 1) // 4 + 1,
int(image_encoder_output["target_height"]) // vae_scale_factor, int(image_encoder_output["target_height"]) // vae_scale_factor,
int(image_encoder_output["target_width"]) // vae_scale_factor, int(image_encoder_output["target_width"]) // vae_scale_factor,
) )
elif args.model_cls == "wan2.1": elif config.model_cls == "wan2.1":
if args.task == "i2v": if config.task == "i2v":
args.target_shape = (16, 21, args.lat_h, args.lat_w) config.target_shape = (16, 21, config.lat_h, config.lat_w)
elif args.task == "t2v": elif config.task == "t2v":
args.target_shape = ( config.target_shape = (
16, 16,
(args.target_video_length - 1) // 4 + 1, (config.target_video_length - 1) // 4 + 1,
int(args.target_height) // args.vae_stride[1], int(config.target_height) // config.vae_stride[1],
int(args.target_width) // args.vae_stride[2], int(config.target_width) // config.vae_stride[2],
) )
...@@ -161,9 +162,9 @@ def get_closest_ratio(height: float, width: float, ratios: list, buckets: list): ...@@ -161,9 +162,9 @@ def get_closest_ratio(height: float, width: float, ratios: list, buckets: list):
return closest_size, closest_ratio return closest_size, closest_ratio
def run_image_encoder(args, image_encoder, vae_model): def run_image_encoder(config, image_encoder, vae_model):
if args.model_cls == "hunyuan": if config.model_cls == "hunyuan":
img = Image.open(args.image_path).convert("RGB") img = Image.open(config.image_path).convert("RGB")
origin_size = img.size origin_size = img.size
i2v_resolution = "720p" i2v_resolution = "720p"
...@@ -190,7 +191,7 @@ def run_image_encoder(args, image_encoder, vae_model): ...@@ -190,7 +191,7 @@ def run_image_encoder(args, image_encoder, vae_model):
semantic_image_pixel_values = [ref_image_transform(img)] semantic_image_pixel_values = [ref_image_transform(img)]
semantic_image_pixel_values = torch.cat(semantic_image_pixel_values).unsqueeze(0).unsqueeze(2).to(torch.float16).to(torch.device("cuda")) semantic_image_pixel_values = torch.cat(semantic_image_pixel_values).unsqueeze(0).unsqueeze(2).to(torch.float16).to(torch.device("cuda"))
img_latents = vae_model.encode(semantic_image_pixel_values, args).mode() img_latents = vae_model.encode(semantic_image_pixel_values, config).mode()
scaling_factor = 0.476986 scaling_factor = 0.476986
img_latents.mul_(scaling_factor) img_latents.mul_(scaling_factor)
...@@ -199,20 +200,20 @@ def run_image_encoder(args, image_encoder, vae_model): ...@@ -199,20 +200,20 @@ def run_image_encoder(args, image_encoder, vae_model):
return {"img": img, "img_latents": img_latents, "target_height": target_height, "target_width": target_width} return {"img": img, "img_latents": img_latents, "target_height": target_height, "target_width": target_width}
elif args.model_cls == "wan2.1": elif config.model_cls == "wan2.1":
img = Image.open(args.image_path).convert("RGB") img = Image.open(config.image_path).convert("RGB")
img = TF.to_tensor(img).sub_(0.5).div_(0.5).cuda() img = TF.to_tensor(img).sub_(0.5).div_(0.5).cuda()
clip_encoder_out = image_encoder.visual([img[:, None, :, :]], args).squeeze(0).to(torch.bfloat16) clip_encoder_out = image_encoder.visual([img[:, None, :, :]], config).squeeze(0).to(torch.bfloat16)
h, w = img.shape[1:] h, w = img.shape[1:]
aspect_ratio = h / w aspect_ratio = h / w
max_area = args.target_height * args.target_width max_area = config.target_height * config.target_width
lat_h = round(np.sqrt(max_area * aspect_ratio) // args.vae_stride[1] // args.patch_size[1] * args.patch_size[1]) lat_h = round(np.sqrt(max_area * aspect_ratio) // config.vae_stride[1] // config.patch_size[1] * config.patch_size[1])
lat_w = round(np.sqrt(max_area / aspect_ratio) // args.vae_stride[2] // args.patch_size[2] * args.patch_size[2]) lat_w = round(np.sqrt(max_area / aspect_ratio) // config.vae_stride[2] // config.patch_size[2] * config.patch_size[2])
h = lat_h * args.vae_stride[1] h = lat_h * config.vae_stride[1]
w = lat_w * args.vae_stride[2] w = lat_w * config.vae_stride[2]
args.lat_h = lat_h config.lat_h = lat_h
args.lat_w = lat_w config.lat_w = lat_w
msk = torch.ones(1, 81, lat_h, lat_w, device=torch.device("cuda")) msk = torch.ones(1, 81, lat_h, lat_w, device=torch.device("cuda"))
msk[:, 1:] = 0 msk[:, 1:] = 0
...@@ -220,64 +221,64 @@ def run_image_encoder(args, image_encoder, vae_model): ...@@ -220,64 +221,64 @@ def run_image_encoder(args, image_encoder, vae_model):
msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w) msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
msk = msk.transpose(1, 2)[0] msk = msk.transpose(1, 2)[0]
vae_encode_out = vae_model.encode( vae_encode_out = vae_model.encode(
[torch.concat([torch.nn.functional.interpolate(img[None].cpu(), size=(h, w), mode="bicubic").transpose(0, 1), torch.zeros(3, 80, h, w)], dim=1).cuda()], args [torch.concat([torch.nn.functional.interpolate(img[None].cpu(), size=(h, w), mode="bicubic").transpose(0, 1), torch.zeros(3, 80, h, w)], dim=1).cuda()], config
)[0] )[0]
vae_encode_out = torch.concat([msk, vae_encode_out]).to(torch.bfloat16) vae_encode_out = torch.concat([msk, vae_encode_out]).to(torch.bfloat16)
return {"clip_encoder_out": clip_encoder_out, "vae_encode_out": vae_encode_out} return {"clip_encoder_out": clip_encoder_out, "vae_encode_out": vae_encode_out}
else: else:
raise NotImplementedError(f"Unsupported model class: {args.model_cls}") raise NotImplementedError(f"Unsupported model class: {config.model_cls}")
def run_text_encoder(args, text, text_encoders, model_config, image_encoder_output): def run_text_encoder(text, text_encoders, config, image_encoder_output):
text_encoder_output = {} text_encoder_output = {}
if args.model_cls == "hunyuan": if config.model_cls == "hunyuan":
for i, encoder in enumerate(text_encoders): for i, encoder in enumerate(text_encoders):
if args.task == "i2v" and i == 0: if config.task == "i2v" and i == 0:
text_state, attention_mask = encoder.infer(text, image_encoder_output["img"], args) text_state, attention_mask = encoder.infer(text, image_encoder_output["img"], config)
else: else:
text_state, attention_mask = encoder.infer(text, args) text_state, attention_mask = encoder.infer(text, config)
text_encoder_output[f"text_encoder_{i + 1}_text_states"] = text_state.to(dtype=torch.bfloat16) text_encoder_output[f"text_encoder_{i + 1}_text_states"] = text_state.to(dtype=torch.bfloat16)
text_encoder_output[f"text_encoder_{i + 1}_attention_mask"] = attention_mask text_encoder_output[f"text_encoder_{i + 1}_attention_mask"] = attention_mask
elif args.model_cls == "wan2.1": elif config.model_cls == "wan2.1":
n_prompt = model_config.get("sample_neg_prompt", "") n_prompt = config.get("sample_neg_prompt", "")
context = text_encoders[0].infer([text], args) context = text_encoders[0].infer([text], config)
context_null = text_encoders[0].infer([n_prompt if n_prompt else ""], args) context_null = text_encoders[0].infer([n_prompt if n_prompt else ""], config)
text_encoder_output["context"] = context text_encoder_output["context"] = context
text_encoder_output["context_null"] = context_null text_encoder_output["context_null"] = context_null
else: else:
raise NotImplementedError(f"Unsupported model type: {args.model_cls}") raise NotImplementedError(f"Unsupported model type: {config.model_cls}")
return text_encoder_output return text_encoder_output
def init_scheduler(args, image_encoder_output): def init_scheduler(config, image_encoder_output):
if args.model_cls == "hunyuan": if config.model_cls == "hunyuan":
if args.feature_caching == "NoCaching": if config.feature_caching == "NoCaching":
scheduler = HunyuanScheduler(args, image_encoder_output) scheduler = HunyuanScheduler(config, image_encoder_output)
elif args.feature_caching == "Tea": elif config.feature_caching == "Tea":
scheduler = HunyuanSchedulerTeaCaching(args, image_encoder_output) scheduler = HunyuanSchedulerTeaCaching(config, image_encoder_output)
elif args.feature_caching == "TaylorSeer": elif config.feature_caching == "TaylorSeer":
scheduler = HunyuanSchedulerTaylorCaching(args, image_encoder_output) scheduler = HunyuanSchedulerTaylorCaching(config, image_encoder_output)
else: else:
raise NotImplementedError(f"Unsupported feature_caching type: {args.feature_caching}") raise NotImplementedError(f"Unsupported feature_caching type: {config.feature_caching}")
elif args.model_cls == "wan2.1": elif config.model_cls == "wan2.1":
if args.feature_caching == "NoCaching": if config.feature_caching == "NoCaching":
scheduler = WanScheduler(args) scheduler = WanScheduler(config)
elif args.feature_caching == "Tea": elif config.feature_caching == "Tea":
scheduler = WanSchedulerTeaCaching(args) scheduler = WanSchedulerTeaCaching(config)
else: else:
raise NotImplementedError(f"Unsupported feature_caching type: {args.feature_caching}") raise NotImplementedError(f"Unsupported feature_caching type: {config.feature_caching}")
else: else:
raise NotImplementedError(f"Unsupported model class: {args.model_cls}") raise NotImplementedError(f"Unsupported model class: {config.model_cls}")
return scheduler return scheduler
def run_main_inference(args, model, text_encoder_output, image_encoder_output): def run_main_inference(model, inputs):
for step_index in range(model.scheduler.infer_steps): for step_index in range(model.scheduler.infer_steps):
torch.cuda.synchronize() torch.cuda.synchronize()
time1 = time.time() time1 = time.time()
...@@ -287,7 +288,7 @@ def run_main_inference(args, model, text_encoder_output, image_encoder_output): ...@@ -287,7 +288,7 @@ def run_main_inference(args, model, text_encoder_output, image_encoder_output):
torch.cuda.synchronize() torch.cuda.synchronize()
time2 = time.time() time2 = time.time()
model.infer(text_encoder_output, image_encoder_output, args) model.infer(inputs)
torch.cuda.synchronize() torch.cuda.synchronize()
time3 = time.time() time3 = time.time()
...@@ -304,8 +305,8 @@ def run_main_inference(args, model, text_encoder_output, image_encoder_output): ...@@ -304,8 +305,8 @@ def run_main_inference(args, model, text_encoder_output, image_encoder_output):
return model.scheduler.latents, model.scheduler.generator return model.scheduler.latents, model.scheduler.generator
def run_vae(latents, generator, args): def run_vae(latents, generator, config):
images = vae_model.decode(latents, generator=generator, args=args) images = vae_model.decode(latents, generator=generator, config=config)
return images return images
...@@ -348,69 +349,49 @@ if __name__ == "__main__": ...@@ -348,69 +349,49 @@ if __name__ == "__main__":
seed_all(args.seed) seed_all(args.seed)
if args.parallel_attn_type: config = set_config(args)
if config.parallel_attn_type:
dist.init_process_group(backend="nccl") dist.init_process_group(backend="nccl")
if args.mm_config: print(f"config: {config}")
mm_config = json.loads(args.mm_config)
else:
mm_config = None
model_config = {
"model_cls": args.model_cls,
"task": args.task,
"attention_type": args.attention_type,
"sample_neg_prompt": args.sample_neg_prompt,
"mm_config": mm_config,
"do_mm_calib": args.do_mm_calib,
"cpu_offload": args.cpu_offload,
"feature_caching": args.feature_caching,
"parallel_attn_type": args.parallel_attn_type,
"parallel_vae": args.parallel_vae,
"use_bfloat16": args.use_bfloat16,
}
if args.config_path is not None:
with open(args.config_path, "r") as f:
config = json.load(f)
model_config.update(config)
print(f"model_config: {model_config}")
with time_duration("Load models"): with time_duration("Load models"):
model, text_encoders, vae_model, image_encoder = load_models(args, model_config) model, text_encoders, vae_model, image_encoder = load_models(config)
if args.task in ["i2v"]: if config["task"] in ["i2v"]:
image_encoder_output = run_image_encoder(args, image_encoder, vae_model) image_encoder_output = run_image_encoder(config, image_encoder, vae_model)
else: else:
image_encoder_output = {"clip_encoder_out": None, "vae_encode_out": None} image_encoder_output = {"clip_encoder_out": None, "vae_encode_out": None}
with time_duration("Run Text Encoder"): with time_duration("Run Text Encoder"):
text_encoder_output = run_text_encoder(args, args.prompt, text_encoders, model_config, image_encoder_output) text_encoder_output = run_text_encoder(config["prompt"], text_encoders, config, image_encoder_output)
inputs = {"text_encoder_output": text_encoder_output, "image_encoder_output": image_encoder_output}
set_target_shape(args, image_encoder_output) set_target_shape(config, image_encoder_output)
scheduler = init_scheduler(args, image_encoder_output) scheduler = init_scheduler(config, image_encoder_output)
model.set_scheduler(scheduler) model.set_scheduler(scheduler)
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
latents, generator = run_main_inference(args, model, text_encoder_output, image_encoder_output) latents, generator = run_main_inference(model, inputs)
if args.cpu_offload: if config.cpu_offload:
scheduler.clear() scheduler.clear()
del text_encoder_output, image_encoder_output, model, text_encoders, scheduler del text_encoder_output, image_encoder_output, model, text_encoders, scheduler
torch.cuda.empty_cache() torch.cuda.empty_cache()
with time_duration("Run VAE"): with time_duration("Run VAE"):
images = run_vae(latents, generator, args) images = run_vae(latents, generator, config)
if not args.parallel_attn_type or (args.parallel_attn_type and dist.get_rank() == 0): if not config.parallel_attn_type or (config.parallel_attn_type and dist.get_rank() == 0):
with time_duration("Save video"): with time_duration("Save video"):
if args.model_cls == "wan2.1": if config.model_cls == "wan2.1":
cache_video(tensor=images, save_file=args.save_video_path, fps=16, nrow=1, normalize=True, value_range=(-1, 1)) cache_video(tensor=images, save_file=config.save_video_path, fps=16, nrow=1, normalize=True, value_range=(-1, 1))
else: else:
save_videos_grid(images, args.save_video_path, fps=24) save_videos_grid(images, config.save_video_path, fps=24)
end_time = time.time() end_time = time.time()
print(f"Total cost: {end_time - start_time}") print(f"Total cost: {end_time - start_time}")
...@@ -2,17 +2,20 @@ import torch ...@@ -2,17 +2,20 @@ import torch
class HunyuanPostInfer: class HunyuanPostInfer:
def __init__(self): def __init__(self, config):
pass self.config = config
def infer(self, weights, img, vec, shape): def set_scheduler(self, scheduler):
self.scheduler = scheduler
def infer(self, weights, img, vec):
out = torch.nn.functional.silu(vec) out = torch.nn.functional.silu(vec)
out = weights.final_layer_adaLN_modulation_1.apply(out) out = weights.final_layer_adaLN_modulation_1.apply(out)
shift, scale = out.chunk(2, dim=1) shift, scale = out.chunk(2, dim=1)
out = torch.nn.functional.layer_norm(img, (img.shape[1],), None, None, 1e-6) out = torch.nn.functional.layer_norm(img, (img.shape[1],), None, None, 1e-6)
out = out * (1 + scale) + shift out = out * (1 + scale) + shift
out = weights.final_layer_linear.apply(out.to(torch.float32)) out = weights.final_layer_linear.apply(out.to(torch.float32))
_, _, ot, oh, ow = shape _, _, ot, oh, ow = self.scheduler.latents.shape
patch_size = [1, 2, 2] patch_size = [1, 2, 2]
tt, th, tw = ( tt, th, tw = (
ot // patch_size[0], ot // patch_size[0],
......
...@@ -5,11 +5,25 @@ from lightx2v.attentions import attention ...@@ -5,11 +5,25 @@ from lightx2v.attentions import attention
class HunyuanPreInfer: class HunyuanPreInfer:
def __init__(self): def __init__(self, config):
self.heads_num = 24 self.heads_num = 24
self.config = config
def infer(self, weights, x, t, text_states, text_mask, text_states_2, freqs_cos, freqs_sin, guidance, img_latents=None): def set_scheduler(self, scheduler):
if img_latents is not None: self.scheduler = scheduler
def infer(self, weights, inputs):
x = self.scheduler.latents
t = self.scheduler.timesteps[self.scheduler.step_index]
freqs_cos = self.scheduler.freqs_cos
freqs_sin = self.scheduler.freqs_sin
guidance = self.scheduler.guidance
text_states = inputs["text_encoder_output"]["text_encoder_1_text_states"]
text_mask = inputs["text_encoder_output"]["text_encoder_1_attention_mask"]
text_states_2 = inputs["text_encoder_output"]["text_encoder_2_text_states"]
if self.config["task"] == "i2v":
token_replace_t = torch.zeros_like(t) token_replace_t = torch.zeros_like(t)
token_replace_vec = self.infer_time_in(weights, token_replace_t) token_replace_vec = self.infer_time_in(weights, token_replace_t)
th = x.shape[-2] // 2 th = x.shape[-2] // 2
...@@ -22,7 +36,7 @@ class HunyuanPreInfer: ...@@ -22,7 +36,7 @@ class HunyuanPreInfer:
infer_vector_out = self.infer_vector_in(weights, text_states_2) infer_vector_out = self.infer_vector_in(weights, text_states_2)
vec = time_out + infer_vector_out vec = time_out + infer_vector_out
if img_latents is not None: if self.config["task"] == "i2v":
token_replace_vec = token_replace_vec + infer_vector_out token_replace_vec = token_replace_vec + infer_vector_out
guidance_out = self.infer_guidance_in(weights, guidance) guidance_out = self.infer_guidance_in(weights, guidance)
...@@ -43,7 +57,7 @@ class HunyuanPreInfer: ...@@ -43,7 +57,7 @@ class HunyuanPreInfer:
cu_seqlens_qkv[2 * i + 2] = s2 cu_seqlens_qkv[2 * i + 2] = s2
max_seqlen_qkv = img_seq_len + txt_seq_len max_seqlen_qkv = img_seq_len + txt_seq_len
if img_latents is not None: if self.config["task"] == "i2v":
return img_out[0], infer_text_out, vec, cu_seqlens_qkv, max_seqlen_qkv, (freqs_cos, freqs_sin), token_replace_vec, frist_frame_token_num return img_out[0], infer_text_out, vec, cu_seqlens_qkv, max_seqlen_qkv, (freqs_cos, freqs_sin), token_replace_vec, frist_frame_token_num
return img_out[0], infer_text_out, vec, cu_seqlens_qkv, max_seqlen_qkv, (freqs_cos, freqs_sin) return img_out[0], infer_text_out, vec, cu_seqlens_qkv, max_seqlen_qkv, (freqs_cos, freqs_sin)
......
...@@ -69,12 +69,14 @@ class HunyuanModel: ...@@ -69,12 +69,14 @@ class HunyuanModel:
self.transformer_weights.load_weights(weight_dict) self.transformer_weights.load_weights(weight_dict)
def _init_infer(self): def _init_infer(self):
self.pre_infer = self.pre_infer_class() self.pre_infer = self.pre_infer_class(self.config)
self.post_infer = self.post_infer_class() self.post_infer = self.post_infer_class(self.config)
self.transformer_infer = self.transformer_infer_class(self.config) self.transformer_infer = self.transformer_infer_class(self.config)
def set_scheduler(self, scheduler): def set_scheduler(self, scheduler):
self.scheduler = scheduler self.scheduler = scheduler
self.pre_infer.set_scheduler(scheduler)
self.post_infer.set_scheduler(scheduler)
self.transformer_infer.set_scheduler(scheduler) self.transformer_infer.set_scheduler(scheduler)
def to_cpu(self): def to_cpu(self):
...@@ -88,28 +90,18 @@ class HunyuanModel: ...@@ -88,28 +90,18 @@ class HunyuanModel:
self.transformer_weights.to_cuda() self.transformer_weights.to_cuda()
@torch.no_grad() @torch.no_grad()
def infer(self, text_encoder_output, image_encoder_output, args): def infer(self, inputs):
if self.config["cpu_offload"]: if self.config["cpu_offload"]:
self.pre_weight.to_cuda() self.pre_weight.to_cuda()
self.post_weight.to_cuda() self.post_weight.to_cuda()
pre_infer_out = self.pre_infer.infer(
self.pre_weight, inputs = self.pre_infer.infer(self.pre_weight, inputs)
self.scheduler.latents, inputs = self.transformer_infer.infer(self.transformer_weights, *inputs)
self.scheduler.timesteps[self.scheduler.step_index], self.scheduler.noise_pred = self.post_infer.infer(self.post_weight, *inputs)
text_encoder_output["text_encoder_1_text_states"],
text_encoder_output["text_encoder_1_attention_mask"],
text_encoder_output["text_encoder_2_text_states"],
self.scheduler.freqs_cos,
self.scheduler.freqs_sin,
self.scheduler.guidance,
img_latents=image_encoder_output["img_latents"] if "img_latents" in image_encoder_output else None,
)
img, vec = self.transformer_infer.infer(self.transformer_weights, *pre_infer_out)
self.scheduler.noise_pred = self.post_infer.infer(self.post_weight, img, vec, self.scheduler.latents.shape)
if self.config["cpu_offload"]: if self.config["cpu_offload"]:
self.pre_weight.to_cpu() self.pre_weight.to_cpu()
self.post_weight.to_cpu() self.post_weight.to_cpu()
if self.config["feature_caching"] == "Tea": if self.config["feature_caching"] == "Tea":
self.scheduler.cnt += 1 self.scheduler.cnt += 1
if self.scheduler.cnt == self.scheduler.num_steps: if self.scheduler.cnt == self.scheduler.num_steps:
......
...@@ -8,6 +8,9 @@ class WanPostInfer: ...@@ -8,6 +8,9 @@ class WanPostInfer:
self.out_dim = config["out_dim"] self.out_dim = config["out_dim"]
self.patch_size = (1, 2, 2) self.patch_size = (1, 2, 2)
def set_scheduler(self, scheduler):
self.scheduler = scheduler
def infer(self, weights, x, e, grid_sizes): def infer(self, weights, x, e, grid_sizes):
e = (weights.head_modulation + e.unsqueeze(1)).chunk(2, dim=1) e = (weights.head_modulation + e.unsqueeze(1)).chunk(2, dim=1)
norm_out = torch.nn.functional.layer_norm(x, (x.shape[1],), None, None, 1e-6).type_as(x) norm_out = torch.nn.functional.layer_norm(x, (x.shape[1],), None, None, 1e-6).type_as(x)
......
...@@ -22,7 +22,20 @@ class WanPreInfer: ...@@ -22,7 +22,20 @@ class WanPreInfer:
self.dim = config["dim"] self.dim = config["dim"]
self.text_len = config["text_len"] self.text_len = config["text_len"]
def infer(self, weights, x, t, context, seq_len, clip_fea=None, y=None): def set_scheduler(self, scheduler):
self.scheduler = scheduler
def infer(self, weights, inputs, positive):
x = [self.scheduler.latents]
t = torch.stack([self.scheduler.timesteps[self.scheduler.step_index]])
if positive:
context = inputs["text_encoder_output"]["context"]
else:
context = inputs["text_encoder_output"]["context_null"]
seq_len = self.scheduler.seq_len
clip_fea = inputs["image_encoder_output"]["clip_encoder_out"]
y = [inputs["image_encoder_output"]["vae_encode_out"]]
if self.task == "i2v": if self.task == "i2v":
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)] x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
......
...@@ -95,6 +95,8 @@ class WanModel: ...@@ -95,6 +95,8 @@ class WanModel:
def set_scheduler(self, scheduler): def set_scheduler(self, scheduler):
self.scheduler = scheduler self.scheduler = scheduler
self.pre_infer.set_scheduler(scheduler)
self.post_infer.set_scheduler(scheduler)
self.transformer_infer.set_scheduler(scheduler) self.transformer_infer.set_scheduler(scheduler)
def to_cpu(self): def to_cpu(self):
...@@ -108,24 +110,13 @@ class WanModel: ...@@ -108,24 +110,13 @@ class WanModel:
self.transformer_weights.to_cuda() self.transformer_weights.to_cuda()
@torch.no_grad() @torch.no_grad()
def infer(self, text_encoders_output, image_encoder_output, args): def infer(self, inputs):
timestep = torch.stack([self.scheduler.timesteps[self.scheduler.step_index]])
if self.config["cpu_offload"]: if self.config["cpu_offload"]:
self.pre_weight.to_cuda() self.pre_weight.to_cuda()
self.post_weight.to_cuda() self.post_weight.to_cuda()
embed, grid_sizes, pre_infer_out = self.pre_infer.infer( embed, grid_sizes, pre_infer_out = self.pre_infer.infer(self.pre_weight, inputs, positive=True)
self.pre_weight,
[self.scheduler.latents],
timestep,
text_encoders_output["context"],
self.scheduler.seq_len,
image_encoder_output["clip_encoder_out"],
[image_encoder_output["vae_encode_out"]],
)
x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out) x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out)
noise_pred_cond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0] noise_pred_cond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0]
if self.config["feature_caching"] == "Tea": if self.config["feature_caching"] == "Tea":
...@@ -133,16 +124,7 @@ class WanModel: ...@@ -133,16 +124,7 @@ class WanModel:
if self.scheduler.cnt >= self.scheduler.num_steps: if self.scheduler.cnt >= self.scheduler.num_steps:
self.scheduler.cnt = 0 self.scheduler.cnt = 0
embed, grid_sizes, pre_infer_out = self.pre_infer.infer( embed, grid_sizes, pre_infer_out = self.pre_infer.infer(self.pre_weight, inputs, positive=False)
self.pre_weight,
[self.scheduler.latents],
timestep,
text_encoders_output["context_null"],
self.scheduler.seq_len,
image_encoder_output["clip_encoder_out"],
[image_encoder_output["vae_encode_out"]],
)
x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out) x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out)
noise_pred_uncond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0] noise_pred_uncond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0]
...@@ -151,7 +133,7 @@ class WanModel: ...@@ -151,7 +133,7 @@ class WanModel:
if self.scheduler.cnt >= self.scheduler.num_steps: if self.scheduler.cnt >= self.scheduler.num_steps:
self.scheduler.cnt = 0 self.scheduler.cnt = 0
self.scheduler.noise_pred = noise_pred_uncond + args.sample_guide_scale * (noise_pred_cond - noise_pred_uncond) self.scheduler.noise_pred = noise_pred_uncond + self.config.sample_guide_scale * (noise_pred_cond - noise_pred_uncond)
if self.config["cpu_offload"]: if self.config["cpu_offload"]:
self.pre_weight.to_cpu() self.pre_weight.to_cpu()
......
...@@ -23,8 +23,8 @@ class TextEncoderHFClipModel: ...@@ -23,8 +23,8 @@ class TextEncoderHFClipModel:
self.model = self.model.to("cuda") self.model = self.model.to("cuda")
@torch.no_grad() @torch.no_grad()
def infer(self, text, args): def infer(self, text, config):
if args.cpu_offload: if config.cpu_offload:
self.to_cuda() self.to_cuda()
tokens = self.tokenizer( tokens = self.tokenizer(
text, text,
...@@ -44,7 +44,7 @@ class TextEncoderHFClipModel: ...@@ -44,7 +44,7 @@ class TextEncoderHFClipModel:
) )
last_hidden_state = outputs["pooler_output"] last_hidden_state = outputs["pooler_output"]
if args.cpu_offload: if config.cpu_offload:
self.to_cpu() self.to_cpu()
return last_hidden_state, tokens["attention_mask"] return last_hidden_state, tokens["attention_mask"]
......
...@@ -34,8 +34,8 @@ class TextEncoderHFLlamaModel: ...@@ -34,8 +34,8 @@ class TextEncoderHFLlamaModel:
self.model = self.model.to("cuda") self.model = self.model.to("cuda")
@torch.no_grad() @torch.no_grad()
def infer(self, text, args): def infer(self, text, config):
if args.cpu_offload: if config.cpu_offload:
self.to_cuda() self.to_cuda()
text = self.prompt_template.format(text) text = self.prompt_template.format(text)
tokens = self.tokenizer( tokens = self.tokenizer(
...@@ -57,7 +57,7 @@ class TextEncoderHFLlamaModel: ...@@ -57,7 +57,7 @@ class TextEncoderHFLlamaModel:
last_hidden_state = outputs.hidden_states[-(self.hidden_state_skip_layer + 1)][:, self.crop_start :] last_hidden_state = outputs.hidden_states[-(self.hidden_state_skip_layer + 1)][:, self.crop_start :]
attention_mask = tokens["attention_mask"][:, self.crop_start :] attention_mask = tokens["attention_mask"][:, self.crop_start :]
if args.cpu_offload: if config.cpu_offload:
self.to_cpu() self.to_cpu()
return last_hidden_state, attention_mask return last_hidden_state, attention_mask
......
...@@ -98,9 +98,9 @@ class TextEncoderHFLlavaModel: ...@@ -98,9 +98,9 @@ class TextEncoderHFLlavaModel:
self.model = self.model.to("cuda") self.model = self.model.to("cuda")
@torch.no_grad() @torch.no_grad()
def infer(self, text, img, args): def infer(self, text, img, config):
# if args.cpu_offload: if config.cpu_offload:
# self.to_cuda() self.to_cuda()
text = self.prompt_template.format(text) text = self.prompt_template.format(text)
print(f"text: {text}") print(f"text: {text}")
tokens = self.tokenizer( tokens = self.tokenizer(
...@@ -148,8 +148,8 @@ class TextEncoderHFLlavaModel: ...@@ -148,8 +148,8 @@ class TextEncoderHFLlavaModel:
last_hidden_state = torch.cat([image_last_hidden_state, text_last_hidden_state], dim=1) last_hidden_state = torch.cat([image_last_hidden_state, text_last_hidden_state], dim=1)
attention_mask = torch.cat([image_attention_mask, text_attention_mask], dim=1) attention_mask = torch.cat([image_attention_mask, text_attention_mask], dim=1)
# if args.cpu_offload: if config.cpu_offload:
# self.to_cpu() self.to_cpu()
return last_hidden_state, attention_mask return last_hidden_state, attention_mask
......
...@@ -492,8 +492,8 @@ class T5EncoderModel: ...@@ -492,8 +492,8 @@ class T5EncoderModel:
def to_cuda(self): def to_cuda(self):
self.model = self.model.to("cuda") self.model = self.model.to("cuda")
def infer(self, texts, args): def infer(self, texts, config):
if args.cpu_offload: if config.cpu_offload:
self.to_cuda() self.to_cuda()
ids, mask = self.tokenizer(texts, return_mask=True, add_special_tokens=True) ids, mask = self.tokenizer(texts, return_mask=True, add_special_tokens=True)
...@@ -502,7 +502,7 @@ class T5EncoderModel: ...@@ -502,7 +502,7 @@ class T5EncoderModel:
seq_lens = mask.gt(0).sum(dim=1).long() seq_lens = mask.gt(0).sum(dim=1).long()
context = self.model(ids, mask) context = self.model(ids, mask)
if args.cpu_offload: if config.cpu_offload:
self.to_cpu() self.to_cpu()
return [u[:v] for u, v in zip(context, seq_lens)] return [u[:v] for u, v in zip(context, seq_lens)]
......
...@@ -4,15 +4,15 @@ from .autoencoder_kl_causal_3d import AutoencoderKLCausal3D, DiagonalGaussianDis ...@@ -4,15 +4,15 @@ from .autoencoder_kl_causal_3d import AutoencoderKLCausal3D, DiagonalGaussianDis
class VideoEncoderKLCausal3DModel: class VideoEncoderKLCausal3DModel:
def __init__(self, model_path, dtype, device, args): def __init__(self, model_path, dtype, device, config):
self.model_path = model_path self.model_path = model_path
self.dtype = dtype self.dtype = dtype
self.device = device self.device = device
self.args = args self.config = config
self.load() self.load()
def load(self): def load(self):
if self.args.task == "t2v": if self.config.task == "t2v":
self.vae_path = os.path.join(self.model_path, "hunyuan-video-t2v-720p/vae") self.vae_path = os.path.join(self.model_path, "hunyuan-video-t2v-720p/vae")
else: else:
self.vae_path = os.path.join(self.model_path, "hunyuan-video-i2v-720p/vae") self.vae_path = os.path.join(self.model_path, "hunyuan-video-i2v-720p/vae")
...@@ -30,8 +30,8 @@ class VideoEncoderKLCausal3DModel: ...@@ -30,8 +30,8 @@ class VideoEncoderKLCausal3DModel:
def to_cuda(self): def to_cuda(self):
self.model = self.model.to("cuda") self.model = self.model.to("cuda")
def decode(self, latents, generator, args): def decode(self, latents, generator, config):
if args.cpu_offload: if config.cpu_offload:
self.to_cuda() self.to_cuda()
latents = latents / self.model.config.scaling_factor latents = latents / self.model.config.scaling_factor
latents = latents.to(dtype=self.dtype, device=torch.device("cuda")) latents = latents.to(dtype=self.dtype, device=torch.device("cuda"))
...@@ -39,11 +39,11 @@ class VideoEncoderKLCausal3DModel: ...@@ -39,11 +39,11 @@ class VideoEncoderKLCausal3DModel:
image = self.model.decode(latents, return_dict=False, generator=generator)[0] image = self.model.decode(latents, return_dict=False, generator=generator)[0]
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().float() image = image.cpu().float()
if args.cpu_offload: if config.cpu_offload:
self.to_cpu() self.to_cpu()
return image return image
def encode(self, x, args): def encode(self, x, config):
h = self.model.encoder(x) h = self.model.encoder(x)
moments = self.model.quant_conv(h) moments = self.model.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments) posterior = DiagonalGaussianDistribution(moments)
......
...@@ -786,8 +786,8 @@ class WanVAE: ...@@ -786,8 +786,8 @@ class WanVAE:
return images return images
def decode(self, zs, generator, args): def decode(self, zs, generator, config):
if args.cpu_offload: if config.cpu_offload:
self.to_cuda() self.to_cuda()
if self.parallel: if self.parallel:
...@@ -806,7 +806,7 @@ class WanVAE: ...@@ -806,7 +806,7 @@ class WanVAE:
else: else:
images = self.model.decode(zs.unsqueeze(0), self.scale).float().clamp_(-1, 1) images = self.model.decode(zs.unsqueeze(0), self.scale).float().clamp_(-1, 1)
if args.cpu_offload: if config.cpu_offload:
images = images.cpu().float() images = images.cpu().float()
self.to_cpu() self.to_cpu()
......
import json
from easydict import EasyDict
def set_config(args):
config = {k: v for k, v in vars(args).items()}
config = EasyDict(config)
if args.mm_config:
config.mm_config = json.loads(args.mm_config)
else:
config.mm_config = None
if args.config_path is not None:
with open(args.config_path, "r") as f:
model_config = json.load(f)
config.update(model_config)
return config
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment