Commit 783b3a72 authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

Fix audio model compile and offload bugs

Dev gsq
parents 9067043e 92f067f1
...@@ -85,6 +85,13 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -85,6 +85,13 @@ class WanTransformerInfer(BaseTransformerInfer):
cu_seqlens_k = torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(0, dtype=torch.int32) cu_seqlens_k = torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(0, dtype=torch.int32)
return cu_seqlens_q, cu_seqlens_k return cu_seqlens_q, cu_seqlens_k
def compute_freqs(self, q, grid_sizes, freqs):
if "audio" in self.config.get("model_cls", ""):
freqs_i = compute_freqs_audio(q.size(2) // 2, grid_sizes, freqs)
else:
freqs_i = compute_freqs(q.size(2) // 2, grid_sizes, freqs)
return freqs_i
@torch.compile(disable=not CHECK_ENABLE_GRAPH_MODE()) @torch.compile(disable=not CHECK_ENABLE_GRAPH_MODE())
def infer(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks=None): def infer(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks=None):
return self.infer_func(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks) return self.infer_func(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks)
...@@ -109,6 +116,8 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -109,6 +116,8 @@ class WanTransformerInfer(BaseTransformerInfer):
freqs, freqs,
context, context,
) )
if audio_dit_blocks:
x = self._apply_audio_dit(x, block_idx, grid_sizes, audio_dit_blocks)
self.weights_stream_mgr.swap_weights() self.weights_stream_mgr.swap_weights()
...@@ -137,6 +146,8 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -137,6 +146,8 @@ class WanTransformerInfer(BaseTransformerInfer):
freqs, freqs,
context, context,
) )
if audio_dit_blocks:
x = self._apply_audio_dit(x, block_idx, grid_sizes, audio_dit_blocks)
self.weights_stream_mgr.swap_weights() self.weights_stream_mgr.swap_weights()
...@@ -179,6 +190,8 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -179,6 +190,8 @@ class WanTransformerInfer(BaseTransformerInfer):
elif cur_phase_idx == 3: elif cur_phase_idx == 3:
y = self.infer_ffn(cur_phase, x, attn_out, c_shift_msa, c_scale_msa) y = self.infer_ffn(cur_phase, x, attn_out, c_shift_msa, c_scale_msa)
x = self.post_process(x, y, c_gate_msa) x = self.post_process(x, y, c_gate_msa)
if audio_dit_blocks:
x = self._apply_audio_dit(x, block_idx, grid_sizes, audio_dit_blocks)
is_last_phase = block_idx == weights.blocks_num - 1 and phase_idx == self.phases_num - 1 is_last_phase = block_idx == weights.blocks_num - 1 and phase_idx == self.phases_num - 1
if not is_last_phase: if not is_last_phase:
...@@ -239,6 +252,8 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -239,6 +252,8 @@ class WanTransformerInfer(BaseTransformerInfer):
elif cur_phase_idx == 3: elif cur_phase_idx == 3:
y = self.infer_ffn(cur_phase, x, attn_out, c_shift_msa, c_scale_msa) y = self.infer_ffn(cur_phase, x, attn_out, c_shift_msa, c_scale_msa)
x = self.post_process(x, y, c_gate_msa) x = self.post_process(x, y, c_gate_msa)
if audio_dit_blocks:
x = self._apply_audio_dit(x, block_idx, grid_sizes, audio_dit_blocks)
if not (block_idx == weights.blocks_num - 1 and phase_idx == self.phases_num - 1): if not (block_idx == weights.blocks_num - 1 and phase_idx == self.phases_num - 1):
next_block_idx = block_idx + 1 if phase_idx == self.phases_num - 1 else block_idx next_block_idx = block_idx + 1 if phase_idx == self.phases_num - 1 else block_idx
...@@ -275,6 +290,14 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -275,6 +290,14 @@ class WanTransformerInfer(BaseTransformerInfer):
freqs_cis[valid_token_length:, :, : rope_t_dim // 2] = 0 freqs_cis[valid_token_length:, :, : rope_t_dim // 2] = 0
return freqs_cis return freqs_cis
@torch._dynamo.disable
def _apply_audio_dit(self, x, block_idx, grid_sizes, audio_dit_blocks):
for ipa_out in audio_dit_blocks:
if block_idx in ipa_out:
cur_modify = ipa_out[block_idx]
x = cur_modify["modify_func"](x, grid_sizes, **cur_modify["kwargs"])
return x
def _infer_without_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks=None): def _infer_without_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks=None):
for block_idx in range(self.blocks_num): for block_idx in range(self.blocks_num):
x = self.infer_block( x = self.infer_block(
...@@ -287,12 +310,9 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -287,12 +310,9 @@ class WanTransformerInfer(BaseTransformerInfer):
freqs, freqs,
context, context,
) )
if audio_dit_blocks:
x = self._apply_audio_dit(x, block_idx, grid_sizes, audio_dit_blocks)
if audio_dit_blocks is not None and len(audio_dit_blocks) > 0:
for ipa_out in audio_dit_blocks:
if block_idx in ipa_out:
cur_modify = ipa_out[block_idx]
x = cur_modify["modify_func"](x, grid_sizes, **cur_modify["kwargs"])
return x return x
def infer_block(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context): def infer_block(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
...@@ -328,13 +348,6 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -328,13 +348,6 @@ class WanTransformerInfer(BaseTransformerInfer):
return shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa return shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa
def compute_freqs(self, q, grid_sizes, freqs):
if "audio" in self.config.get("model_cls", ""):
freqs_i = compute_freqs_audio(q.size(2) // 2, grid_sizes, freqs)
else:
freqs_i = compute_freqs(q.size(2) // 2, grid_sizes, freqs)
return freqs_i
def infer_self_attn(self, weights, grid_sizes, x, seq_lens, freqs, shift_msa, scale_msa): def infer_self_attn(self, weights, grid_sizes, x, seq_lens, freqs, shift_msa, scale_msa):
if hasattr(weights, "smooth_norm1_weight"): if hasattr(weights, "smooth_norm1_weight"):
norm1_weight = (1 + scale_msa.squeeze()) * weights.smooth_norm1_weight.tensor norm1_weight = (1 + scale_msa.squeeze()) * weights.smooth_norm1_weight.tensor
......
...@@ -3,7 +3,7 @@ import os ...@@ -3,7 +3,7 @@ import os
import subprocess import subprocess
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
import numpy as np import numpy as np
import torch import torch
...@@ -302,7 +302,7 @@ class VideoGenerator: ...@@ -302,7 +302,7 @@ class VideoGenerator:
return mask.transpose(0, 1) return mask.transpose(0, 1)
@torch.no_grad() @torch.no_grad()
def generate_segment(self, inputs: Dict[str, Any], audio_features: torch.Tensor, prev_video: Optional[torch.Tensor] = None, prev_frame_length: int = 5, segment_idx: int = 0) -> torch.Tensor: def generate_segment(self, inputs, audio_features, prev_video=None, prev_frame_length=5, segment_idx=0, total_steps=None):
"""Generate video segment""" """Generate video segment"""
# Update inputs with audio features # Update inputs with audio features
inputs["audio_encoder_output"] = audio_features inputs["audio_encoder_output"] = audio_features
...@@ -352,7 +352,8 @@ class VideoGenerator: ...@@ -352,7 +352,8 @@ class VideoGenerator:
inputs["previmg_encoder_output"] = {"prev_latents": prev_latents, "prev_mask": prev_mask} inputs["previmg_encoder_output"] = {"prev_latents": prev_latents, "prev_mask": prev_mask}
# Run inference loop # Run inference loop
total_steps = self.model.scheduler.infer_steps if total_steps is None:
total_steps = self.model.scheduler.infer_steps
for step_index in range(total_steps): for step_index in range(total_steps):
logger.info(f"==> Segment {segment_idx}, Step {step_index}/{total_steps}") logger.info(f"==> Segment {segment_idx}, Step {step_index}/{total_steps}")
...@@ -694,6 +695,62 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -694,6 +695,62 @@ class WanAudioRunner(WanRunner): # type:ignore
ret["target_shape"] = self.config.target_shape ret["target_shape"] = self.config.target_shape
return ret return ret
def run_step(self):
"""Optimized pipeline with modular components"""
self.initialize()
assert self._audio_processor is not None
assert self._audio_preprocess is not None
self._video_generator = VideoGenerator(self.model, self.vae_encoder, self.vae_decoder, self.config, self.progress_callback)
with memory_efficient_inference():
if self.config["use_prompt_enhancer"]:
self.config["prompt_enhanced"] = self.post_prompt_enhancer()
self.inputs = self.prepare_inputs()
# Re-initialize scheduler after image encoding sets correct dimensions
self.init_scheduler()
self.model.scheduler.prepare(self.inputs["image_encoder_output"])
# Re-create video generator with updated model/scheduler
self._video_generator = VideoGenerator(self.model, self.vae_encoder, self.vae_decoder, self.config, self.progress_callback)
# Process audio
audio_array = self._audio_processor.load_audio(self.config["audio_path"])
video_duration = self.config.get("video_duration", 5)
target_fps = self.config.get("target_fps", 16)
max_num_frames = self.config.get("target_video_length", 81)
audio_len = int(audio_array.shape[0] / self._audio_processor.audio_sr * target_fps)
expected_frames = min(max(1, int(video_duration * target_fps)), audio_len)
# Segment audio
audio_segments = self._audio_processor.segment_audio(audio_array, expected_frames, max_num_frames)
self._video_generator.total_segments = len(audio_segments)
# Generate video segments
prev_video = None
torch.manual_seed(self.config.seed)
# Process audio features
audio_features = self._audio_preprocess(audio_segments[0].audio_array, sampling_rate=self._audio_processor.audio_sr, return_tensors="pt").input_values.squeeze(0).to(self.model.device)
# Generate video segment
with memory_efficient_inference():
self._video_generator.generate_segment(
self.inputs.copy(), # Copy to avoid modifying original
audio_features,
prev_video=prev_video,
prev_frame_length=5,
segment_idx=0,
total_steps=1,
)
# Final cleanup
self.end_run()
@RUNNER_REGISTER("wan2.2_moe_audio") @RUNNER_REGISTER("wan2.2_moe_audio")
class Wan22MoeAudioRunner(WanAudioRunner): class Wan22MoeAudioRunner(WanAudioRunner):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment