Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
e3663f4b
Commit
e3663f4b
authored
Sep 04, 2025
by
sandy
Committed by
GitHub
Sep 04, 2025
Browse files
[Ref] segment split and merge (#285)
parent
ad73b271
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
36 additions
and
64 deletions
+36
-64
lightx2v/models/runners/wan/wan_audio_runner.py
lightx2v/models/runners/wan/wan_audio_runner.py
+36
-64
No files found.
lightx2v/models/runners/wan/wan_audio_runner.py
View file @
e3663f4b
...
@@ -175,8 +175,6 @@ class AudioSegment:
...
@@ -175,8 +175,6 @@ class AudioSegment:
audio_array
:
np
.
ndarray
audio_array
:
np
.
ndarray
start_frame
:
int
start_frame
:
int
end_frame
:
int
end_frame
:
int
is_last
:
bool
=
False
useful_length
:
Optional
[
int
]
=
None
class
FramePreprocessorTorchVersion
:
class
FramePreprocessorTorchVersion
:
...
@@ -228,6 +226,7 @@ class AudioProcessor:
...
@@ -228,6 +226,7 @@ class AudioProcessor:
def
__init__
(
self
,
audio_sr
:
int
=
16000
,
target_fps
:
int
=
16
):
def
__init__
(
self
,
audio_sr
:
int
=
16000
,
target_fps
:
int
=
16
):
self
.
audio_sr
=
audio_sr
self
.
audio_sr
=
audio_sr
self
.
target_fps
=
target_fps
self
.
target_fps
=
target_fps
self
.
audio_frame_rate
=
audio_sr
//
target_fps
def
load_audio
(
self
,
audio_path
:
str
)
->
np
.
ndarray
:
def
load_audio
(
self
,
audio_path
:
str
)
->
np
.
ndarray
:
"""Load and resample audio"""
"""Load and resample audio"""
...
@@ -237,63 +236,48 @@ class AudioProcessor:
...
@@ -237,63 +236,48 @@ class AudioProcessor:
def
get_audio_range
(
self
,
start_frame
:
int
,
end_frame
:
int
)
->
Tuple
[
int
,
int
]:
def
get_audio_range
(
self
,
start_frame
:
int
,
end_frame
:
int
)
->
Tuple
[
int
,
int
]:
"""Calculate audio range for given frame range"""
"""Calculate audio range for given frame range"""
audio_frame_rate
=
self
.
audio_sr
/
self
.
target_fps
return
round
(
start_frame
*
self
.
audio_frame_rate
),
round
(
end_frame
*
self
.
audio_frame_rate
)
return
round
(
start_frame
*
audio_frame_rate
),
round
(
end_frame
*
audio_frame_rate
)
def
segment_audio
(
self
,
audio_array
:
np
.
ndarray
,
expected_frames
:
int
,
max_num_frames
:
int
,
prev_frame_length
:
int
=
5
)
->
List
[
AudioSegment
]:
def
segment_audio
(
self
,
audio_array
:
np
.
ndarray
,
expected_frames
:
int
,
max_num_frames
:
int
,
prev_frame_length
:
int
=
5
)
->
List
[
AudioSegment
]:
"""Segment audio based on frame requirements"""
"""Segment audio based on frame requirements"""
segments
=
[]
segments
=
[]
segments_idx
=
self
.
init_segments_idx
(
expected_frames
,
max_num_frames
,
prev_frame_length
)
# Calculate intervals
audio_start
,
audio_end
=
self
.
get_audio_range
(
0
,
expected_frames
)
interval_num
=
1
audio_array_ori
=
audio_array
[
audio_start
:
audio_end
]
res_frame_num
=
0
if
expected_frames
<=
max_num_frames
:
for
idx
,
(
start_idx
,
end_idx
)
in
enumerate
(
segments_idx
):
interval_num
=
1
audio_start
,
audio_end
=
self
.
get_audio_range
(
start_idx
,
end_idx
)
else
:
audio_array
=
audio_array_ori
[
audio_start
:
audio_end
]
interval_num
=
max
(
int
((
expected_frames
-
max_num_frames
)
/
(
max_num_frames
-
prev_frame_length
))
+
1
,
1
)
res_frame_num
=
expected_frames
-
interval_num
*
(
max_num_frames
-
prev_frame_length
)
if
res_frame_num
>
prev_frame_length
:
interval_num
+=
1
# Create segments
for
idx
in
range
(
interval_num
):
if
idx
==
0
:
# First segment
audio_start
,
audio_end
=
self
.
get_audio_range
(
0
,
max_num_frames
)
segment_audio
=
audio_array
[
audio_start
:
audio_end
]
useful_length
=
None
if
expected_frames
<
max_num_frames
:
useful_length
=
segment_audio
.
shape
[
0
]
max_num_audio_length
=
int
((
max_num_frames
+
1
)
/
self
.
target_fps
*
self
.
audio_sr
)
segment_audio
=
np
.
concatenate
((
segment_audio
,
np
.
zeros
(
max_num_audio_length
-
useful_length
)),
axis
=
0
)
segments
.
append
(
AudioSegment
(
segment_audio
,
0
,
max_num_frames
,
False
,
useful_length
))
elif
res_frame_num
>
prev_frame_length
and
idx
==
interval_num
-
1
:
# Last segment (might be shorter)
start_frame
=
idx
*
max_num_frames
-
idx
*
prev_frame_length
audio_start
,
audio_end
=
self
.
get_audio_range
(
start_frame
,
expected_frames
)
segment_audio
=
audio_array
[
audio_start
:
audio_end
]
useful_length
=
segment_audio
.
shape
[
0
]
max_num_audio_length
=
int
((
max_num_frames
+
1
)
/
self
.
target_fps
*
self
.
audio_sr
)
segment_audio
=
np
.
concatenate
((
segment_audio
,
np
.
zeros
(
max_num_audio_length
-
useful_length
)),
axis
=
0
)
segments
.
append
(
AudioSegment
(
segment_audio
,
start_frame
,
expected_frames
,
True
,
useful_length
))
if
idx
<
len
(
segments_idx
)
-
1
:
end_idx
=
segments_idx
[
idx
+
1
][
0
]
else
:
else
:
# Middle segments
if
audio_array
.
shape
[
0
]
<
audio_end
-
audio_start
:
start_frame
=
idx
*
max_num_frames
-
idx
*
prev_frame_length
padding_len
=
audio_end
-
audio_start
-
audio_array
.
shape
[
0
]
end_frame
=
(
idx
+
1
)
*
max_num_frames
-
idx
*
prev_frame_length
audio_array
=
np
.
concatenate
((
audio_array
,
np
.
zeros
(
padding_len
)),
axis
=
0
)
audio_start
,
audio_end
=
self
.
get_audio_range
(
start_frame
,
end_frame
)
end_idx
=
end_idx
-
padding_len
//
self
.
audio_frame_rate
segment_audio
=
audio_array
[
audio_start
:
audio_end
]
segments
.
append
(
AudioSegment
(
segment_audio
,
start_frame
,
end_frame
,
False
))
segments
.
append
(
AudioSegment
(
audio_array
,
start_idx
,
end_idx
))
del
audio_array
,
audio_array_ori
return
segments
return
segments
def
init_segments_idx
(
self
,
total_frame
:
int
,
clip_frame
:
int
=
81
,
overlap_frame
:
int
=
5
)
->
list
[
tuple
[
int
,
int
,
int
]]:
"""Initialize segment indices with overlap"""
start_end_list
=
[]
min_frame
=
clip_frame
for
start
in
range
(
0
,
total_frame
,
clip_frame
-
overlap_frame
):
is_last
=
start
+
clip_frame
>=
total_frame
end
=
min
(
start
+
clip_frame
,
total_frame
)
if
end
-
start
<
min_frame
:
end
=
start
+
min_frame
if
((
end
-
start
)
-
1
)
%
4
!=
0
:
end
=
start
+
(((
end
-
start
)
-
1
)
//
4
)
*
4
+
1
start_end_list
.
append
((
start
,
end
))
if
is_last
:
break
return
start_end_list
@
RUNNER_REGISTER
(
"seko_talk"
)
@
RUNNER_REGISTER
(
"seko_talk"
)
class
WanAudioRunner
(
WanRunner
):
# type:ignore
class
WanAudioRunner
(
WanRunner
):
# type:ignore
...
@@ -480,7 +464,7 @@ class WanAudioRunner(WanRunner): # type:ignore
...
@@ -480,7 +464,7 @@ class WanAudioRunner(WanRunner): # type:ignore
def
init_run_segment
(
self
,
segment_idx
,
audio_array
=
None
):
def
init_run_segment
(
self
,
segment_idx
,
audio_array
=
None
):
self
.
segment_idx
=
segment_idx
self
.
segment_idx
=
segment_idx
if
audio_array
is
not
None
:
if
audio_array
is
not
None
:
self
.
segment
=
AudioSegment
(
audio_array
,
0
,
audio_array
.
shape
[
0
]
,
False
)
self
.
segment
=
AudioSegment
(
audio_array
,
0
,
audio_array
.
shape
[
0
])
else
:
else
:
self
.
segment
=
self
.
inputs
[
"audio_segments"
][
segment_idx
]
self
.
segment
=
self
.
inputs
[
"audio_segments"
][
segment_idx
]
...
@@ -504,21 +488,9 @@ class WanAudioRunner(WanRunner): # type:ignore
...
@@ -504,21 +488,9 @@ class WanAudioRunner(WanRunner): # type:ignore
@
ProfilingContext4Debug
(
"End run segment"
)
@
ProfilingContext4Debug
(
"End run segment"
)
def
end_run_segment
(
self
):
def
end_run_segment
(
self
):
self
.
gen_video
=
torch
.
clamp
(
self
.
gen_video
,
-
1
,
1
).
to
(
torch
.
float
)
self
.
gen_video
=
torch
.
clamp
(
self
.
gen_video
,
-
1
,
1
).
to
(
torch
.
float
)
useful_length
=
self
.
segment
.
end_frame
-
self
.
segment
.
start_frame
# Extract relevant frames
self
.
gen_video_list
.
append
(
self
.
gen_video
[:,
:,
:
useful_length
].
cpu
())
start_frame
=
0
if
self
.
segment_idx
==
0
else
self
.
prev_frame_length
self
.
cut_audio_list
.
append
(
self
.
segment
.
audio_array
[:
useful_length
*
self
.
_audio_processor
.
audio_frame_rate
])
start_audio_frame
=
0
if
self
.
segment_idx
==
0
else
int
(
self
.
prev_frame_length
*
self
.
_audio_processor
.
audio_sr
/
self
.
config
.
get
(
"target_fps"
,
16
))
if
self
.
segment
.
is_last
and
self
.
segment
.
useful_length
:
end_frame
=
self
.
segment
.
end_frame
-
self
.
segment
.
start_frame
self
.
gen_video_list
.
append
(
self
.
gen_video
[:,
:,
start_frame
:
end_frame
].
cpu
())
self
.
cut_audio_list
.
append
(
self
.
segment
.
audio_array
[
start_audio_frame
:
self
.
segment
.
useful_length
])
elif
self
.
segment
.
useful_length
and
self
.
inputs
[
"expected_frames"
]
<
self
.
config
.
get
(
"target_video_length"
,
81
):
self
.
gen_video_list
.
append
(
self
.
gen_video
[:,
:,
start_frame
:
self
.
inputs
[
"expected_frames"
]].
cpu
())
self
.
cut_audio_list
.
append
(
self
.
segment
.
audio_array
[
start_audio_frame
:
self
.
segment
.
useful_length
])
else
:
self
.
gen_video_list
.
append
(
self
.
gen_video
[:,
:,
start_frame
:].
cpu
())
self
.
cut_audio_list
.
append
(
self
.
segment
.
audio_array
[
start_audio_frame
:])
if
self
.
va_recorder
:
if
self
.
va_recorder
:
cur_video
=
vae_to_comfyui_image
(
self
.
gen_video_list
[
-
1
])
cur_video
=
vae_to_comfyui_image
(
self
.
gen_video_list
[
-
1
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment