"docs/zh_cn/vscode:/vscode.git/clone" did not exist on "be937d4a201a96a537d645992ebfbe21f70cc493"
Commit 00962c67 authored by gushiqiao's avatar gushiqiao
Browse files

Fix audio model compile and offload bugs

parent 4389450a
...@@ -113,7 +113,7 @@ class PerceiverAttentionCA(nn.Module): ...@@ -113,7 +113,7 @@ class PerceiverAttentionCA(nn.Module):
shift_scale_gate = torch.zeros((1, 3, inner_dim)) shift_scale_gate = torch.zeros((1, 3, inner_dim))
shift_scale_gate[:, 2] = 1 shift_scale_gate[:, 2] = 1
self.register_buffer("shift_scale_gate", shift_scale_gate, persistent=False) self.register_buffer("shift_scale_gate", shift_scale_gate, persistent=False)
def forward(self, x, latents, t_emb, q_lens, k_lens): def forward(self, x, latents, t_emb, q_lens, k_lens):
"""x shape (batchsize, latent_frame, audio_tokens_per_latent, """x shape (batchsize, latent_frame, audio_tokens_per_latent,
model_dim) latents (batchsize, length, model_dim)""" model_dim) latents (batchsize, length, model_dim)"""
......
...@@ -83,7 +83,14 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -83,7 +83,14 @@ class WanTransformerInfer(BaseTransformerInfer):
cu_seqlens_q = torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(0, dtype=torch.int32) cu_seqlens_q = torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(0, dtype=torch.int32)
cu_seqlens_k = torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(0, dtype=torch.int32) cu_seqlens_k = torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(0, dtype=torch.int32)
return cu_seqlens_q, cu_seqlens_k return cu_seqlens_q, cu_seqlens_k
def compute_freqs(self, q, grid_sizes, freqs):
if "audio" in self.config.get("model_cls", ""):
freqs_i = compute_freqs_audio(q.size(2) // 2, grid_sizes, freqs)
else:
freqs_i = compute_freqs(q.size(2) // 2, grid_sizes, freqs)
return freqs_i
@torch.compile(disable=not CHECK_ENABLE_GRAPH_MODE()) @torch.compile(disable=not CHECK_ENABLE_GRAPH_MODE())
def infer(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks=None): def infer(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks=None):
return self.infer_func(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks) return self.infer_func(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks)
...@@ -108,6 +115,8 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -108,6 +115,8 @@ class WanTransformerInfer(BaseTransformerInfer):
freqs, freqs,
context, context,
) )
if audio_dit_blocks:
x = self._apply_audio_dit(x, block_idx, grid_sizes, audio_dit_blocks)
self.weights_stream_mgr.swap_weights() self.weights_stream_mgr.swap_weights()
...@@ -136,7 +145,9 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -136,7 +145,9 @@ class WanTransformerInfer(BaseTransformerInfer):
freqs, freqs,
context, context,
) )
if audio_dit_blocks:
x = self._apply_audio_dit(x, block_idx, grid_sizes, audio_dit_blocks)
self.weights_stream_mgr.swap_weights() self.weights_stream_mgr.swap_weights()
if block_idx == self.blocks_num - 1: if block_idx == self.blocks_num - 1:
...@@ -144,6 +155,7 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -144,6 +155,7 @@ class WanTransformerInfer(BaseTransformerInfer):
self.weights_stream_mgr._async_prefetch_block(weights.blocks) self.weights_stream_mgr._async_prefetch_block(weights.blocks)
if self.clean_cuda_cache: if self.clean_cuda_cache:
del grid_sizes, embed, embed0, seq_lens, freqs, context del grid_sizes, embed, embed0, seq_lens, freqs, context
torch.cuda.empty_cache() torch.cuda.empty_cache()
...@@ -178,6 +190,8 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -178,6 +190,8 @@ class WanTransformerInfer(BaseTransformerInfer):
elif cur_phase_idx == 3: elif cur_phase_idx == 3:
y = self.infer_ffn(cur_phase, x, attn_out, c_shift_msa, c_scale_msa) y = self.infer_ffn(cur_phase, x, attn_out, c_shift_msa, c_scale_msa)
x = self.post_process(x, y, c_gate_msa) x = self.post_process(x, y, c_gate_msa)
if audio_dit_blocks:
x = self._apply_audio_dit(x, block_idx, grid_sizes, audio_dit_blocks)
is_last_phase = block_idx == weights.blocks_num - 1 and phase_idx == self.phases_num - 1 is_last_phase = block_idx == weights.blocks_num - 1 and phase_idx == self.phases_num - 1
if not is_last_phase: if not is_last_phase:
...@@ -238,6 +252,8 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -238,6 +252,8 @@ class WanTransformerInfer(BaseTransformerInfer):
elif cur_phase_idx == 3: elif cur_phase_idx == 3:
y = self.infer_ffn(cur_phase, x, attn_out, c_shift_msa, c_scale_msa) y = self.infer_ffn(cur_phase, x, attn_out, c_shift_msa, c_scale_msa)
x = self.post_process(x, y, c_gate_msa) x = self.post_process(x, y, c_gate_msa)
if audio_dit_blocks:
x = self._apply_audio_dit(x, block_idx, grid_sizes, audio_dit_blocks)
if not (block_idx == weights.blocks_num - 1 and phase_idx == self.phases_num - 1): if not (block_idx == weights.blocks_num - 1 and phase_idx == self.phases_num - 1):
next_block_idx = block_idx + 1 if phase_idx == self.phases_num - 1 else block_idx next_block_idx = block_idx + 1 if phase_idx == self.phases_num - 1 else block_idx
...@@ -274,6 +290,16 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -274,6 +290,16 @@ class WanTransformerInfer(BaseTransformerInfer):
freqs_cis[valid_token_length:, :, : rope_t_dim // 2] = 0 freqs_cis[valid_token_length:, :, : rope_t_dim // 2] = 0
return freqs_cis return freqs_cis
@torch._dynamo.disable
def _apply_audio_dit(self, x, block_idx, grid_sizes, audio_dit_blocks):
for ipa_out in audio_dit_blocks:
if block_idx in ipa_out:
cur_modify = ipa_out[block_idx]
x = cur_modify["modify_func"](x,
grid_sizes,
**cur_modify["kwargs"])
return x
def _infer_without_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks=None): def _infer_without_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks=None):
for block_idx in range(self.blocks_num): for block_idx in range(self.blocks_num):
x = self.infer_block( x = self.infer_block(
...@@ -286,12 +312,9 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -286,12 +312,9 @@ class WanTransformerInfer(BaseTransformerInfer):
freqs, freqs,
context, context,
) )
if audio_dit_blocks:
x = self._apply_audio_dit(x, block_idx, grid_sizes, audio_dit_blocks)
if audio_dit_blocks is not None and len(audio_dit_blocks) > 0:
for ipa_out in audio_dit_blocks:
if block_idx in ipa_out:
cur_modify = ipa_out[block_idx]
x = cur_modify["modify_func"](x, grid_sizes, **cur_modify["kwargs"])
return x return x
def infer_block(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context): def infer_block(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
...@@ -327,13 +350,6 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -327,13 +350,6 @@ class WanTransformerInfer(BaseTransformerInfer):
return shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa return shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa
def compute_freqs(self, q, grid_sizes, freqs):
if "audio" in self.config.get("model_cls", ""):
freqs_i = compute_freqs_audio(q.size(2) // 2, grid_sizes, freqs)
else:
freqs_i = compute_freqs(q.size(2) // 2, grid_sizes, freqs)
return freqs_i
def infer_self_attn(self, weights, grid_sizes, x, seq_lens, freqs, shift_msa, scale_msa): def infer_self_attn(self, weights, grid_sizes, x, seq_lens, freqs, shift_msa, scale_msa):
if hasattr(weights, "smooth_norm1_weight"): if hasattr(weights, "smooth_norm1_weight"):
norm1_weight = (1 + scale_msa.squeeze()) * weights.smooth_norm1_weight.tensor norm1_weight = (1 + scale_msa.squeeze()) * weights.smooth_norm1_weight.tensor
......
...@@ -302,7 +302,7 @@ class VideoGenerator: ...@@ -302,7 +302,7 @@ class VideoGenerator:
return mask.transpose(0, 1) return mask.transpose(0, 1)
@torch.no_grad() @torch.no_grad()
def generate_segment(self, inputs: Dict[str, Any], audio_features: torch.Tensor, prev_video: Optional[torch.Tensor] = None, prev_frame_length: int = 5, segment_idx: int = 0) -> torch.Tensor: def generate_segment(self, inputs, audio_features, prev_video=None, prev_frame_length=5, segment_idx=0, total_steps=None):
"""Generate video segment""" """Generate video segment"""
# Update inputs with audio features # Update inputs with audio features
inputs["audio_encoder_output"] = audio_features inputs["audio_encoder_output"] = audio_features
...@@ -352,7 +352,8 @@ class VideoGenerator: ...@@ -352,7 +352,8 @@ class VideoGenerator:
inputs["previmg_encoder_output"] = {"prev_latents": prev_latents, "prev_mask": prev_mask} inputs["previmg_encoder_output"] = {"prev_latents": prev_latents, "prev_mask": prev_mask}
# Run inference loop # Run inference loop
total_steps = self.model.scheduler.infer_steps if total_steps is None:
total_steps = self.model.scheduler.infer_steps
for step_index in range(total_steps): for step_index in range(total_steps):
logger.info(f"==> Segment {segment_idx}, Step {step_index}/{total_steps}") logger.info(f"==> Segment {segment_idx}, Step {step_index}/{total_steps}")
...@@ -686,6 +687,62 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -686,6 +687,62 @@ class WanAudioRunner(WanRunner): # type:ignore
ret["target_shape"] = self.config.target_shape ret["target_shape"] = self.config.target_shape
return ret return ret
def run_step(self):
"""Optimized pipeline with modular components"""
self.initialize()
assert self._audio_processor is not None
assert self._audio_preprocess is not None
self._video_generator = VideoGenerator(self.model, self.vae_encoder, self.vae_decoder, self.config, self.progress_callback)
with memory_efficient_inference():
if self.config["use_prompt_enhancer"]:
self.config["prompt_enhanced"] = self.post_prompt_enhancer()
self.inputs = self.prepare_inputs()
# Re-initialize scheduler after image encoding sets correct dimensions
self.init_scheduler()
self.model.scheduler.prepare(self.inputs["image_encoder_output"])
# Re-create video generator with updated model/scheduler
self._video_generator = VideoGenerator(self.model, self.vae_encoder, self.vae_decoder, self.config, self.progress_callback)
# Process audio
audio_array = self._audio_processor.load_audio(self.config["audio_path"])
video_duration = self.config.get("video_duration", 5)
target_fps = self.config.get("target_fps", 16)
max_num_frames = self.config.get("target_video_length", 81)
audio_len = int(audio_array.shape[0] / self._audio_processor.audio_sr * target_fps)
expected_frames = min(max(1, int(video_duration * target_fps)), audio_len)
# Segment audio
audio_segments = self._audio_processor.segment_audio(audio_array, expected_frames, max_num_frames)
self._video_generator.total_segments = len(audio_segments)
# Generate video segments
prev_video = None
torch.manual_seed(self.config.seed)
# Process audio features
audio_features = self._audio_preprocess(audio_segments[0].audio_array, sampling_rate=self._audio_processor.audio_sr, return_tensors="pt").input_values.squeeze(0).to(self.model.device)
# Generate video segment
with memory_efficient_inference():
self._video_generator.generate_segment(
self.inputs.copy(), # Copy to avoid modifying original
audio_features,
prev_video=prev_video,
prev_frame_length=5,
segment_idx=0,
total_steps=1
)
# Final cleanup
self.end_run()
@RUNNER_REGISTER("wan2.2_moe_audio") @RUNNER_REGISTER("wan2.2_moe_audio")
class Wan22MoeAudioRunner(WanAudioRunner): class Wan22MoeAudioRunner(WanAudioRunner):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment