Commit 39a4849a authored by gaclove's avatar gaclove
Browse files

refactor: streamline run_pipeline method in WanAudioRunner for improved...

refactor: streamline run_pipeline method in WanAudioRunner for improved clarity and efficiency, including enhanced error handling and memory management
parent 1a881d63
......@@ -451,118 +451,121 @@ class WanAudioRunner(WanRunner): # type:ignore
def run_pipeline(self, save_video=True):
"""Optimized pipeline with modular components"""
# Ensure models are initialized
self.initialize()
assert self._audio_processor is not None
assert self._audio_preprocess is not None
# Initialize video generator if needed
if self._video_generator is None:
self._video_generator = VideoGenerator(self.model, self.vae_encoder, self.vae_decoder, self.config, self.progress_callback)
# Prepare inputs
with memory_efficient_inference():
if self.config["use_prompt_enhancer"]:
self.config["prompt_enhanced"] = self.post_prompt_enhancer()
self.inputs = self.prepare_inputs()
# Re-initialize scheduler after image encoding sets correct dimensions
self.init_scheduler()
self.model.scheduler.prepare(self.inputs["image_encoder_output"])
# Re-create video generator with updated model/scheduler
self._video_generator = VideoGenerator(self.model, self.vae_encoder, self.vae_decoder, self.config, self.progress_callback)
# Process audio
audio_array = self._audio_processor.load_audio(self.config["audio_path"])
video_duration = self.config.get("video_duration", 5)
target_fps = self.config.get("target_fps", 16)
max_num_frames = self.config.get("target_video_length", 81)
audio_len = int(audio_array.shape[0] / self._audio_processor.audio_sr * target_fps)
expected_frames = min(max(1, int(video_duration * target_fps)), audio_len)
try:
self.initialize()
# Segment audio
audio_segments = self._audio_processor.segment_audio(audio_array, expected_frames, max_num_frames)
assert self._audio_processor is not None
assert self._audio_preprocess is not None
self._video_generator.total_segments = len(audio_segments)
self._video_generator = VideoGenerator(self.model, self.vae_encoder, self.vae_decoder, self.config, self.progress_callback)
# Generate video segments
gen_video_list = []
cut_audio_list = []
prev_video = None
with memory_efficient_inference():
if self.config["use_prompt_enhancer"]:
self.config["prompt_enhanced"] = self.post_prompt_enhancer()
for idx, segment in enumerate(audio_segments):
# Update seed for each segment
self.config.seed = self.config.seed + idx
torch.manual_seed(self.config.seed)
logger.info(f"Processing segment {idx + 1}/{len(audio_segments)}, seed: {self.config.seed}")
self.inputs = self.prepare_inputs()
# Re-initialize scheduler after image encoding sets correct dimensions
self.init_scheduler()
self.model.scheduler.prepare(self.inputs["image_encoder_output"])
# Process audio features
audio_features = self._audio_preprocess(segment.audio_array, sampling_rate=self._audio_processor.audio_sr, return_tensors="pt").input_values.squeeze(0).to(self.model.device)
# Re-create video generator with updated model/scheduler
self._video_generator = VideoGenerator(self.model, self.vae_encoder, self.vae_decoder, self.config, self.progress_callback)
# Generate video segment
# Process audio
audio_array = self._audio_processor.load_audio(self.config["audio_path"])
video_duration = self.config.get("video_duration", 5)
target_fps = self.config.get("target_fps", 16)
max_num_frames = self.config.get("target_video_length", 81)
audio_len = int(audio_array.shape[0] / self._audio_processor.audio_sr * target_fps)
expected_frames = min(max(1, int(video_duration * target_fps)), audio_len)
# Segment audio
audio_segments = self._audio_processor.segment_audio(audio_array, expected_frames, max_num_frames)
self._video_generator.total_segments = len(audio_segments)
# Generate video segments
gen_video_list = []
cut_audio_list = []
prev_video = None
for idx, segment in enumerate(audio_segments):
self.config.seed = self.config.seed + idx
torch.manual_seed(self.config.seed)
logger.info(f"Processing segment {idx + 1}/{len(audio_segments)}, seed: {self.config.seed}")
# Process audio features
audio_features = self._audio_preprocess(segment.audio_array, sampling_rate=self._audio_processor.audio_sr, return_tensors="pt").input_values.squeeze(0).to(self.model.device)
# Generate video segment
with memory_efficient_inference():
gen_video = self._video_generator.generate_segment(
self.inputs.copy(), # Copy to avoid modifying original
audio_features,
prev_video=prev_video,
prev_frame_length=5,
segment_idx=idx,
)
# Extract relevant frames
start_frame = 0 if idx == 0 else 5
start_audio_frame = 0 if idx == 0 else int(6 * self._audio_processor.audio_sr / target_fps)
if segment.is_last and segment.useful_length:
end_frame = segment.end_frame - segment.start_frame
gen_video_list.append(gen_video[:, :, start_frame:end_frame].cpu())
cut_audio_list.append(segment.audio_array[start_audio_frame : segment.useful_length])
elif segment.useful_length and expected_frames < max_num_frames:
gen_video_list.append(gen_video[:, :, start_frame:expected_frames].cpu())
cut_audio_list.append(segment.audio_array[start_audio_frame : segment.useful_length])
else:
gen_video_list.append(gen_video[:, :, start_frame:].cpu())
cut_audio_list.append(segment.audio_array[start_audio_frame:])
# Update prev_video for next iteration
prev_video = gen_video
# Clean up GPU memory after each segment
del gen_video
torch.cuda.empty_cache()
# Merge results
with memory_efficient_inference():
gen_video = self._video_generator.generate_segment(
self.inputs.copy(), # Copy to avoid modifying original
audio_features,
prev_video=prev_video,
prev_frame_length=5,
segment_idx=idx,
gen_lvideo = torch.cat(gen_video_list, dim=2).float()
merge_audio = np.concatenate(cut_audio_list, axis=0).astype(np.float32)
comfyui_images = vae_to_comfyui_image(gen_lvideo)
# Apply frame interpolation if configured
if "video_frame_interpolation" in self.config and self.vfi_model is not None:
interpolation_target_fps = self.config["video_frame_interpolation"]["target_fps"]
logger.info(f"Interpolating frames from {target_fps} to {interpolation_target_fps}")
comfyui_images = self.vfi_model.interpolate_frames(
comfyui_images,
source_fps=target_fps,
target_fps=interpolation_target_fps,
)
target_fps = interpolation_target_fps
# Extract relevant frames
start_frame = 0 if idx == 0 else 5
start_audio_frame = 0 if idx == 0 else int(6 * self._audio_processor.audio_sr / target_fps)
if segment.is_last and segment.useful_length:
end_frame = segment.end_frame - segment.start_frame
gen_video_list.append(gen_video[:, :, start_frame:end_frame].cpu())
cut_audio_list.append(segment.audio_array[start_audio_frame : segment.useful_length])
elif segment.useful_length and expected_frames < max_num_frames:
gen_video_list.append(gen_video[:, :, start_frame:expected_frames].cpu())
cut_audio_list.append(segment.audio_array[start_audio_frame : segment.useful_length])
else:
gen_video_list.append(gen_video[:, :, start_frame:].cpu())
cut_audio_list.append(segment.audio_array[start_audio_frame:])
# Convert audio to ComfyUI format
audio_waveform = torch.from_numpy(merge_audio).unsqueeze(0).unsqueeze(0)
comfyui_audio = {"waveform": audio_waveform, "sample_rate": self._audio_processor.audio_sr}
# Update prev_video for next iteration
prev_video = gen_video
# Save video if requested
if save_video and self.config.get("save_video_path", None):
self._save_video_with_audio(comfyui_images, merge_audio, target_fps)
# Clean up GPU memory after each segment
del gen_video
torch.cuda.empty_cache()
# Final cleanup
self.end_run()
# Merge results
with memory_efficient_inference():
gen_lvideo = torch.cat(gen_video_list, dim=2).float()
merge_audio = np.concatenate(cut_audio_list, axis=0).astype(np.float32)
comfyui_images = vae_to_comfyui_image(gen_lvideo)
# Apply frame interpolation if configured
if "video_frame_interpolation" in self.config and self.vfi_model is not None:
interpolation_target_fps = self.config["video_frame_interpolation"]["target_fps"]
logger.info(f"Interpolating frames from {target_fps} to {interpolation_target_fps}")
comfyui_images = self.vfi_model.interpolate_frames(
comfyui_images,
source_fps=target_fps,
target_fps=interpolation_target_fps,
)
target_fps = interpolation_target_fps
return comfyui_images, comfyui_audio
# Convert audio to ComfyUI format
audio_waveform = torch.from_numpy(merge_audio).unsqueeze(0).unsqueeze(0)
comfyui_audio = {"waveform": audio_waveform, "sample_rate": self._audio_processor.audio_sr}
# Save video if requested
if save_video and self.config.get("save_video_path", None):
self._save_video_with_audio(comfyui_images, merge_audio, target_fps)
# Final cleanup
self.end_run()
return comfyui_images, comfyui_audio
finally:
self._video_generator = None
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
def _save_video_with_audio(self, images, audio_array, fps):
"""Save video with audio"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment