Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
d048b178
Commit
d048b178
authored
Jul 23, 2025
by
gaclove
Browse files
refactor: enhance WanAudioRunner to improve audio handling and frame interpolation
parent
5bd9bdbd
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
458 additions
and
459 deletions
+458
-459
lightx2v/models/runners/wan/wan_audio_runner.py
lightx2v/models/runners/wan/wan_audio_runner.py
+458
-459
No files found.
lightx2v/models/runners/wan/wan_audio_runner.py
View file @
d048b178
...
...
@@ -4,6 +4,10 @@ import numpy as np
import
torch
import
torchvision.transforms.functional
as
TF
from
PIL
import
Image
from
contextlib
import
contextmanager
from
typing
import
Optional
,
Tuple
,
Union
,
List
,
Dict
,
Any
from
dataclasses
import
dataclass
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.models.runners.wan.wan_runner
import
WanRunner
from
lightx2v.models.runners.default_runner
import
DefaultRunner
...
...
@@ -34,46 +38,266 @@ from torchvision.transforms.functional import resize
import
subprocess
import
warnings
from
typing
import
Optional
,
Tuple
,
Union
def
add_mask_to_frames
(
frames
:
np
.
ndarray
,
mask_rate
:
float
=
0.1
,
rnd_state
:
np
.
random
.
RandomState
=
None
,
)
->
np
.
ndarray
:
if
mask_rate
is
None
:
return
frames
@
contextmanager
def
memory_efficient_inference
():
"""Context manager for memory-efficient inference"""
try
:
yield
finally
:
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
@
dataclass
class
AudioSegment
:
"""Data class for audio segment information"""
audio_array
:
np
.
ndarray
start_frame
:
int
end_frame
:
int
is_last
:
bool
=
False
useful_length
:
Optional
[
int
]
=
None
class
FramePreprocessor
:
"""Handles frame preprocessing including noise and masking"""
def
__init__
(
self
,
noise_mean
:
float
=
-
3.0
,
noise_std
:
float
=
0.5
,
mask_rate
:
float
=
0.1
):
self
.
noise_mean
=
noise_mean
self
.
noise_std
=
noise_std
self
.
mask_rate
=
mask_rate
def
add_noise
(
self
,
frames
:
np
.
ndarray
,
rnd_state
:
Optional
[
np
.
random
.
RandomState
]
=
None
)
->
np
.
ndarray
:
"""Add noise to frames"""
if
self
.
noise_mean
is
None
or
self
.
noise_std
is
None
:
return
frames
if
rnd_state
is
None
:
rnd_state
=
np
.
random
.
RandomState
()
shape
=
frames
.
shape
bs
=
1
if
len
(
shape
)
==
4
else
shape
[
0
]
sigma
=
rnd_state
.
normal
(
loc
=
self
.
noise_mean
,
scale
=
self
.
noise_std
,
size
=
(
bs
,))
sigma
=
np
.
exp
(
sigma
)
sigma
=
np
.
expand_dims
(
sigma
,
axis
=
tuple
(
range
(
1
,
len
(
shape
))))
noise
=
rnd_state
.
randn
(
*
shape
)
*
sigma
return
frames
+
noise
def
add_mask
(
self
,
frames
:
np
.
ndarray
,
rnd_state
:
Optional
[
np
.
random
.
RandomState
]
=
None
)
->
np
.
ndarray
:
"""Add mask to frames"""
if
self
.
mask_rate
is
None
:
return
frames
if
rnd_state
is
None
:
rnd_state
=
np
.
random
.
RandomState
()
h
,
w
=
frames
.
shape
[
-
2
:]
mask
=
rnd_state
.
rand
(
h
,
w
)
>
self
.
mask_rate
return
frames
*
mask
def
process_prev_frames
(
self
,
frames
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Process previous frames with noise and masking"""
frames_np
=
frames
.
cpu
().
detach
().
numpy
()
frames_np
=
self
.
add_noise
(
frames_np
)
frames_np
=
self
.
add_mask
(
frames_np
)
return
torch
.
from_numpy
(
frames_np
).
to
(
dtype
=
frames
.
dtype
,
device
=
frames
.
device
)
class
AudioProcessor
:
"""Handles audio loading and segmentation"""
def
__init__
(
self
,
audio_sr
:
int
=
16000
,
target_fps
:
int
=
16
):
self
.
audio_sr
=
audio_sr
self
.
target_fps
=
target_fps
def
load_audio
(
self
,
audio_path
:
str
)
->
np
.
ndarray
:
"""Load and resample audio"""
audio_array
,
ori_sr
=
ta
.
load
(
audio_path
)
audio_array
=
ta
.
functional
.
resample
(
audio_array
.
mean
(
0
),
orig_freq
=
ori_sr
,
new_freq
=
self
.
audio_sr
)
return
audio_array
.
numpy
()
def
get_audio_range
(
self
,
start_frame
:
int
,
end_frame
:
int
)
->
Tuple
[
int
,
int
]:
"""Calculate audio range for given frame range"""
audio_frame_rate
=
self
.
audio_sr
/
self
.
target_fps
return
round
(
start_frame
*
audio_frame_rate
),
round
((
end_frame
+
1
)
*
audio_frame_rate
)
def
segment_audio
(
self
,
audio_array
:
np
.
ndarray
,
expected_frames
:
int
,
max_num_frames
:
int
,
prev_frame_length
:
int
=
5
)
->
List
[
AudioSegment
]:
"""Segment audio based on frame requirements"""
segments
=
[]
# Calculate intervals
interval_num
=
1
res_frame_num
=
0
if
expected_frames
<=
max_num_frames
:
interval_num
=
1
else
:
interval_num
=
max
(
int
((
expected_frames
-
max_num_frames
)
/
(
max_num_frames
-
prev_frame_length
))
+
1
,
1
)
res_frame_num
=
expected_frames
-
interval_num
*
(
max_num_frames
-
prev_frame_length
)
if
res_frame_num
>
5
:
interval_num
+=
1
# Create segments
for
idx
in
range
(
interval_num
):
if
idx
==
0
:
# First segment
audio_start
,
audio_end
=
self
.
get_audio_range
(
0
,
max_num_frames
)
segment_audio
=
audio_array
[
audio_start
:
audio_end
]
useful_length
=
None
if
expected_frames
<
max_num_frames
:
useful_length
=
segment_audio
.
shape
[
0
]
max_num_audio_length
=
int
((
max_num_frames
+
1
)
/
self
.
target_fps
*
self
.
audio_sr
)
segment_audio
=
np
.
concatenate
((
segment_audio
,
np
.
zeros
(
max_num_audio_length
-
useful_length
)),
axis
=
0
)
segments
.
append
(
AudioSegment
(
segment_audio
,
0
,
max_num_frames
,
False
,
useful_length
))
elif
res_frame_num
>
5
and
idx
==
interval_num
-
1
:
# Last segment (might be shorter)
start_frame
=
idx
*
max_num_frames
-
idx
*
prev_frame_length
audio_start
,
audio_end
=
self
.
get_audio_range
(
start_frame
,
expected_frames
)
segment_audio
=
audio_array
[
audio_start
:
audio_end
]
useful_length
=
segment_audio
.
shape
[
0
]
max_num_audio_length
=
int
((
max_num_frames
+
1
)
/
self
.
target_fps
*
self
.
audio_sr
)
segment_audio
=
np
.
concatenate
((
segment_audio
,
np
.
zeros
(
max_num_audio_length
-
useful_length
)),
axis
=
0
)
segments
.
append
(
AudioSegment
(
segment_audio
,
start_frame
,
expected_frames
,
True
,
useful_length
))
else
:
# Middle segments
start_frame
=
idx
*
max_num_frames
-
idx
*
prev_frame_length
end_frame
=
(
idx
+
1
)
*
max_num_frames
-
idx
*
prev_frame_length
audio_start
,
audio_end
=
self
.
get_audio_range
(
start_frame
,
end_frame
)
segment_audio
=
audio_array
[
audio_start
:
audio_end
]
segments
.
append
(
AudioSegment
(
segment_audio
,
start_frame
,
end_frame
,
False
))
return
segments
class
VideoGenerator
:
"""Handles video generation for each segment"""
def
__init__
(
self
,
model
,
vae_encoder
,
vae_decoder
,
config
):
self
.
model
=
model
self
.
vae_encoder
=
vae_encoder
self
.
vae_decoder
=
vae_decoder
self
.
config
=
config
self
.
frame_preprocessor
=
FramePreprocessor
()
def
prepare_prev_latents
(
self
,
prev_video
:
Optional
[
torch
.
Tensor
],
prev_frame_length
:
int
)
->
Optional
[
Dict
[
str
,
torch
.
Tensor
]]:
"""Prepare previous latents for conditioning"""
if
prev_video
is
None
:
return
None
device
=
self
.
model
.
device
dtype
=
torch
.
bfloat16
vae_dtype
=
torch
.
float
tgt_h
,
tgt_w
=
self
.
config
.
tgt_h
,
self
.
config
.
tgt_w
prev_frames
=
torch
.
zeros
((
1
,
3
,
self
.
config
.
target_video_length
,
tgt_h
,
tgt_w
),
device
=
device
)
# Extract and process last frames
last_frames
=
prev_video
[:,
:,
-
prev_frame_length
:].
clone
().
to
(
device
)
last_frames
=
self
.
frame_preprocessor
.
process_prev_frames
(
last_frames
)
prev_frames
[:,
:,
:
prev_frame_length
]
=
last_frames
prev_latents
=
self
.
vae_encoder
.
encode
(
prev_frames
.
to
(
vae_dtype
),
self
.
config
)[
0
].
to
(
dtype
)
# Create mask
prev_token_length
=
(
prev_frame_length
-
1
)
//
4
+
1
_
,
nframe
,
height
,
width
=
self
.
model
.
scheduler
.
latents
.
shape
frames_n
=
(
nframe
-
1
)
*
4
+
1
prev_frame_len
=
max
((
prev_token_length
-
1
)
*
4
+
1
,
0
)
prev_mask
=
torch
.
ones
((
1
,
frames_n
,
height
,
width
),
device
=
device
,
dtype
=
dtype
)
prev_mask
[:,
prev_frame_len
:]
=
0
prev_mask
=
self
.
_wan_mask_rearrange
(
prev_mask
).
unsqueeze
(
0
)
return
{
"prev_latents"
:
prev_latents
,
"prev_mask"
:
prev_mask
}
def
_wan_mask_rearrange
(
self
,
mask
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Rearrange mask for WAN model"""
if
mask
.
ndim
==
3
:
mask
=
mask
[
None
]
assert
mask
.
ndim
==
4
_
,
t
,
h
,
w
=
mask
.
shape
assert
t
==
((
t
-
1
)
//
4
*
4
+
1
)
mask_first_frame
=
torch
.
repeat_interleave
(
mask
[:,
0
:
1
],
repeats
=
4
,
dim
=
1
)
mask
=
torch
.
concat
([
mask_first_frame
,
mask
[:,
1
:]],
dim
=
1
)
mask
=
mask
.
view
(
mask
.
shape
[
1
]
//
4
,
4
,
h
,
w
)
return
mask
.
transpose
(
0
,
1
)
@
torch
.
no_grad
()
def
generate_segment
(
self
,
inputs
:
Dict
[
str
,
Any
],
audio_features
:
torch
.
Tensor
,
prev_video
:
Optional
[
torch
.
Tensor
]
=
None
,
prev_frame_length
:
int
=
5
,
segment_idx
:
int
=
0
)
->
torch
.
Tensor
:
"""Generate video segment"""
# Update inputs with audio features
inputs
[
"audio_encoder_output"
]
=
audio_features
# Reset scheduler for non-first segments
if
segment_idx
>
0
:
self
.
model
.
scheduler
.
reset
()
# Prepare previous latents - ALWAYS needed, even for first segment
device
=
self
.
model
.
device
dtype
=
torch
.
bfloat16
vae_dtype
=
torch
.
float
tgt_h
,
tgt_w
=
self
.
config
.
tgt_h
,
self
.
config
.
tgt_w
max_num_frames
=
self
.
config
.
target_video_length
if
segment_idx
==
0
:
# First segment - create zero frames
prev_frames
=
torch
.
zeros
((
1
,
3
,
max_num_frames
,
tgt_h
,
tgt_w
),
device
=
device
)
prev_latents
=
self
.
vae_encoder
.
encode
(
prev_frames
.
to
(
vae_dtype
),
self
.
config
)[
0
].
to
(
dtype
)
prev_len
=
0
else
:
# Subsequent segments - use previous video
previmg_encoder_output
=
self
.
prepare_prev_latents
(
prev_video
,
prev_frame_length
)
if
previmg_encoder_output
:
prev_latents
=
previmg_encoder_output
[
"prev_latents"
]
prev_len
=
(
prev_frame_length
-
1
)
//
4
+
1
else
:
# Fallback to zeros if prepare_prev_latents fails
prev_frames
=
torch
.
zeros
((
1
,
3
,
max_num_frames
,
tgt_h
,
tgt_w
),
device
=
device
)
prev_latents
=
self
.
vae_encoder
.
encode
(
prev_frames
.
to
(
vae_dtype
),
self
.
config
)[
0
].
to
(
dtype
)
prev_len
=
0
# Create mask for prev_latents
_
,
nframe
,
height
,
width
=
self
.
model
.
scheduler
.
latents
.
shape
frames_n
=
(
nframe
-
1
)
*
4
+
1
prev_frame_len
=
max
((
prev_len
-
1
)
*
4
+
1
,
0
)
prev_mask
=
torch
.
ones
((
1
,
frames_n
,
height
,
width
),
device
=
device
,
dtype
=
dtype
)
prev_mask
[:,
prev_frame_len
:]
=
0
prev_mask
=
self
.
_wan_mask_rearrange
(
prev_mask
).
unsqueeze
(
0
)
if
rnd_state
is
None
:
rnd_state
=
np
.
random
.
RandomState
()
# Always set previmg_encoder_output
inputs
[
"previmg_encoder_output"
]
=
{
"prev_latents"
:
prev_latents
,
"prev_mask"
:
prev_mask
}
h
,
w
=
frames
.
shape
[
-
2
:]
mask
=
rnd_state
.
rand
(
h
,
w
)
>
mask_rate
frames
=
frames
*
mask
return
frames
# Run inference loop
for
step_index
in
range
(
self
.
model
.
scheduler
.
infer_steps
):
logger
.
info
(
f
"==> Segment
{
segment_idx
}
, Step
{
step_index
}
/
{
self
.
model
.
scheduler
.
infer_steps
}
"
)
with
ProfilingContext4Debug
(
"step_pre"
):
self
.
model
.
scheduler
.
step_pre
(
step_index
=
step_index
)
def
add_noise_to_frames
(
frames
:
np
.
ndarray
,
noise_mean
:
float
=
-
3.0
,
noise_std
:
float
=
0.5
,
rnd_state
:
np
.
random
.
RandomState
=
None
,
)
->
np
.
ndarray
:
if
noise_mean
is
None
or
noise_std
is
None
:
return
frames
with
ProfilingContext4Debug
(
"infer"
):
self
.
model
.
infer
(
inputs
)
if
rnd_state
is
None
:
rnd_state
=
np
.
random
.
RandomState
()
with
ProfilingContext4Debug
(
"step_post"
)
:
self
.
model
.
scheduler
.
step_post
()
shape
=
frames
.
shape
bs
=
1
if
len
(
shape
)
==
4
else
shape
[
0
]
sigma
=
rnd_state
.
normal
(
loc
=
noise_mean
,
scale
=
noise_std
,
size
=
(
bs
,))
sigma
=
np
.
exp
(
sigma
)
sigma
=
np
.
expand_dims
(
sigma
,
axis
=
tuple
(
range
(
1
,
len
(
shape
))))
noise
=
rnd_state
.
randn
(
*
shape
)
*
sigma
frames
=
frames
+
noise
return
frames
# Decode latents
latents
=
self
.
model
.
scheduler
.
latents
generator
=
self
.
model
.
scheduler
.
generator
gen_video
=
self
.
vae_decoder
.
decode
(
latents
,
generator
=
generator
,
config
=
self
.
config
)
gen_video
=
torch
.
clamp
(
gen_video
,
-
1
,
1
).
to
(
torch
.
float
)
return
gen_video
def
get_crop_bbox
(
ori_h
,
ori_w
,
tgt_h
,
tgt_w
):
...
...
@@ -131,221 +355,226 @@ def adaptive_resize(img):
return
cropped_img
,
target_h
,
target_w
def
array_to_video
(
image_array
:
np
.
ndarray
,
output_path
:
str
,
fps
:
int
|
float
=
30
,
resolution
:
tuple
[
int
,
int
]
|
tuple
[
float
,
float
]
|
None
=
None
,
disable_log
:
bool
=
False
,
lossless
:
bool
=
True
,
output_pix_fmt
:
str
=
"yuv420p"
,
)
->
None
:
"""Convert an array to a video directly, gif not supported.
Args:
image_array (np.ndarray): shape should be (f * h * w * 3).
output_path (str): output video file path.
fps (Union[int, float, optional): fps. Defaults to 30.
resolution (Optional[Union[Tuple[int, int], Tuple[float, float]]],
optional): (height, width) of the output video.
Defaults to None.
disable_log (bool, optional): whether close the ffmepg command info.
Defaults to False.
output_pix_fmt (str): output pix_fmt in ffmpeg command.
Raises:
FileNotFoundError: check output path.
TypeError: check input array.
Returns:
None.
"""
if
not
isinstance
(
image_array
,
np
.
ndarray
):
raise
TypeError
(
"Input should be np.ndarray."
)
assert
image_array
.
ndim
==
4
assert
image_array
.
shape
[
-
1
]
==
3
if
resolution
:
height
,
width
=
resolution
width
+=
width
%
2
height
+=
height
%
2
else
:
image_array
=
pad_for_libx264
(
image_array
)
height
,
width
=
image_array
.
shape
[
1
],
image_array
.
shape
[
2
]
if
lossless
:
command
=
[
"/usr/bin/ffmpeg"
,
"-y"
,
# (optional) overwrite output file if it exists
"-f"
,
"rawvideo"
,
"-s"
,
f
"
{
int
(
width
)
}
x
{
int
(
height
)
}
"
,
# size of one frame
"-pix_fmt"
,
"bgr24"
,
"-r"
,
f
"
{
fps
}
"
,
# frames per second
"-loglevel"
,
"error"
,
"-threads"
,
"4"
,
"-i"
,
"-"
,
# The input comes from a pipe
"-vcodec"
,
"libx264rgb"
,
"-crf"
,
"0"
,
"-an"
,
# Tells FFMPEG not to expect any audio
output_path
,
]
else
:
output_pix_fmt
=
output_pix_fmt
or
"yuv420p"
command
=
[
"/usr/bin/ffmpeg"
,
"-y"
,
# (optional) overwrite output file if it exists
"-f"
,
"rawvideo"
,
"-s"
,
f
"
{
int
(
width
)
}
x
{
int
(
height
)
}
"
,
# size of one frame
"-pix_fmt"
,
"bgr24"
,
"-r"
,
f
"
{
fps
}
"
,
# frames per second
"-loglevel"
,
"error"
,
"-threads"
,
"4"
,
"-i"
,
"-"
,
# The input comes from a pipe
"-vcodec"
,
"libx264"
,
"-pix_fmt"
,
f
"
{
output_pix_fmt
}
"
,
"-an"
,
# Tells FFMPEG not to expect any audio
output_path
,
]
if
output_pix_fmt
is
not
None
:
command
+=
[
"-pix_fmt"
,
output_pix_fmt
]
if
not
disable_log
:
print
(
f
'Running "
{
" "
.
join
(
command
)
}
"'
)
process
=
subprocess
.
Popen
(
command
,
stdin
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
,
)
if
process
.
stdin
is
None
or
process
.
stderr
is
None
:
raise
BrokenPipeError
(
"No buffer received."
)
index
=
0
while
True
:
if
index
>=
image_array
.
shape
[
0
]:
break
process
.
stdin
.
write
(
image_array
[
index
].
tobytes
())
index
+=
1
process
.
stdin
.
close
()
process
.
stderr
.
close
()
process
.
wait
()
def
pad_for_libx264
(
image_array
):
if
image_array
.
ndim
==
2
or
(
image_array
.
ndim
==
3
and
image_array
.
shape
[
2
]
==
3
):
hei_index
=
0
wid_index
=
1
elif
image_array
.
ndim
==
4
or
(
image_array
.
ndim
==
3
and
image_array
.
shape
[
2
]
!=
3
):
hei_index
=
1
wid_index
=
2
else
:
return
image_array
hei_pad
=
image_array
.
shape
[
hei_index
]
%
2
wid_pad
=
image_array
.
shape
[
wid_index
]
%
2
if
hei_pad
+
wid_pad
>
0
:
pad_width
=
[]
for
dim_index
in
range
(
image_array
.
ndim
):
if
dim_index
==
hei_index
:
pad_width
.
append
((
0
,
hei_pad
))
elif
dim_index
==
wid_index
:
pad_width
.
append
((
0
,
wid_pad
))
else
:
pad_width
.
append
((
0
,
0
))
values
=
0
image_array
=
np
.
pad
(
image_array
,
pad_width
,
mode
=
"constant"
,
constant_values
=
values
)
return
image_array
def
generate_unique_path
(
path
):
if
not
os
.
path
.
exists
(
path
):
return
path
root
,
ext
=
os
.
path
.
splitext
(
path
)
index
=
1
new_path
=
f
"
{
root
}
-
{
index
}{
ext
}
"
while
os
.
path
.
exists
(
new_path
):
index
+=
1
new_path
=
f
"
{
root
}
-
{
index
}{
ext
}
"
return
new_path
def
save_audio
(
audio_array
,
audio_name
:
str
,
video_name
:
str
,
sr
:
int
=
16000
,
output_path
:
Optional
[
str
]
=
None
,
):
logger
.
info
(
f
"Saving audio to
{
audio_name
}
type:
{
type
(
audio_array
)
}
"
)
ta
.
save
(
audio_name
,
torch
.
tensor
(
audio_array
[
None
]),
sample_rate
=
sr
,
)
if
output_path
is
None
:
out_video
=
f
"
{
video_name
[:
-
4
]
}
_with_audio.mp4"
else
:
out_video
=
output_path
@
RUNNER_REGISTER
(
"wan2.1_audio"
)
class
WanAudioRunner
(
WanRunner
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
self
.
_is_initialized
=
False
self
.
_audio_adapter_pipe
=
None
self
.
_audio_processor
=
None
self
.
_video_generator
=
None
self
.
_audio_preprocess
=
None
parent_dir
=
os
.
path
.
dirname
(
out_video
)
if
parent_dir
and
not
os
.
path
.
exists
(
parent_dir
):
os
.
makedirs
(
parent_dir
,
exist_ok
=
True
)
def
initialize_once
(
self
):
"""Initialize all models once for multiple runs"""
if
self
.
_is_initialized
:
return
if
os
.
path
.
exists
(
out_video
):
os
.
remove
(
out_video
)
logger
.
info
(
"Initializing models (one-time setup)..."
)
subprocess
.
call
([
"/usr/bin/ffmpeg"
,
"-y"
,
"-i"
,
video_name
,
"-i"
,
audio_name
,
out_video
])
# Initialize audio processor
audio_sr
=
self
.
config
.
get
(
"audio_sr"
,
16000
)
target_fps
=
self
.
config
.
get
(
"target_fps"
,
16
)
self
.
_audio_processor
=
AudioProcessor
(
audio_sr
,
target_fps
)
return
out_video
# Load audio feature extractor
self
.
_audio_preprocess
=
AutoFeatureExtractor
.
from_pretrained
(
self
.
config
[
"model_path"
],
subfolder
=
"audio_encoder"
)
# Initialize scheduler
self
.
init_scheduler
()
@
RUNNER_REGISTER
(
"wan2.1_audio"
)
class
WanAudioRunner
(
WanRunner
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
self
.
_is_initialized
=
True
logger
.
info
(
"Model initialization complete"
)
def
init_scheduler
(
self
):
"""Initialize consistency model scheduler"""
scheduler
=
ConsistencyModelScheduler
(
self
.
config
)
self
.
model
.
set_scheduler
(
scheduler
)
def
load_audio_models
(
self
):
##音频特征提取器
self
.
audio_preprocess
=
AutoFeatureExtractor
.
from_pretrained
(
self
.
config
[
"model_path"
],
subfolder
=
"audio_encoder"
)
def
load_audio_adapter_lazy
(
self
):
"""Lazy load audio adapter when needed"""
if
self
.
_audio_adapter_pipe
is
not
None
:
return
self
.
_audio_adapter_pipe
#
#音频驱动视频生成
adapter
#
Audio
adapter
audio_adapter_path
=
self
.
config
[
"model_path"
]
+
"/audio_adapter.safetensors"
audio_adaper
=
AudioAdapter
.
from_transformer
(
audio_adap
t
er
=
AudioAdapter
.
from_transformer
(
self
.
model
,
audio_feature_dim
=
1024
,
interval
=
1
,
time_freq_dim
=
256
,
projection_transformer_layers
=
4
,
)
audio_adapter
=
rank0_load_state_dict_from_path
(
audio_adaper
,
audio_adapter_path
,
strict
=
False
)
audio_adapter
=
rank0_load_state_dict_from_path
(
audio_adap
t
er
,
audio_adapter_path
,
strict
=
False
)
#
#音频特征编码器
#
Audio encoder
device
=
self
.
model
.
device
audio_encoder_repo
=
self
.
config
[
"model_path"
]
+
"/audio_encoder"
audio_adapter_pipe
=
AudioAdapterPipe
(
audio_adapter
,
audio_encoder_repo
=
audio_encoder_repo
,
dtype
=
torch
.
bfloat16
,
device
=
device
,
generator
=
torch
.
Generator
(
device
),
weight
=
1.0
)
self
.
_
audio_adapter_pipe
=
AudioAdapterPipe
(
audio_adapter
,
audio_encoder_repo
=
audio_encoder_repo
,
dtype
=
torch
.
bfloat16
,
device
=
device
,
generator
=
torch
.
Generator
(
device
),
weight
=
1.0
)
return
audio_adapter_pipe
return
self
.
_audio_adapter_pipe
def
prepare_inputs
(
self
):
"""Prepare inputs for the model"""
image_encoder_output
=
None
if
os
.
path
.
isfile
(
self
.
config
.
image_path
):
with
ProfilingContext
(
"Run Img Encoder"
):
vae_encode_out
,
clip_encoder_out
=
self
.
run_image_encoder
(
self
.
config
,
self
.
vae_encoder
)
image_encoder_output
=
{
"clip_encoder_out"
:
clip_encoder_out
,
"vae_encode_out"
:
vae_encode_out
,
}
with
ProfilingContext
(
"Run Text Encoder"
):
img
=
Image
.
open
(
self
.
config
[
"image_path"
]).
convert
(
"RGB"
)
text_encoder_output
=
self
.
run_text_encoder
(
self
.
config
[
"prompt"
],
img
)
self
.
set_target_shape
()
return
{
"text_encoder_output"
:
text_encoder_output
,
"image_encoder_output"
:
image_encoder_output
,
"audio_adapter_pipe"
:
self
.
load_audio_adapter_lazy
()}
def
run_pipeline
(
self
,
save_video
=
True
):
"""Optimized pipeline with modular components"""
# Ensure models are initialized
self
.
initialize_once
()
# Initialize video generator if needed
if
self
.
_video_generator
is
None
:
self
.
_video_generator
=
VideoGenerator
(
self
.
model
,
self
.
vae_encoder
,
self
.
vae_decoder
,
self
.
config
)
# Prepare inputs
with
memory_efficient_inference
():
if
self
.
config
[
"use_prompt_enhancer"
]:
self
.
config
[
"prompt_enhanced"
]
=
self
.
post_prompt_enhancer
()
self
.
inputs
=
self
.
prepare_inputs
()
self
.
model
.
scheduler
.
prepare
(
self
.
inputs
[
"image_encoder_output"
])
# Process audio
audio_array
=
self
.
_audio_processor
.
load_audio
(
self
.
config
[
"audio_path"
])
video_duration
=
self
.
config
.
get
(
"video_duration"
,
5
)
target_fps
=
self
.
config
.
get
(
"target_fps"
,
16
)
max_num_frames
=
self
.
config
.
get
(
"target_video_length"
,
81
)
audio_len
=
int
(
audio_array
.
shape
[
0
]
/
self
.
_audio_processor
.
audio_sr
*
target_fps
)
expected_frames
=
min
(
max
(
1
,
int
(
video_duration
*
target_fps
)),
audio_len
)
# Segment audio
audio_segments
=
self
.
_audio_processor
.
segment_audio
(
audio_array
,
expected_frames
,
max_num_frames
)
# Generate video segments
gen_video_list
=
[]
cut_audio_list
=
[]
prev_video
=
None
for
idx
,
segment
in
enumerate
(
audio_segments
):
# Update seed for each segment
self
.
config
.
seed
=
self
.
config
.
seed
+
idx
torch
.
manual_seed
(
self
.
config
.
seed
)
logger
.
info
(
f
"Processing segment
{
idx
+
1
}
/
{
len
(
audio_segments
)
}
, seed:
{
self
.
config
.
seed
}
"
)
# Process audio features
audio_features
=
self
.
_audio_preprocess
(
segment
.
audio_array
,
sampling_rate
=
self
.
_audio_processor
.
audio_sr
,
return_tensors
=
"pt"
).
input_values
.
squeeze
(
0
).
to
(
self
.
model
.
device
)
# Generate video segment
with
memory_efficient_inference
():
gen_video
=
self
.
_video_generator
.
generate_segment
(
self
.
inputs
.
copy
(),
# Copy to avoid modifying original
audio_features
,
prev_video
=
prev_video
,
prev_frame_length
=
5
,
segment_idx
=
idx
,
)
# Extract relevant frames
start_frame
=
0
if
idx
==
0
else
5
start_audio_frame
=
0
if
idx
==
0
else
int
(
6
*
self
.
_audio_processor
.
audio_sr
/
target_fps
)
if
segment
.
is_last
and
segment
.
useful_length
:
end_frame
=
segment
.
end_frame
-
segment
.
start_frame
gen_video_list
.
append
(
gen_video
[:,
:,
start_frame
:
end_frame
].
cpu
())
cut_audio_list
.
append
(
segment
.
audio_array
[
start_audio_frame
:
segment
.
useful_length
])
elif
segment
.
useful_length
and
expected_frames
<
max_num_frames
:
gen_video_list
.
append
(
gen_video
[:,
:,
start_frame
:
expected_frames
].
cpu
())
cut_audio_list
.
append
(
segment
.
audio_array
[
start_audio_frame
:
segment
.
useful_length
])
else
:
gen_video_list
.
append
(
gen_video
[:,
:,
start_frame
:].
cpu
())
cut_audio_list
.
append
(
segment
.
audio_array
[
start_audio_frame
:])
# Update prev_video for next iteration
prev_video
=
gen_video
# Clean up GPU memory after each segment
del
gen_video
torch
.
cuda
.
empty_cache
()
# Merge results
with
memory_efficient_inference
():
gen_lvideo
=
torch
.
cat
(
gen_video_list
,
dim
=
2
).
float
()
merge_audio
=
np
.
concatenate
(
cut_audio_list
,
axis
=
0
).
astype
(
np
.
float32
)
comfyui_images
=
vae_to_comfyui_image
(
gen_lvideo
)
# Apply frame interpolation if configured
if
"video_frame_interpolation"
in
self
.
config
and
self
.
vfi_model
is
not
None
:
interpolation_target_fps
=
self
.
config
[
"video_frame_interpolation"
][
"target_fps"
]
logger
.
info
(
f
"Interpolating frames from
{
target_fps
}
to
{
interpolation_target_fps
}
"
)
comfyui_images
=
self
.
vfi_model
.
interpolate_frames
(
comfyui_images
,
source_fps
=
target_fps
,
target_fps
=
interpolation_target_fps
,
)
target_fps
=
interpolation_target_fps
# Convert audio to ComfyUI format
audio_waveform
=
torch
.
from_numpy
(
merge_audio
).
unsqueeze
(
0
).
unsqueeze
(
0
)
comfyui_audio
=
{
"waveform"
:
audio_waveform
,
"sample_rate"
:
self
.
_audio_processor
.
audio_sr
}
# Save video if requested
if
save_video
and
self
.
config
.
get
(
"save_video_path"
,
None
):
self
.
_save_video_with_audio
(
comfyui_images
,
merge_audio
,
target_fps
)
# Final cleanup
self
.
end_run
()
return
comfyui_images
,
comfyui_audio
def
_save_video_with_audio
(
self
,
images
,
audio_array
,
fps
):
"""Save video with audio"""
import
tempfile
with
tempfile
.
NamedTemporaryFile
(
suffix
=
".mp4"
,
delete
=
False
)
as
video_tmp
:
video_path
=
video_tmp
.
name
with
tempfile
.
NamedTemporaryFile
(
suffix
=
".wav"
,
delete
=
False
)
as
audio_tmp
:
audio_path
=
audio_tmp
.
name
try
:
# Save video
save_to_video
(
images
,
video_path
,
fps
)
# Save audio
ta
.
save
(
audio_path
,
torch
.
tensor
(
audio_array
[
None
]),
sample_rate
=
self
.
_audio_processor
.
audio_sr
)
# Merge video and audio
output_path
=
self
.
config
.
get
(
"save_video_path"
)
parent_dir
=
os
.
path
.
dirname
(
output_path
)
if
parent_dir
and
not
os
.
path
.
exists
(
parent_dir
):
os
.
makedirs
(
parent_dir
,
exist_ok
=
True
)
subprocess
.
call
([
"/usr/bin/ffmpeg"
,
"-y"
,
"-i"
,
video_path
,
"-i"
,
audio_path
,
output_path
])
logger
.
info
(
f
"Saved video with audio to:
{
output_path
}
"
)
finally
:
# Clean up temp files
if
os
.
path
.
exists
(
video_path
):
os
.
remove
(
video_path
)
if
os
.
path
.
exists
(
audio_path
):
os
.
remove
(
audio_path
)
def
load_transformer
(
self
):
"""Load transformer with LoRA support"""
base_model
=
WanAudioModel
(
self
.
config
.
model_path
,
self
.
config
,
self
.
init_device
)
if
self
.
config
.
get
(
"lora_configs"
)
and
self
.
config
.
lora_configs
:
...
...
@@ -361,19 +590,21 @@ class WanAudioRunner(WanRunner):
return
base_model
def
load_image_encoder
(
self
):
"""Load image encoder"""
clip_model_dir
=
self
.
config
[
"model_path"
]
+
"/image_encoder"
image_encoder
=
WanVideoIPHandler
(
"CLIPModel"
,
repo_or_path
=
clip_model_dir
,
require_grad
=
False
,
mode
=
"eval"
,
device
=
self
.
init_device
,
dtype
=
torch
.
float16
)
return
image_encoder
def
run_image_encoder
(
self
,
config
,
vae_model
):
"""Run image encoder"""
ref_img
=
Image
.
open
(
config
.
image_path
)
ref_img
=
(
np
.
array
(
ref_img
).
astype
(
np
.
float32
)
-
127.5
)
/
127.5
ref_img
=
torch
.
from_numpy
(
ref_img
).
to
(
vae_model
.
device
)
ref_img
=
rearrange
(
ref_img
,
"H W C -> 1 C H W"
)
ref_img
=
ref_img
[:,
:
3
]
#
r
esize and crop image
#
R
esize and crop image
cond_frms
,
tgt_h
,
tgt_w
=
adaptive_resize
(
ref_img
)
config
.
tgt_h
=
tgt_h
config
.
tgt_w
=
tgt_w
...
...
@@ -384,36 +615,13 @@ class WanAudioRunner(WanRunner):
config
.
lat_h
=
lat_h
config
.
lat_w
=
lat_w
vae_encode_out
=
vae_model
.
encode
(
cond_frms
.
to
(
torch
.
float
),
config
)
if
isinstance
(
vae_encode_out
,
list
):
#
# list转tensor
if
isinstance
(
vae_encode_out
,
list
):
vae_encode_out
=
torch
.
stack
(
vae_encode_out
,
dim
=
0
).
to
(
torch
.
bfloat16
)
return
vae_encode_out
,
clip_encoder_out
def
run_input_encoder_internal
(
self
):
image_encoder_output
=
None
if
os
.
path
.
isfile
(
self
.
config
.
image_path
):
with
ProfilingContext
(
"Run Img Encoder"
):
vae_encode_out
,
clip_encoder_out
=
self
.
run_image_encoder
(
self
.
config
,
self
.
vae_encoder
)
image_encoder_output
=
{
"clip_encoder_out"
:
clip_encoder_out
,
"vae_encode_out"
:
vae_encode_out
,
}
logger
.
info
(
f
"clip_encoder_out:
{
clip_encoder_out
.
shape
}
vae_encode_out:
{
vae_encode_out
.
shape
}
"
)
with
ProfilingContext
(
"Run Text Encoder"
):
logger
.
info
(
f
"Prompt:
{
self
.
config
[
'prompt'
]
}
"
)
img
=
Image
.
open
(
self
.
config
[
"image_path"
]).
convert
(
"RGB"
)
text_encoder_output
=
self
.
run_text_encoder
(
self
.
config
[
"prompt"
],
img
)
self
.
set_target_shape
()
self
.
inputs
=
{
"text_encoder_output"
:
text_encoder_output
,
"image_encoder_output"
:
image_encoder_output
}
# del self.image_encoder # 删除ref的clip模型,只使用一次
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
def
set_target_shape
(
self
):
"""Set target shape for generation"""
ret
=
{}
num_channels_latents
=
16
if
self
.
config
.
task
==
"i2v"
:
...
...
@@ -427,216 +635,7 @@ class WanAudioRunner(WanRunner):
ret
[
"lat_w"
]
=
self
.
config
.
lat_w
else
:
error_msg
=
"t2v task is not supported in WanAudioRunner"
assert
1
==
0
,
error_msg
assert
False
,
error_msg
ret
[
"target_shape"
]
=
self
.
config
.
target_shape
return
ret
def
run
(
self
,
save_video
=
True
):
def
load_audio
(
in_path
:
str
,
sr
:
float
=
16000
):
audio_array
,
ori_sr
=
ta
.
load
(
in_path
)
audio_array
=
ta
.
functional
.
resample
(
audio_array
.
mean
(
0
),
orig_freq
=
ori_sr
,
new_freq
=
sr
)
return
audio_array
.
numpy
()
def
get_audio_range
(
start_frame
:
int
,
end_frame
:
int
,
fps
:
float
,
audio_sr
:
float
=
16000
):
audio_frame_rate
=
audio_sr
/
fps
return
round
(
start_frame
*
audio_frame_rate
),
round
((
end_frame
+
1
)
*
audio_frame_rate
)
def
wan_mask_rearrange
(
mask
:
torch
.
Tensor
):
# mask: 1, T, H, W, where 1 means the input mask is one-channel
if
mask
.
ndim
==
3
:
mask
=
mask
[
None
]
assert
mask
.
ndim
==
4
_
,
t
,
h
,
w
=
mask
.
shape
assert
t
==
((
t
-
1
)
//
4
*
4
+
1
)
mask_first_frame
=
torch
.
repeat_interleave
(
mask
[:,
0
:
1
],
repeats
=
4
,
dim
=
1
)
mask
=
torch
.
concat
([
mask_first_frame
,
mask
[:,
1
:]],
dim
=
1
)
mask
=
mask
.
view
(
mask
.
shape
[
1
]
//
4
,
4
,
h
,
w
)
return
mask
.
transpose
(
0
,
1
)
# 4, T // 4, H, W
self
.
inputs
[
"audio_adapter_pipe"
]
=
self
.
load_audio_models
()
# process audio
audio_sr
=
self
.
config
.
get
(
"audio_sr"
,
16000
)
max_num_frames
=
self
.
config
.
get
(
"target_video_length"
,
81
)
# wan2.1一段最多81帧,5秒,16fps
target_fps
=
self
.
config
.
get
(
"target_fps"
,
16
)
# 音视频同步帧率
video_duration
=
self
.
config
.
get
(
"video_duration"
,
5
)
# 期望视频输出时长
audio_array
=
load_audio
(
self
.
config
[
"audio_path"
],
sr
=
audio_sr
)
audio_len
=
int
(
audio_array
.
shape
[
0
]
/
audio_sr
*
target_fps
)
prev_frame_length
=
5
prev_token_length
=
(
prev_frame_length
-
1
)
//
4
+
1
max_num_audio_length
=
int
((
max_num_frames
+
1
)
/
target_fps
*
audio_sr
)
interval_num
=
1
# expected_frames
expected_frames
=
min
(
max
(
1
,
int
(
float
(
video_duration
)
*
target_fps
)),
audio_len
)
res_frame_num
=
0
if
expected_frames
<=
max_num_frames
:
interval_num
=
1
else
:
interval_num
=
max
(
int
((
expected_frames
-
max_num_frames
)
/
(
max_num_frames
-
prev_frame_length
))
+
1
,
1
)
res_frame_num
=
expected_frames
-
interval_num
*
(
max_num_frames
-
prev_frame_length
)
if
res_frame_num
>
5
:
interval_num
+=
1
audio_start
,
audio_end
=
get_audio_range
(
0
,
expected_frames
,
fps
=
target_fps
,
audio_sr
=
audio_sr
)
audio_array_ori
=
audio_array
[
audio_start
:
audio_end
]
gen_video_list
=
[]
cut_audio_list
=
[]
# reference latents
tgt_h
=
self
.
config
.
tgt_h
tgt_w
=
self
.
config
.
tgt_w
device
=
self
.
model
.
scheduler
.
latents
.
device
dtype
=
torch
.
bfloat16
vae_dtype
=
torch
.
float
for
idx
in
range
(
interval_num
):
self
.
config
.
seed
=
self
.
config
.
seed
+
idx
torch
.
manual_seed
(
self
.
config
.
seed
)
logger
.
info
(
f
"### manual_seed:
{
self
.
config
.
seed
}
####"
)
useful_length
=
-
1
if
idx
==
0
:
# 第一段 Condition padding0
prev_frames
=
torch
.
zeros
((
1
,
3
,
max_num_frames
,
tgt_h
,
tgt_w
),
device
=
device
)
prev_latents
=
self
.
vae_encoder
.
encode
(
prev_frames
.
to
(
vae_dtype
),
self
.
config
)[
0
].
to
(
dtype
)
prev_len
=
0
audio_start
,
audio_end
=
get_audio_range
(
0
,
max_num_frames
,
fps
=
target_fps
,
audio_sr
=
audio_sr
)
audio_array
=
audio_array_ori
[
audio_start
:
audio_end
]
if
expected_frames
<
max_num_frames
:
useful_length
=
audio_array
.
shape
[
0
]
audio_array
=
np
.
concatenate
((
audio_array
,
np
.
zeros
(
max_num_audio_length
)[:
max_num_audio_length
-
useful_length
]),
axis
=
0
)
audio_input_feat
=
self
.
audio_preprocess
(
audio_array
,
sampling_rate
=
audio_sr
,
return_tensors
=
"pt"
).
input_values
.
squeeze
(
0
)
elif
res_frame_num
>
5
and
idx
==
interval_num
-
1
:
# 最后一段可能不够81帧
prev_frames
=
torch
.
zeros
((
1
,
3
,
max_num_frames
,
tgt_h
,
tgt_w
),
device
=
device
)
last_frames
=
gen_video_list
[
-
1
][:,
:,
-
prev_frame_length
:].
clone
().
to
(
device
)
last_frames
=
last_frames
.
cpu
().
detach
().
numpy
()
last_frames
=
add_noise_to_frames
(
last_frames
)
last_frames
=
add_mask_to_frames
(
last_frames
,
mask_rate
=
0.1
)
# mask 0.10
last_frames
=
torch
.
from_numpy
(
last_frames
).
to
(
dtype
=
dtype
,
device
=
device
)
prev_frames
[:,
:,
:
prev_frame_length
]
=
last_frames
prev_latents
=
self
.
vae_encoder
.
encode
(
prev_frames
.
to
(
vae_dtype
),
self
.
config
)[
0
].
to
(
dtype
)
prev_len
=
prev_token_length
audio_start
,
audio_end
=
get_audio_range
(
idx
*
max_num_frames
-
idx
*
prev_frame_length
,
expected_frames
,
fps
=
target_fps
,
audio_sr
=
audio_sr
)
audio_array
=
audio_array_ori
[
audio_start
:
audio_end
]
useful_length
=
audio_array
.
shape
[
0
]
audio_array
=
np
.
concatenate
((
audio_array
,
np
.
zeros
(
max_num_audio_length
)[:
max_num_audio_length
-
useful_length
]),
axis
=
0
)
audio_input_feat
=
self
.
audio_preprocess
(
audio_array
,
sampling_rate
=
audio_sr
,
return_tensors
=
"pt"
).
input_values
.
squeeze
(
0
)
else
:
# 中间段满81帧带pre_latens
prev_frames
=
torch
.
zeros
((
1
,
3
,
max_num_frames
,
tgt_h
,
tgt_w
),
device
=
device
)
last_frames
=
gen_video_list
[
-
1
][:,
:,
-
prev_frame_length
:].
clone
().
to
(
device
)
last_frames
=
last_frames
.
cpu
().
detach
().
numpy
()
last_frames
=
add_noise_to_frames
(
last_frames
)
# mean:-3.0 std:0.5
last_frames
=
add_mask_to_frames
(
last_frames
,
mask_rate
=
0.1
)
# mask 0.10
last_frames
=
torch
.
from_numpy
(
last_frames
).
to
(
dtype
=
dtype
,
device
=
device
)
prev_frames
[:,
:,
:
prev_frame_length
]
=
last_frames
prev_latents
=
self
.
vae_encoder
.
encode
(
prev_frames
.
to
(
vae_dtype
),
self
.
config
)[
0
].
to
(
dtype
)
prev_len
=
prev_token_length
audio_start
,
audio_end
=
get_audio_range
(
idx
*
max_num_frames
-
idx
*
prev_frame_length
,
(
idx
+
1
)
*
max_num_frames
-
idx
*
prev_frame_length
,
fps
=
target_fps
,
audio_sr
=
audio_sr
)
audio_array
=
audio_array_ori
[
audio_start
:
audio_end
]
audio_input_feat
=
self
.
audio_preprocess
(
audio_array
,
sampling_rate
=
audio_sr
,
return_tensors
=
"pt"
).
input_values
.
squeeze
(
0
)
self
.
inputs
[
"audio_encoder_output"
]
=
audio_input_feat
.
to
(
device
)
if
idx
!=
0
:
self
.
model
.
scheduler
.
reset
()
if
prev_latents
is
not
None
:
_
,
nframe
,
height
,
width
=
self
.
model
.
scheduler
.
latents
.
shape
# bs = 1
frames_n
=
(
nframe
-
1
)
*
4
+
1
prev_frame_len
=
max
((
prev_len
-
1
)
*
4
+
1
,
0
)
prev_mask
=
torch
.
ones
((
1
,
frames_n
,
height
,
width
),
device
=
device
,
dtype
=
dtype
)
prev_mask
[:,
prev_frame_len
:]
=
0
prev_mask
=
wan_mask_rearrange
(
prev_mask
).
unsqueeze
(
0
)
previmg_encoder_output
=
{
"prev_latents"
:
prev_latents
,
"prev_mask"
:
prev_mask
,
}
self
.
inputs
[
"previmg_encoder_output"
]
=
previmg_encoder_output
for
step_index
in
range
(
self
.
model
.
scheduler
.
infer_steps
):
logger
.
info
(
f
"==> step_index:
{
step_index
}
/
{
self
.
model
.
scheduler
.
infer_steps
}
"
)
with
ProfilingContext4Debug
(
"step_pre"
):
self
.
model
.
scheduler
.
step_pre
(
step_index
=
step_index
)
with
ProfilingContext4Debug
(
"infer"
):
self
.
model
.
infer
(
self
.
inputs
)
with
ProfilingContext4Debug
(
"step_post"
):
self
.
model
.
scheduler
.
step_post
()
latents
=
self
.
model
.
scheduler
.
latents
generator
=
self
.
model
.
scheduler
.
generator
gen_video
=
self
.
vae_decoder
.
decode
(
latents
,
generator
=
generator
,
config
=
self
.
config
)
gen_video
=
torch
.
clamp
(
gen_video
,
-
1
,
1
).
to
(
torch
.
float
)
start_frame
=
0
if
idx
==
0
else
prev_frame_length
start_audio_frame
=
0
if
idx
==
0
else
int
((
prev_frame_length
+
1
)
*
audio_sr
/
target_fps
)
if
res_frame_num
>
5
and
idx
==
interval_num
-
1
:
gen_video_list
.
append
(
gen_video
[:,
:,
start_frame
:
res_frame_num
].
cpu
())
cut_audio_list
.
append
(
audio_array
[
start_audio_frame
:
useful_length
])
elif
expected_frames
<
max_num_frames
and
useful_length
!=
-
1
:
gen_video_list
.
append
(
gen_video
[:,
:,
start_frame
:
expected_frames
].
cpu
())
cut_audio_list
.
append
(
audio_array
[
start_audio_frame
:
useful_length
])
else
:
gen_video_list
.
append
(
gen_video
[:,
:,
start_frame
:].
cpu
())
cut_audio_list
.
append
(
audio_array
[
start_audio_frame
:])
gen_lvideo
=
torch
.
cat
(
gen_video_list
,
dim
=
2
).
float
()
merge_audio
=
np
.
concatenate
(
cut_audio_list
,
axis
=
0
).
astype
(
np
.
float32
)
comfyui_images
=
vae_to_comfyui_image
(
gen_lvideo
)
# Apply frame interpolation if configured
if
"video_frame_interpolation"
in
self
.
config
:
assert
self
.
vfi_model
is
not
None
and
self
.
config
[
"video_frame_interpolation"
].
get
(
"target_fps"
,
None
)
is
not
None
interpolation_target_fps
=
self
.
config
[
"video_frame_interpolation"
][
"target_fps"
]
logger
.
info
(
f
"Interpolating frames from
{
target_fps
}
to
{
interpolation_target_fps
}
"
)
comfyui_images
=
self
.
vfi_model
.
interpolate_frames
(
comfyui_images
,
source_fps
=
target_fps
,
target_fps
=
interpolation_target_fps
,
)
# Update target_fps for saving
target_fps
=
interpolation_target_fps
# Convert audio to ComfyUI format
# Convert numpy array to torch tensor and add batch dimension
audio_waveform
=
torch
.
from_numpy
(
merge_audio
).
unsqueeze
(
0
).
unsqueeze
(
0
)
# [batch, channels, samples]
comfyui_audio
=
{
"waveform"
:
audio_waveform
,
"sample_rate"
:
audio_sr
}
# Save video if requested
if
save_video
and
self
.
config
.
get
(
"save_video_path"
,
None
):
out_path
=
os
.
path
.
join
(
"./"
,
"video_merge.mp4"
)
audio_file
=
os
.
path
.
join
(
"./"
,
"audio_merge.wav"
)
# Use the updated target_fps (after interpolation if applied)
save_to_video
(
comfyui_images
,
out_path
,
target_fps
)
save_audio
(
merge_audio
,
audio_file
,
out_path
,
output_path
=
self
.
config
.
get
(
"save_video_path"
,
None
))
os
.
remove
(
out_path
)
os
.
remove
(
audio_file
)
return
comfyui_images
,
comfyui_audio
def
run_pipeline
(
self
,
save_video
=
True
):
if
self
.
config
[
"use_prompt_enhancer"
]:
self
.
config
[
"prompt_enhanced"
]
=
self
.
post_prompt_enhancer
()
self
.
run_input_encoder_internal
()
self
.
set_target_shape
()
self
.
init_scheduler
()
self
.
model
.
scheduler
.
prepare
(
self
.
inputs
[
"image_encoder_output"
])
images
,
audio
=
self
.
run
(
save_video
)
# run() now returns both images and audio
self
.
end_run
()
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
return
images
,
audio
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment