Unverified Commit 131c8a46 authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

[Fix] Fix model_io datacls (#340)

parent 682037cd
......@@ -293,8 +293,8 @@ class Generator(nn.Module):
self.dec = Synthesis(motion_dim)
def get_motion(self, img):
# motion_feat = self.enc.enc_motion(img)
motion_feat = torch.utils.checkpoint.checkpoint((self.enc.enc_motion), img, use_reentrant=True)
motion_feat = self.enc.enc_motion(img)
# motion_feat = torch.utils.checkpoint.checkpoint((self.enc.enc_motion), img, use_reentrant=True)
with torch.amp.autocast("cuda", dtype=torch.float32):
motion = self.dec.direction(motion_feat)
return motion
......@@ -123,7 +123,7 @@ class WanAudioModel(WanModel):
@torch.no_grad()
def _seq_parallel_pre_process(self, pre_infer_out):
x = pre_infer_out.x
person_mask_latens = pre_infer_out.adapter_output["person_mask_latens"]
person_mask_latens = pre_infer_out.adapter_args["person_mask_latens"]
world_size = dist.get_world_size(self.seq_p_group)
cur_rank = dist.get_rank(self.seq_p_group)
......@@ -136,7 +136,7 @@ class WanAudioModel(WanModel):
pre_infer_out.x = torch.chunk(x, world_size, dim=0)[cur_rank]
if person_mask_latens is not None:
pre_infer_out.adapter_output["person_mask_latens"] = torch.chunk(person_mask_latens, world_size, dim=1)[cur_rank]
pre_infer_out.adapter_args["person_mask_latens"] = torch.chunk(person_mask_latens, world_size, dim=1)[cur_rank]
if self.config["model_cls"] in ["wan2.2", "wan2.2_audio"] and self.config["task"] == "i2v":
embed, embed0 = pre_infer_out.embed, pre_infer_out.embed0
......
......@@ -14,8 +14,8 @@ class WanAnimateTransformerInfer(WanOffloadTransformerInfer):
def infer_post_adapter(self, phase, x, pre_infer_out):
if phase.is_empty():
return x
T = pre_infer_out.motion_vec.shape[0]
x_motion = phase.pre_norm_motion.apply(pre_infer_out.motion_vec)
T = pre_infer_out.adapter_args["motion_vec"].shape[0]
x_motion = phase.pre_norm_motion.apply(pre_infer_out.adapter_args["motion_vec"])
x_feat = phase.pre_norm_feat.apply(x)
kv = phase.linear1_kv.apply(x_motion.view(-1, x_motion.shape[-1]))
kv = kv.view(T, -1, kv.shape[-1])
......
......@@ -128,5 +128,5 @@ class WanAudioPreInfer(WanPreInfer):
seq_lens=seq_lens,
freqs=self.freqs,
context=context,
adapter_output={"audio_encoder_output": inputs["audio_encoder_output"], "person_mask_latens": person_mask_latens},
adapter_args={"audio_encoder_output": inputs["audio_encoder_output"], "person_mask_latens": person_mask_latens},
)
......@@ -22,8 +22,8 @@ class WanAudioTransformerInfer(WanOffloadTransformerInfer):
@torch.no_grad()
def infer_post_adapter(self, phase, x, pre_infer_out):
grid_sizes = pre_infer_out.grid_sizes.tensor
audio_encoder_output = pre_infer_out.adapter_output["audio_encoder_output"]
person_mask_latens = pre_infer_out.adapter_output["person_mask_latens"]
audio_encoder_output = pre_infer_out.adapter_args["audio_encoder_output"]
person_mask_latens = pre_infer_out.adapter_args["person_mask_latens"]
total_tokens = grid_sizes[0].prod()
pre_frame_tokens = grid_sizes[0][1:].prod()
n_tokens = total_tokens - pre_frame_tokens # 去掉ref image的token数
......
......@@ -19,5 +19,4 @@ class WanPreInferModuleOutput:
seq_lens: torch.Tensor
freqs: torch.Tensor
context: torch.Tensor
motion_vec: torch.Tensor
adapter_output: Dict[str, Any] = field(default_factory=dict)
adapter_args: Dict[str, Any] = field(default_factory=dict)
......@@ -217,7 +217,7 @@ class WanOffloadTransformerInfer(WanTransformerInfer):
self.phase_params["c_gate_msa"],
)
if hasattr(cur_phase, "after_proj"):
pre_infer_out.adapter_output["hints"].append(cur_phase.after_proj.apply(x))
pre_infer_out.adapter_args["hints"].append(cur_phase.after_proj.apply(x))
elif cur_phase_idx == 3:
x = self.infer_post_adapter(cur_phase, x, pre_infer_out)
return x
......
......@@ -131,5 +131,5 @@ class WanPreInfer:
seq_lens=seq_lens,
freqs=self.freqs,
context=context,
motion_vec=motion_vec,
adapter_args={"motion_vec": motion_vec},
)
......@@ -108,7 +108,7 @@ class WanTransformerInfer(BaseTransformerInfer):
y = self.infer_ffn(block.compute_phases[2], x, attn_out, c_shift_msa, c_scale_msa)
x = self.post_process(x, y, c_gate_msa, pre_infer_out)
if hasattr(block.compute_phases[2], "after_proj"):
pre_infer_out.adapter_output["hints"].append(block.compute_phases[2].after_proj.apply(x))
pre_infer_out.adapter_args["hints"].append(block.compute_phases[2].after_proj.apply(x))
if self.has_post_adapter:
x = self.infer_post_adapter(block.compute_phases[3], x, pre_infer_out)
......
......@@ -20,7 +20,7 @@ class WanVaceTransformerInfer(WanOffloadTransformerInfer):
return c
def infer_vace_blocks(self, vace_blocks, pre_infer_out):
pre_infer_out.adapter_output["hints"] = []
pre_infer_out.adapter_args["hints"] = []
self.infer_state = "vace"
if hasattr(self, "weights_stream_mgr"):
self.weights_stream_mgr.init(self.vace_blocks_num, self.phases_num, self.offload_ratio)
......@@ -33,5 +33,5 @@ class WanVaceTransformerInfer(WanOffloadTransformerInfer):
x = super().post_process(x, y, c_gate_msa, pre_infer_out)
if self.infer_state == "base" and self.block_idx in self.vace_blocks_mapping:
hint_idx = self.vace_blocks_mapping[self.block_idx]
x = x + pre_infer_out.adapter_output["hints"][hint_idx] * pre_infer_out.adapter_output.get("context_scale", 1.0)
x = x + pre_infer_out.adapter_args["hints"][hint_idx] * pre_infer_out.adapter_args.get("context_scale", 1.0)
return x
......@@ -363,18 +363,15 @@ class WanAnimateRunner(WanRunner):
self.config,
self.init_device,
)
motion_encoder, face_encoder = self.load_encoder()
motion_encoder, face_encoder = self.load_encoders()
model.set_animate_encoders(motion_encoder, face_encoder)
return model
def load_encoder(self):
motion_encoder = Generator(size=512, style_dim=512, motion_dim=20).eval().requires_grad_(False).to(GET_DTYPE())
face_encoder = FaceEncoder(in_dim=512, hidden_dim=5120, num_heads=4).eval().requires_grad_(False).to(GET_DTYPE())
def load_encoders(self):
motion_encoder = Generator(size=512, style_dim=512, motion_dim=20).eval().requires_grad_(False).to(GET_DTYPE()).cuda()
face_encoder = FaceEncoder(in_dim=512, hidden_dim=5120, num_heads=4).eval().requires_grad_(False).to(GET_DTYPE()).cuda()
motion_weight_dict = remove_substrings_from_keys(load_weights(self.config["model_path"], include_keys=["motion_encoder"]), "motion_encoder.")
face_weight_dict = remove_substrings_from_keys(load_weights(self.config["model_path"], include_keys=["face_encoder"]), "face_encoder.")
motion_encoder.load_state_dict(motion_weight_dict)
face_encoder.load_state_dict(face_weight_dict)
if not self.config["cpu_offload"]:
motion_encoder = motion_encoder.cuda()
face_encoder = face_encoder.cuda()
return motion_encoder, face_encoder
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment