Commit e3663f4b authored by sandy's avatar sandy Committed by GitHub
Browse files

[Ref] segment split and merge (#285)

parent ad73b271
...@@ -175,8 +175,6 @@ class AudioSegment: ...@@ -175,8 +175,6 @@ class AudioSegment:
audio_array: np.ndarray audio_array: np.ndarray
start_frame: int start_frame: int
end_frame: int end_frame: int
is_last: bool = False
useful_length: Optional[int] = None
class FramePreprocessorTorchVersion: class FramePreprocessorTorchVersion:
...@@ -228,6 +226,7 @@ class AudioProcessor: ...@@ -228,6 +226,7 @@ class AudioProcessor:
def __init__(self, audio_sr: int = 16000, target_fps: int = 16): def __init__(self, audio_sr: int = 16000, target_fps: int = 16):
self.audio_sr = audio_sr self.audio_sr = audio_sr
self.target_fps = target_fps self.target_fps = target_fps
self.audio_frame_rate = audio_sr // target_fps
def load_audio(self, audio_path: str) -> np.ndarray: def load_audio(self, audio_path: str) -> np.ndarray:
"""Load and resample audio""" """Load and resample audio"""
...@@ -237,63 +236,48 @@ class AudioProcessor: ...@@ -237,63 +236,48 @@ class AudioProcessor:
def get_audio_range(self, start_frame: int, end_frame: int) -> Tuple[int, int]: def get_audio_range(self, start_frame: int, end_frame: int) -> Tuple[int, int]:
"""Calculate audio range for given frame range""" """Calculate audio range for given frame range"""
audio_frame_rate = self.audio_sr / self.target_fps return round(start_frame * self.audio_frame_rate), round(end_frame * self.audio_frame_rate)
return round(start_frame * audio_frame_rate), round(end_frame * audio_frame_rate)
def segment_audio(self, audio_array: np.ndarray, expected_frames: int, max_num_frames: int, prev_frame_length: int = 5) -> List[AudioSegment]: def segment_audio(self, audio_array: np.ndarray, expected_frames: int, max_num_frames: int, prev_frame_length: int = 5) -> List[AudioSegment]:
"""Segment audio based on frame requirements""" """Segment audio based on frame requirements"""
segments = [] segments = []
segments_idx = self.init_segments_idx(expected_frames, max_num_frames, prev_frame_length)
# Calculate intervals audio_start, audio_end = self.get_audio_range(0, expected_frames)
interval_num = 1 audio_array_ori = audio_array[audio_start:audio_end]
res_frame_num = 0
if expected_frames <= max_num_frames: for idx, (start_idx, end_idx) in enumerate(segments_idx):
interval_num = 1 audio_start, audio_end = self.get_audio_range(start_idx, end_idx)
else: audio_array = audio_array_ori[audio_start:audio_end]
interval_num = max(int((expected_frames - max_num_frames) / (max_num_frames - prev_frame_length)) + 1, 1)
res_frame_num = expected_frames - interval_num * (max_num_frames - prev_frame_length)
if res_frame_num > prev_frame_length:
interval_num += 1
# Create segments
for idx in range(interval_num):
if idx == 0:
# First segment
audio_start, audio_end = self.get_audio_range(0, max_num_frames)
segment_audio = audio_array[audio_start:audio_end]
useful_length = None
if expected_frames < max_num_frames:
useful_length = segment_audio.shape[0]
max_num_audio_length = int((max_num_frames + 1) / self.target_fps * self.audio_sr)
segment_audio = np.concatenate((segment_audio, np.zeros(max_num_audio_length - useful_length)), axis=0)
segments.append(AudioSegment(segment_audio, 0, max_num_frames, False, useful_length))
elif res_frame_num > prev_frame_length and idx == interval_num - 1:
# Last segment (might be shorter)
start_frame = idx * max_num_frames - idx * prev_frame_length
audio_start, audio_end = self.get_audio_range(start_frame, expected_frames)
segment_audio = audio_array[audio_start:audio_end]
useful_length = segment_audio.shape[0]
max_num_audio_length = int((max_num_frames + 1) / self.target_fps * self.audio_sr)
segment_audio = np.concatenate((segment_audio, np.zeros(max_num_audio_length - useful_length)), axis=0)
segments.append(AudioSegment(segment_audio, start_frame, expected_frames, True, useful_length))
if idx < len(segments_idx) - 1:
end_idx = segments_idx[idx + 1][0]
else: else:
# Middle segments if audio_array.shape[0] < audio_end - audio_start:
start_frame = idx * max_num_frames - idx * prev_frame_length padding_len = audio_end - audio_start - audio_array.shape[0]
end_frame = (idx + 1) * max_num_frames - idx * prev_frame_length audio_array = np.concatenate((audio_array, np.zeros(padding_len)), axis=0)
audio_start, audio_end = self.get_audio_range(start_frame, end_frame) end_idx = end_idx - padding_len // self.audio_frame_rate
segment_audio = audio_array[audio_start:audio_end]
segments.append(AudioSegment(segment_audio, start_frame, end_frame, False))
segments.append(AudioSegment(audio_array, start_idx, end_idx))
del audio_array, audio_array_ori
return segments return segments
def init_segments_idx(self, total_frame: int, clip_frame: int = 81, overlap_frame: int = 5) -> list[tuple[int, int, int]]:
"""Initialize segment indices with overlap"""
start_end_list = []
min_frame = clip_frame
for start in range(0, total_frame, clip_frame - overlap_frame):
is_last = start + clip_frame >= total_frame
end = min(start + clip_frame, total_frame)
if end - start < min_frame:
end = start + min_frame
if ((end - start) - 1) % 4 != 0:
end = start + (((end - start) - 1) // 4) * 4 + 1
start_end_list.append((start, end))
if is_last:
break
return start_end_list
@RUNNER_REGISTER("seko_talk") @RUNNER_REGISTER("seko_talk")
class WanAudioRunner(WanRunner): # type:ignore class WanAudioRunner(WanRunner): # type:ignore
...@@ -480,7 +464,7 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -480,7 +464,7 @@ class WanAudioRunner(WanRunner): # type:ignore
def init_run_segment(self, segment_idx, audio_array=None): def init_run_segment(self, segment_idx, audio_array=None):
self.segment_idx = segment_idx self.segment_idx = segment_idx
if audio_array is not None: if audio_array is not None:
self.segment = AudioSegment(audio_array, 0, audio_array.shape[0], False) self.segment = AudioSegment(audio_array, 0, audio_array.shape[0])
else: else:
self.segment = self.inputs["audio_segments"][segment_idx] self.segment = self.inputs["audio_segments"][segment_idx]
...@@ -504,21 +488,9 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -504,21 +488,9 @@ class WanAudioRunner(WanRunner): # type:ignore
@ProfilingContext4Debug("End run segment") @ProfilingContext4Debug("End run segment")
def end_run_segment(self): def end_run_segment(self):
self.gen_video = torch.clamp(self.gen_video, -1, 1).to(torch.float) self.gen_video = torch.clamp(self.gen_video, -1, 1).to(torch.float)
useful_length = self.segment.end_frame - self.segment.start_frame
# Extract relevant frames self.gen_video_list.append(self.gen_video[:, :, :useful_length].cpu())
start_frame = 0 if self.segment_idx == 0 else self.prev_frame_length self.cut_audio_list.append(self.segment.audio_array[: useful_length * self._audio_processor.audio_frame_rate])
start_audio_frame = 0 if self.segment_idx == 0 else int(self.prev_frame_length * self._audio_processor.audio_sr / self.config.get("target_fps", 16))
if self.segment.is_last and self.segment.useful_length:
end_frame = self.segment.end_frame - self.segment.start_frame
self.gen_video_list.append(self.gen_video[:, :, start_frame:end_frame].cpu())
self.cut_audio_list.append(self.segment.audio_array[start_audio_frame : self.segment.useful_length])
elif self.segment.useful_length and self.inputs["expected_frames"] < self.config.get("target_video_length", 81):
self.gen_video_list.append(self.gen_video[:, :, start_frame : self.inputs["expected_frames"]].cpu())
self.cut_audio_list.append(self.segment.audio_array[start_audio_frame : self.segment.useful_length])
else:
self.gen_video_list.append(self.gen_video[:, :, start_frame:].cpu())
self.cut_audio_list.append(self.segment.audio_array[start_audio_frame:])
if self.va_recorder: if self.va_recorder:
cur_video = vae_to_comfyui_image(self.gen_video_list[-1]) cur_video = vae_to_comfyui_image(self.gen_video_list[-1])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment