"doc/vscode:/vscode.git/clone" did not exist on "0f73f40da0dfc0c022217d667a67f7044ae6a28a"
Commit a32f6801 authored by sandy's avatar sandy Committed by GitHub
Browse files

[Feat] Make prev_frame_length parameter configurable via config (#276)



* [Feat] Make prev_frame_length parameter configurable via config

* Add newline at end of file in wan_audio_runner.py

---------
Co-authored-by: default avatarYang Yong(雍洋) <yongyang1030@163.com>
parent 319f1f41
......@@ -237,7 +237,7 @@ class AudioProcessor:
else:
interval_num = max(int((expected_frames - max_num_frames) / (max_num_frames - prev_frame_length)) + 1, 1)
res_frame_num = expected_frames - interval_num * (max_num_frames - prev_frame_length)
if res_frame_num > 5:
if res_frame_num > prev_frame_length:
interval_num += 1
# Create segments
......@@ -255,7 +255,7 @@ class AudioProcessor:
segments.append(AudioSegment(segment_audio, 0, max_num_frames, False, useful_length))
elif res_frame_num > 5 and idx == interval_num - 1:
elif res_frame_num > prev_frame_length and idx == interval_num - 1:
# Last segment (might be shorter)
start_frame = idx * max_num_frames - idx * prev_frame_length
audio_start, audio_end = self.get_audio_range(start_frame, expected_frames)
......@@ -284,6 +284,7 @@ class WanAudioRunner(WanRunner): # type:ignore
def __init__(self, config):
super().__init__(config)
self.frame_preprocessor = FramePreprocessor()
self.prev_frame_length = self.config.get("prev_frame_length", 5)
def init_scheduler(self):
"""Initialize consistency model scheduler"""
......@@ -307,7 +308,7 @@ class WanAudioRunner(WanRunner): # type:ignore
expected_frames = min(max(1, int(video_duration * target_fps)), audio_len)
# Segment audio
audio_segments = self._audio_processor.segment_audio(audio_array, expected_frames, self.config.get("target_video_length", 81))
audio_segments = self._audio_processor.segment_audio(audio_array, expected_frames, self.config.get("target_video_length", 81), self.prev_frame_length)
return audio_segments, expected_frames
......@@ -464,7 +465,7 @@ class WanAudioRunner(WanRunner): # type:ignore
audio_features = self.audio_adapter.forward_audio_proj(audio_features, self.model.scheduler.latents.shape[1])
self.inputs["audio_encoder_output"] = audio_features
self.inputs["previmg_encoder_output"] = self.prepare_prev_latents(self.prev_video, prev_frame_length=5)
self.inputs["previmg_encoder_output"] = self.prepare_prev_latents(self.prev_video, prev_frame_length=self.prev_frame_length)
# Reset scheduler for non-first segments
if segment_idx > 0:
......@@ -475,8 +476,8 @@ class WanAudioRunner(WanRunner): # type:ignore
self.gen_video = torch.clamp(self.gen_video, -1, 1).to(torch.float)
# Extract relevant frames
start_frame = 0 if self.segment_idx == 0 else 5
start_audio_frame = 0 if self.segment_idx == 0 else int(6 * self._audio_processor.audio_sr / self.config.get("target_fps", 16))
start_frame = 0 if self.segment_idx == 0 else self.prev_frame_length
start_audio_frame = 0 if self.segment_idx == 0 else int((self.prev_frame_length + 1) * self._audio_processor.audio_sr / self.config.get("target_fps", 16))
if self.segment.is_last and self.segment.useful_length:
end_frame = self.segment.end_frame - self.segment.start_frame
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment