Commit a32f6801 authored by sandy's avatar sandy Committed by GitHub
Browse files

[Feat] Make prev_frame_length parameter configurable via config (#276)



* [Feat] Make prev_frame_length parameter configurable via config

* Add newline at end of file in wan_audio_runner.py

---------
Co-authored-by: default avatarYang Yong(雍洋) <yongyang1030@163.com>
parent 319f1f41
...@@ -237,7 +237,7 @@ class AudioProcessor: ...@@ -237,7 +237,7 @@ class AudioProcessor:
else: else:
interval_num = max(int((expected_frames - max_num_frames) / (max_num_frames - prev_frame_length)) + 1, 1) interval_num = max(int((expected_frames - max_num_frames) / (max_num_frames - prev_frame_length)) + 1, 1)
res_frame_num = expected_frames - interval_num * (max_num_frames - prev_frame_length) res_frame_num = expected_frames - interval_num * (max_num_frames - prev_frame_length)
if res_frame_num > 5: if res_frame_num > prev_frame_length:
interval_num += 1 interval_num += 1
# Create segments # Create segments
...@@ -255,7 +255,7 @@ class AudioProcessor: ...@@ -255,7 +255,7 @@ class AudioProcessor:
segments.append(AudioSegment(segment_audio, 0, max_num_frames, False, useful_length)) segments.append(AudioSegment(segment_audio, 0, max_num_frames, False, useful_length))
elif res_frame_num > 5 and idx == interval_num - 1: elif res_frame_num > prev_frame_length and idx == interval_num - 1:
# Last segment (might be shorter) # Last segment (might be shorter)
start_frame = idx * max_num_frames - idx * prev_frame_length start_frame = idx * max_num_frames - idx * prev_frame_length
audio_start, audio_end = self.get_audio_range(start_frame, expected_frames) audio_start, audio_end = self.get_audio_range(start_frame, expected_frames)
...@@ -284,6 +284,7 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -284,6 +284,7 @@ class WanAudioRunner(WanRunner): # type:ignore
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.frame_preprocessor = FramePreprocessor() self.frame_preprocessor = FramePreprocessor()
self.prev_frame_length = self.config.get("prev_frame_length", 5)
def init_scheduler(self): def init_scheduler(self):
"""Initialize consistency model scheduler""" """Initialize consistency model scheduler"""
...@@ -307,7 +308,7 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -307,7 +308,7 @@ class WanAudioRunner(WanRunner): # type:ignore
expected_frames = min(max(1, int(video_duration * target_fps)), audio_len) expected_frames = min(max(1, int(video_duration * target_fps)), audio_len)
# Segment audio # Segment audio
audio_segments = self._audio_processor.segment_audio(audio_array, expected_frames, self.config.get("target_video_length", 81)) audio_segments = self._audio_processor.segment_audio(audio_array, expected_frames, self.config.get("target_video_length", 81), self.prev_frame_length)
return audio_segments, expected_frames return audio_segments, expected_frames
...@@ -464,7 +465,7 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -464,7 +465,7 @@ class WanAudioRunner(WanRunner): # type:ignore
audio_features = self.audio_adapter.forward_audio_proj(audio_features, self.model.scheduler.latents.shape[1]) audio_features = self.audio_adapter.forward_audio_proj(audio_features, self.model.scheduler.latents.shape[1])
self.inputs["audio_encoder_output"] = audio_features self.inputs["audio_encoder_output"] = audio_features
self.inputs["previmg_encoder_output"] = self.prepare_prev_latents(self.prev_video, prev_frame_length=5) self.inputs["previmg_encoder_output"] = self.prepare_prev_latents(self.prev_video, prev_frame_length=self.prev_frame_length)
# Reset scheduler for non-first segments # Reset scheduler for non-first segments
if segment_idx > 0: if segment_idx > 0:
...@@ -475,8 +476,8 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -475,8 +476,8 @@ class WanAudioRunner(WanRunner): # type:ignore
self.gen_video = torch.clamp(self.gen_video, -1, 1).to(torch.float) self.gen_video = torch.clamp(self.gen_video, -1, 1).to(torch.float)
# Extract relevant frames # Extract relevant frames
start_frame = 0 if self.segment_idx == 0 else 5 start_frame = 0 if self.segment_idx == 0 else self.prev_frame_length
start_audio_frame = 0 if self.segment_idx == 0 else int(6 * self._audio_processor.audio_sr / self.config.get("target_fps", 16)) start_audio_frame = 0 if self.segment_idx == 0 else int((self.prev_frame_length + 1) * self._audio_processor.audio_sr / self.config.get("target_fps", 16))
if self.segment.is_last and self.segment.useful_length: if self.segment.is_last and self.segment.useful_length:
end_frame = self.segment.end_frame - self.segment.start_frame end_frame = self.segment.end_frame - self.segment.start_frame
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment