Commit edeae441 authored by gaclove's avatar gaclove
Browse files

refactor: update WanTransformerInfer and VideoGenerator to improve audio...

refactor: update WanTransformerInfer and VideoGenerator to improve audio handling and progress tracking
parent 039456f2
...@@ -343,12 +343,12 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -343,12 +343,12 @@ class WanTransformerInfer(BaseTransformerInfer):
v = weights.self_attn_v.apply(norm1_out).view(s, n, d) v = weights.self_attn_v.apply(norm1_out).view(s, n, d)
if not self.parallel_attention: if not self.parallel_attention:
if self.config.get("audio_sr", False): if "audio" in self.config.get("model_cls", ""):
freqs_i = compute_freqs_audio(q.size(2) // 2, grid_sizes, freqs) freqs_i = compute_freqs_audio(q.size(2) // 2, grid_sizes, freqs)
else: else:
freqs_i = compute_freqs(q.size(2) // 2, grid_sizes, freqs) freqs_i = compute_freqs(q.size(2) // 2, grid_sizes, freqs)
else: else:
if self.config.get("audio_sr", False): if "audio" in self.config.get("model_cls", ""):
freqs_i = compute_freqs_audio_dist(q.size(0), q.size(2) // 2, grid_sizes, freqs) freqs_i = compute_freqs_audio_dist(q.size(0), q.size(2) // 2, grid_sizes, freqs)
else: else:
freqs_i = compute_freqs_dist(q.size(0), q.size(2) // 2, grid_sizes, freqs) freqs_i = compute_freqs_dist(q.size(0), q.size(2) // 2, grid_sizes, freqs)
......
...@@ -120,7 +120,7 @@ class DefaultRunner(BaseRunner): ...@@ -120,7 +120,7 @@ class DefaultRunner(BaseRunner):
self.model.scheduler.step_post() self.model.scheduler.step_post()
if self.progress_callback: if self.progress_callback:
self.progress_callback(step_index + 1, total_steps) self.progress_callback(((step_index + 1) / total_steps) * 100, 100)
return self.model.scheduler.latents, self.model.scheduler.generator return self.model.scheduler.latents, self.model.scheduler.generator
......
...@@ -255,12 +255,14 @@ class AudioProcessor: ...@@ -255,12 +255,14 @@ class AudioProcessor:
class VideoGenerator: class VideoGenerator:
"""Handles video generation for each segment""" """Handles video generation for each segment"""
def __init__(self, model, vae_encoder, vae_decoder, config): def __init__(self, model, vae_encoder, vae_decoder, config, progress_callback=None):
self.model = model self.model = model
self.vae_encoder = vae_encoder self.vae_encoder = vae_encoder
self.vae_decoder = vae_decoder self.vae_decoder = vae_decoder
self.config = config self.config = config
self.frame_preprocessor = FramePreprocessor() self.frame_preprocessor = FramePreprocessor()
self.progress_callback = progress_callback
self.total_segments = 1
def prepare_prev_latents(self, prev_video: Optional[torch.Tensor], prev_frame_length: int) -> Optional[Dict[str, torch.Tensor]]: def prepare_prev_latents(self, prev_video: Optional[torch.Tensor], prev_frame_length: int) -> Optional[Dict[str, torch.Tensor]]:
"""Prepare previous latents for conditioning""" """Prepare previous latents for conditioning"""
...@@ -352,8 +354,9 @@ class VideoGenerator: ...@@ -352,8 +354,9 @@ class VideoGenerator:
inputs["previmg_encoder_output"] = {"prev_latents": prev_latents, "prev_mask": prev_mask} inputs["previmg_encoder_output"] = {"prev_latents": prev_latents, "prev_mask": prev_mask}
# Run inference loop # Run inference loop
for step_index in range(self.model.scheduler.infer_steps): total_steps = self.model.scheduler.infer_steps
logger.info(f"==> Segment {segment_idx}, Step {step_index}/{self.model.scheduler.infer_steps}") for step_index in range(total_steps):
logger.info(f"==> Segment {segment_idx}, Step {step_index}/{total_steps}")
with ProfilingContext4Debug("step_pre"): with ProfilingContext4Debug("step_pre"):
self.model.scheduler.step_pre(step_index=step_index) self.model.scheduler.step_pre(step_index=step_index)
...@@ -364,6 +367,10 @@ class VideoGenerator: ...@@ -364,6 +367,10 @@ class VideoGenerator:
with ProfilingContext4Debug("step_post"): with ProfilingContext4Debug("step_post"):
self.model.scheduler.step_post() self.model.scheduler.step_post()
if self.progress_callback:
segment_progress = (segment_idx * total_steps + step_index + 1) / (self.total_segments * total_steps)
self.progress_callback(int(segment_progress * 100), 100)
# Decode latents # Decode latents
latents = self.model.scheduler.latents latents = self.model.scheduler.latents
generator = self.model.scheduler.generator generator = self.model.scheduler.generator
...@@ -377,7 +384,6 @@ class VideoGenerator: ...@@ -377,7 +384,6 @@ class VideoGenerator:
class WanAudioRunner(WanRunner): class WanAudioRunner(WanRunner):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self._is_initialized = False
self._audio_adapter_pipe = None self._audio_adapter_pipe = None
self._audio_processor = None self._audio_processor = None
self._video_generator = None self._video_generator = None
...@@ -385,25 +391,15 @@ class WanAudioRunner(WanRunner): ...@@ -385,25 +391,15 @@ class WanAudioRunner(WanRunner):
def initialize_once(self): def initialize_once(self):
"""Initialize all models once for multiple runs""" """Initialize all models once for multiple runs"""
if self._is_initialized:
return
logger.info("Initializing models (one-time setup)...")
# Initialize audio processor # Initialize audio processor
audio_sr = self.config.get("audio_sr", 16000) audio_sr = self.config.get("audio_sr", 16000)
target_fps = self.config.get("target_fps", 16) target_fps = self.config.get("target_fps", 16)
self._audio_processor = AudioProcessor(audio_sr, target_fps) self._audio_processor = AudioProcessor(audio_sr, target_fps)
# Load audio feature extractor
self._audio_preprocess = AutoFeatureExtractor.from_pretrained(self.config["model_path"], subfolder="audio_encoder")
# Initialize scheduler # Initialize scheduler
self.init_scheduler() self.init_scheduler()
self._is_initialized = True
logger.info("Model initialization complete")
def init_scheduler(self): def init_scheduler(self):
"""Initialize consistency model scheduler""" """Initialize consistency model scheduler"""
scheduler = ConsistencyModelScheduler(self.config) scheduler = ConsistencyModelScheduler(self.config)
...@@ -459,7 +455,7 @@ class WanAudioRunner(WanRunner): ...@@ -459,7 +455,7 @@ class WanAudioRunner(WanRunner):
# Initialize video generator if needed # Initialize video generator if needed
if self._video_generator is None: if self._video_generator is None:
self._video_generator = VideoGenerator(self.model, self.vae_encoder, self.vae_decoder, self.config) self._video_generator = VideoGenerator(self.model, self.vae_encoder, self.vae_decoder, self.config, self.progress_callback)
# Prepare inputs # Prepare inputs
with memory_efficient_inference(): with memory_efficient_inference():
...@@ -481,6 +477,8 @@ class WanAudioRunner(WanRunner): ...@@ -481,6 +477,8 @@ class WanAudioRunner(WanRunner):
# Segment audio # Segment audio
audio_segments = self._audio_processor.segment_audio(audio_array, expected_frames, max_num_frames) audio_segments = self._audio_processor.segment_audio(audio_array, expected_frames, max_num_frames)
self._video_generator.total_segments = len(audio_segments)
# Generate video segments # Generate video segments
gen_video_list = [] gen_video_list = []
cut_audio_list = [] cut_audio_list = []
...@@ -605,6 +603,9 @@ class WanAudioRunner(WanRunner): ...@@ -605,6 +603,9 @@ class WanAudioRunner(WanRunner):
lora_wrapper.apply_lora(lora_name, strength) lora_wrapper.apply_lora(lora_name, strength)
logger.info(f"Loaded LoRA: {lora_name} with strength: {strength}") logger.info(f"Loaded LoRA: {lora_name} with strength: {strength}")
# XXX: trick
self._audio_preprocess = AutoFeatureExtractor.from_pretrained(self.config["model_path"], subfolder="audio_encoder")
return base_model return base_model
def run_image_encoder(self, config, vae_model): def run_image_encoder(self, config, vae_model):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment