Commit 5bd9bdbd authored by gaclove's avatar gaclove
Browse files

refactor: update DefaultRunner to return audio as None and modify...

refactor: update DefaultRunner to return audio as None and modify WanAudioRunner to handle audio saving and frame interpolation
parent 669ed391
......@@ -240,4 +240,5 @@ class DefaultRunner(BaseRunner):
torch.cuda.empty_cache()
gc.collect()
return images
# Return (images, audio) - audio is None for default runner
return images, None
......@@ -432,7 +432,7 @@ class WanAudioRunner(WanRunner):
ret["target_shape"] = self.config.target_shape
return ret
def run(self):
def run(self, save_video=True):
def load_audio(in_path: str, sr: float = 16000):
audio_array, ori_sr = ta.load(in_path)
audio_array = ta.functional.resample(audio_array.mean(0), orig_freq=ori_sr, new_freq=sr)
......@@ -548,7 +548,7 @@ class WanAudioRunner(WanRunner):
self.model.scheduler.reset()
if prev_latents is not None:
ltnt_channel, nframe, height, width = self.model.scheduler.latents.shape
_, nframe, height, width = self.model.scheduler.latents.shape
# bs = 1
frames_n = (nframe - 1) * 4 + 1
prev_frame_len = max((prev_len - 1) * 4 + 1, 0)
......@@ -592,15 +592,39 @@ class WanAudioRunner(WanRunner):
gen_lvideo = torch.cat(gen_video_list, dim=2).float()
merge_audio = np.concatenate(cut_audio_list, axis=0).astype(np.float32)
out_path = os.path.join("./", "video_merge.mp4")
audio_file = os.path.join("./", "audio_merge.wav")
comfyui_images = vae_to_comfyui_image(gen_lvideo)
save_to_video(comfyui_images, out_path, target_fps)
save_audio(merge_audio, audio_file, out_path, output_path=self.config.get("save_video_path", None))
os.remove(out_path)
os.remove(audio_file)
def run_pipeline(self):
# Apply frame interpolation if configured
if "video_frame_interpolation" in self.config:
assert self.vfi_model is not None and self.config["video_frame_interpolation"].get("target_fps", None) is not None
interpolation_target_fps = self.config["video_frame_interpolation"]["target_fps"]
logger.info(f"Interpolating frames from {target_fps} to {interpolation_target_fps}")
comfyui_images = self.vfi_model.interpolate_frames(
comfyui_images,
source_fps=target_fps,
target_fps=interpolation_target_fps,
)
# Update target_fps for saving
target_fps = interpolation_target_fps
# Convert audio to ComfyUI format
# Convert numpy array to torch tensor and add batch dimension
audio_waveform = torch.from_numpy(merge_audio).unsqueeze(0).unsqueeze(0) # [batch, channels, samples]
comfyui_audio = {"waveform": audio_waveform, "sample_rate": audio_sr}
# Save video if requested
if save_video and self.config.get("save_video_path", None):
out_path = os.path.join("./", "video_merge.mp4")
audio_file = os.path.join("./", "audio_merge.wav")
# Use the updated target_fps (after interpolation if applied)
save_to_video(comfyui_images, out_path, target_fps)
save_audio(merge_audio, audio_file, out_path, output_path=self.config.get("save_video_path", None))
os.remove(out_path)
os.remove(audio_file)
return comfyui_images, comfyui_audio
def run_pipeline(self, save_video=True):
if self.config["use_prompt_enhancer"]:
self.config["prompt_enhanced"] = self.post_prompt_enhancer()
......@@ -609,8 +633,10 @@ class WanAudioRunner(WanRunner):
self.init_scheduler()
self.model.scheduler.prepare(self.inputs["image_encoder_output"])
self.run()
images, audio = self.run(save_video) # run() now returns both images and audio
self.end_run()
gc.collect()
torch.cuda.empty_cache()
return images, audio
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment