Commit 39a4849a authored by gaclove's avatar gaclove
Browse files

refactor: streamline run_pipeline method in WanAudioRunner for improved...

refactor: streamline run_pipeline method in WanAudioRunner for improved clarity and efficiency, including enhanced error handling and memory management
parent 1a881d63
...@@ -451,17 +451,15 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -451,17 +451,15 @@ class WanAudioRunner(WanRunner): # type:ignore
def run_pipeline(self, save_video=True): def run_pipeline(self, save_video=True):
"""Optimized pipeline with modular components""" """Optimized pipeline with modular components"""
# Ensure models are initialized
try:
self.initialize() self.initialize()
assert self._audio_processor is not None assert self._audio_processor is not None
assert self._audio_preprocess is not None assert self._audio_preprocess is not None
# Initialize video generator if needed
if self._video_generator is None:
self._video_generator = VideoGenerator(self.model, self.vae_encoder, self.vae_decoder, self.config, self.progress_callback) self._video_generator = VideoGenerator(self.model, self.vae_encoder, self.vae_decoder, self.config, self.progress_callback)
# Prepare inputs
with memory_efficient_inference(): with memory_efficient_inference():
if self.config["use_prompt_enhancer"]: if self.config["use_prompt_enhancer"]:
self.config["prompt_enhanced"] = self.post_prompt_enhancer() self.config["prompt_enhanced"] = self.post_prompt_enhancer()
...@@ -494,7 +492,6 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -494,7 +492,6 @@ class WanAudioRunner(WanRunner): # type:ignore
prev_video = None prev_video = None
for idx, segment in enumerate(audio_segments): for idx, segment in enumerate(audio_segments):
# Update seed for each segment
self.config.seed = self.config.seed + idx self.config.seed = self.config.seed + idx
torch.manual_seed(self.config.seed) torch.manual_seed(self.config.seed)
logger.info(f"Processing segment {idx + 1}/{len(audio_segments)}, seed: {self.config.seed}") logger.info(f"Processing segment {idx + 1}/{len(audio_segments)}, seed: {self.config.seed}")
...@@ -564,6 +561,12 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -564,6 +561,12 @@ class WanAudioRunner(WanRunner): # type:ignore
return comfyui_images, comfyui_audio return comfyui_images, comfyui_audio
finally:
self._video_generator = None
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
def _save_video_with_audio(self, images, audio_array, fps): def _save_video_with_audio(self, images, audio_array, fps):
"""Save video with audio""" """Save video with audio"""
import tempfile import tempfile
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment