Commit 6a9e8f6a authored by gushiqiao's avatar gushiqiao Committed by Yang Yong(雍洋)
Browse files

Fix bugs

parent 91c5dd15
......@@ -41,7 +41,7 @@ def load_models(args, model_config):
text_encoder_1 = TextEncoderHFLlamaModel(os.path.join(args.model_path, "text_encoder"), init_device)
text_encoder_2 = TextEncoderHFClipModel(os.path.join(args.model_path, "text_encoder_2"), init_device)
text_encoders = [text_encoder_1, text_encoder_2]
model = HunyuanModel(args.model_path, model_config, device=init_device)
model = HunyuanModel(args.model_path, model_config, init_device)
vae_model = VideoEncoderKLCausal3DModel(args.model_path, dtype=torch.float16, device=init_device)
elif args.model_cls == "wan2.1":
......@@ -54,7 +54,7 @@ def load_models(args, model_config):
shard_fn=None,
)
text_encoders = [text_encoder]
model = WanModel(args.model_path, model_config, device=init_device)
model = WanModel(args.model_path, model_config, init_device)
vae_model = WanVAE(vae_pth=os.path.join(args.model_path, "Wan2.1_VAE.pth"), device=init_device, parallel=args.parallel_vae)
if args.task == "i2v":
image_encoder = CLIPModel(
......
......@@ -71,13 +71,11 @@ class WanTransformerAttentionBlock:
self.cross_attn_norm_k,
self.ffn_0,
self.ffn_2,
# self.modulation,
]
if self.task == "i2v":
self.cross_attn_k_img = MM_WEIGHT_REGISTER[self.mm_type](f"blocks.{self.block_index}.cross_attn.k_img.weight", f"blocks.{self.block_index}.cross_attn.k_img.bias")
self.cross_attn_v_img = MM_WEIGHT_REGISTER[self.mm_type](f"blocks.{self.block_index}.cross_attn.v_img.weight", f"blocks.{self.block_index}.cross_attn.v_img.bias")
# self.cross_attn_norm_k_img_weight = weight_dict[f'blocks.{self.block_index}.cross_attn.norm_k_img.weight']
self.cross_attn_norm_k_img = RMS_WEIGHT_REGISTER["sgl-kernel"](f"blocks.{self.block_index}.cross_attn.norm_k_img.weight")
self.weight_list.append(self.cross_attn_k_img)
self.weight_list.append(self.cross_attn_v_img)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment