Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
d048b178
Commit
d048b178
authored
Jul 23, 2025
by
gaclove
Browse files
refactor: enhance WanAudioRunner to improve audio handling and frame interpolation
parent
5bd9bdbd
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
458 additions
and
459 deletions
+458
-459
lightx2v/models/runners/wan/wan_audio_runner.py
lightx2v/models/runners/wan/wan_audio_runner.py
+458
-459
No files found.
lightx2v/models/runners/wan/wan_audio_runner.py
View file @
d048b178
...
@@ -4,6 +4,10 @@ import numpy as np
...
@@ -4,6 +4,10 @@ import numpy as np
import
torch
import
torch
import
torchvision.transforms.functional
as
TF
import
torchvision.transforms.functional
as
TF
from
PIL
import
Image
from
PIL
import
Image
from
contextlib
import
contextmanager
from
typing
import
Optional
,
Tuple
,
Union
,
List
,
Dict
,
Any
from
dataclasses
import
dataclass
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.models.runners.wan.wan_runner
import
WanRunner
from
lightx2v.models.runners.wan.wan_runner
import
WanRunner
from
lightx2v.models.runners.default_runner
import
DefaultRunner
from
lightx2v.models.runners.default_runner
import
DefaultRunner
...
@@ -34,46 +38,266 @@ from torchvision.transforms.functional import resize
...
@@ -34,46 +38,266 @@ from torchvision.transforms.functional import resize
import
subprocess
import
subprocess
import
warnings
import
warnings
from
typing
import
Optional
,
Tuple
,
Union
def
add_mask_to_frames
(
@
contextmanager
frames
:
np
.
ndarray
,
def
memory_efficient_inference
():
mask_rate
:
float
=
0.1
,
"""Context manager for memory-efficient inference"""
rnd_state
:
np
.
random
.
RandomState
=
None
,
try
:
)
->
np
.
ndarray
:
yield
if
mask_rate
is
None
:
finally
:
return
frames
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
@
dataclass
class
AudioSegment
:
"""Data class for audio segment information"""
audio_array
:
np
.
ndarray
start_frame
:
int
end_frame
:
int
is_last
:
bool
=
False
useful_length
:
Optional
[
int
]
=
None
class
FramePreprocessor
:
"""Handles frame preprocessing including noise and masking"""
def
__init__
(
self
,
noise_mean
:
float
=
-
3.0
,
noise_std
:
float
=
0.5
,
mask_rate
:
float
=
0.1
):
self
.
noise_mean
=
noise_mean
self
.
noise_std
=
noise_std
self
.
mask_rate
=
mask_rate
def
add_noise
(
self
,
frames
:
np
.
ndarray
,
rnd_state
:
Optional
[
np
.
random
.
RandomState
]
=
None
)
->
np
.
ndarray
:
"""Add noise to frames"""
if
self
.
noise_mean
is
None
or
self
.
noise_std
is
None
:
return
frames
if
rnd_state
is
None
:
rnd_state
=
np
.
random
.
RandomState
()
shape
=
frames
.
shape
bs
=
1
if
len
(
shape
)
==
4
else
shape
[
0
]
sigma
=
rnd_state
.
normal
(
loc
=
self
.
noise_mean
,
scale
=
self
.
noise_std
,
size
=
(
bs
,))
sigma
=
np
.
exp
(
sigma
)
sigma
=
np
.
expand_dims
(
sigma
,
axis
=
tuple
(
range
(
1
,
len
(
shape
))))
noise
=
rnd_state
.
randn
(
*
shape
)
*
sigma
return
frames
+
noise
def
add_mask
(
self
,
frames
:
np
.
ndarray
,
rnd_state
:
Optional
[
np
.
random
.
RandomState
]
=
None
)
->
np
.
ndarray
:
"""Add mask to frames"""
if
self
.
mask_rate
is
None
:
return
frames
if
rnd_state
is
None
:
rnd_state
=
np
.
random
.
RandomState
()
h
,
w
=
frames
.
shape
[
-
2
:]
mask
=
rnd_state
.
rand
(
h
,
w
)
>
self
.
mask_rate
return
frames
*
mask
def
process_prev_frames
(
self
,
frames
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Process previous frames with noise and masking"""
frames_np
=
frames
.
cpu
().
detach
().
numpy
()
frames_np
=
self
.
add_noise
(
frames_np
)
frames_np
=
self
.
add_mask
(
frames_np
)
return
torch
.
from_numpy
(
frames_np
).
to
(
dtype
=
frames
.
dtype
,
device
=
frames
.
device
)
class
AudioProcessor
:
"""Handles audio loading and segmentation"""
def
__init__
(
self
,
audio_sr
:
int
=
16000
,
target_fps
:
int
=
16
):
self
.
audio_sr
=
audio_sr
self
.
target_fps
=
target_fps
def
load_audio
(
self
,
audio_path
:
str
)
->
np
.
ndarray
:
"""Load and resample audio"""
audio_array
,
ori_sr
=
ta
.
load
(
audio_path
)
audio_array
=
ta
.
functional
.
resample
(
audio_array
.
mean
(
0
),
orig_freq
=
ori_sr
,
new_freq
=
self
.
audio_sr
)
return
audio_array
.
numpy
()
def
get_audio_range
(
self
,
start_frame
:
int
,
end_frame
:
int
)
->
Tuple
[
int
,
int
]:
"""Calculate audio range for given frame range"""
audio_frame_rate
=
self
.
audio_sr
/
self
.
target_fps
return
round
(
start_frame
*
audio_frame_rate
),
round
((
end_frame
+
1
)
*
audio_frame_rate
)
def
segment_audio
(
self
,
audio_array
:
np
.
ndarray
,
expected_frames
:
int
,
max_num_frames
:
int
,
prev_frame_length
:
int
=
5
)
->
List
[
AudioSegment
]:
"""Segment audio based on frame requirements"""
segments
=
[]
# Calculate intervals
interval_num
=
1
res_frame_num
=
0
if
expected_frames
<=
max_num_frames
:
interval_num
=
1
else
:
interval_num
=
max
(
int
((
expected_frames
-
max_num_frames
)
/
(
max_num_frames
-
prev_frame_length
))
+
1
,
1
)
res_frame_num
=
expected_frames
-
interval_num
*
(
max_num_frames
-
prev_frame_length
)
if
res_frame_num
>
5
:
interval_num
+=
1
# Create segments
for
idx
in
range
(
interval_num
):
if
idx
==
0
:
# First segment
audio_start
,
audio_end
=
self
.
get_audio_range
(
0
,
max_num_frames
)
segment_audio
=
audio_array
[
audio_start
:
audio_end
]
useful_length
=
None
if
expected_frames
<
max_num_frames
:
useful_length
=
segment_audio
.
shape
[
0
]
max_num_audio_length
=
int
((
max_num_frames
+
1
)
/
self
.
target_fps
*
self
.
audio_sr
)
segment_audio
=
np
.
concatenate
((
segment_audio
,
np
.
zeros
(
max_num_audio_length
-
useful_length
)),
axis
=
0
)
segments
.
append
(
AudioSegment
(
segment_audio
,
0
,
max_num_frames
,
False
,
useful_length
))
elif
res_frame_num
>
5
and
idx
==
interval_num
-
1
:
# Last segment (might be shorter)
start_frame
=
idx
*
max_num_frames
-
idx
*
prev_frame_length
audio_start
,
audio_end
=
self
.
get_audio_range
(
start_frame
,
expected_frames
)
segment_audio
=
audio_array
[
audio_start
:
audio_end
]
useful_length
=
segment_audio
.
shape
[
0
]
max_num_audio_length
=
int
((
max_num_frames
+
1
)
/
self
.
target_fps
*
self
.
audio_sr
)
segment_audio
=
np
.
concatenate
((
segment_audio
,
np
.
zeros
(
max_num_audio_length
-
useful_length
)),
axis
=
0
)
segments
.
append
(
AudioSegment
(
segment_audio
,
start_frame
,
expected_frames
,
True
,
useful_length
))
else
:
# Middle segments
start_frame
=
idx
*
max_num_frames
-
idx
*
prev_frame_length
end_frame
=
(
idx
+
1
)
*
max_num_frames
-
idx
*
prev_frame_length
audio_start
,
audio_end
=
self
.
get_audio_range
(
start_frame
,
end_frame
)
segment_audio
=
audio_array
[
audio_start
:
audio_end
]
segments
.
append
(
AudioSegment
(
segment_audio
,
start_frame
,
end_frame
,
False
))
return
segments
class
VideoGenerator
:
"""Handles video generation for each segment"""
def
__init__
(
self
,
model
,
vae_encoder
,
vae_decoder
,
config
):
self
.
model
=
model
self
.
vae_encoder
=
vae_encoder
self
.
vae_decoder
=
vae_decoder
self
.
config
=
config
self
.
frame_preprocessor
=
FramePreprocessor
()
def
prepare_prev_latents
(
self
,
prev_video
:
Optional
[
torch
.
Tensor
],
prev_frame_length
:
int
)
->
Optional
[
Dict
[
str
,
torch
.
Tensor
]]:
"""Prepare previous latents for conditioning"""
if
prev_video
is
None
:
return
None
device
=
self
.
model
.
device
dtype
=
torch
.
bfloat16
vae_dtype
=
torch
.
float
tgt_h
,
tgt_w
=
self
.
config
.
tgt_h
,
self
.
config
.
tgt_w
prev_frames
=
torch
.
zeros
((
1
,
3
,
self
.
config
.
target_video_length
,
tgt_h
,
tgt_w
),
device
=
device
)
# Extract and process last frames
last_frames
=
prev_video
[:,
:,
-
prev_frame_length
:].
clone
().
to
(
device
)
last_frames
=
self
.
frame_preprocessor
.
process_prev_frames
(
last_frames
)
prev_frames
[:,
:,
:
prev_frame_length
]
=
last_frames
prev_latents
=
self
.
vae_encoder
.
encode
(
prev_frames
.
to
(
vae_dtype
),
self
.
config
)[
0
].
to
(
dtype
)
# Create mask
prev_token_length
=
(
prev_frame_length
-
1
)
//
4
+
1
_
,
nframe
,
height
,
width
=
self
.
model
.
scheduler
.
latents
.
shape
frames_n
=
(
nframe
-
1
)
*
4
+
1
prev_frame_len
=
max
((
prev_token_length
-
1
)
*
4
+
1
,
0
)
prev_mask
=
torch
.
ones
((
1
,
frames_n
,
height
,
width
),
device
=
device
,
dtype
=
dtype
)
prev_mask
[:,
prev_frame_len
:]
=
0
prev_mask
=
self
.
_wan_mask_rearrange
(
prev_mask
).
unsqueeze
(
0
)
return
{
"prev_latents"
:
prev_latents
,
"prev_mask"
:
prev_mask
}
def
_wan_mask_rearrange
(
self
,
mask
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Rearrange mask for WAN model"""
if
mask
.
ndim
==
3
:
mask
=
mask
[
None
]
assert
mask
.
ndim
==
4
_
,
t
,
h
,
w
=
mask
.
shape
assert
t
==
((
t
-
1
)
//
4
*
4
+
1
)
mask_first_frame
=
torch
.
repeat_interleave
(
mask
[:,
0
:
1
],
repeats
=
4
,
dim
=
1
)
mask
=
torch
.
concat
([
mask_first_frame
,
mask
[:,
1
:]],
dim
=
1
)
mask
=
mask
.
view
(
mask
.
shape
[
1
]
//
4
,
4
,
h
,
w
)
return
mask
.
transpose
(
0
,
1
)
@
torch
.
no_grad
()
def
generate_segment
(
self
,
inputs
:
Dict
[
str
,
Any
],
audio_features
:
torch
.
Tensor
,
prev_video
:
Optional
[
torch
.
Tensor
]
=
None
,
prev_frame_length
:
int
=
5
,
segment_idx
:
int
=
0
)
->
torch
.
Tensor
:
"""Generate video segment"""
# Update inputs with audio features
inputs
[
"audio_encoder_output"
]
=
audio_features
# Reset scheduler for non-first segments
if
segment_idx
>
0
:
self
.
model
.
scheduler
.
reset
()
# Prepare previous latents - ALWAYS needed, even for first segment
device
=
self
.
model
.
device
dtype
=
torch
.
bfloat16
vae_dtype
=
torch
.
float
tgt_h
,
tgt_w
=
self
.
config
.
tgt_h
,
self
.
config
.
tgt_w
max_num_frames
=
self
.
config
.
target_video_length
if
segment_idx
==
0
:
# First segment - create zero frames
prev_frames
=
torch
.
zeros
((
1
,
3
,
max_num_frames
,
tgt_h
,
tgt_w
),
device
=
device
)
prev_latents
=
self
.
vae_encoder
.
encode
(
prev_frames
.
to
(
vae_dtype
),
self
.
config
)[
0
].
to
(
dtype
)
prev_len
=
0
else
:
# Subsequent segments - use previous video
previmg_encoder_output
=
self
.
prepare_prev_latents
(
prev_video
,
prev_frame_length
)
if
previmg_encoder_output
:
prev_latents
=
previmg_encoder_output
[
"prev_latents"
]
prev_len
=
(
prev_frame_length
-
1
)
//
4
+
1
else
:
# Fallback to zeros if prepare_prev_latents fails
prev_frames
=
torch
.
zeros
((
1
,
3
,
max_num_frames
,
tgt_h
,
tgt_w
),
device
=
device
)
prev_latents
=
self
.
vae_encoder
.
encode
(
prev_frames
.
to
(
vae_dtype
),
self
.
config
)[
0
].
to
(
dtype
)
prev_len
=
0
# Create mask for prev_latents
_
,
nframe
,
height
,
width
=
self
.
model
.
scheduler
.
latents
.
shape
frames_n
=
(
nframe
-
1
)
*
4
+
1
prev_frame_len
=
max
((
prev_len
-
1
)
*
4
+
1
,
0
)
prev_mask
=
torch
.
ones
((
1
,
frames_n
,
height
,
width
),
device
=
device
,
dtype
=
dtype
)
prev_mask
[:,
prev_frame_len
:]
=
0
prev_mask
=
self
.
_wan_mask_rearrange
(
prev_mask
).
unsqueeze
(
0
)
if
rnd_state
is
None
:
# Always set previmg_encoder_output
rnd_state
=
np
.
random
.
RandomState
()
inputs
[
"previmg_encoder_output"
]
=
{
"prev_latents"
:
prev_latents
,
"prev_mask"
:
prev_mask
}
h
,
w
=
frames
.
shape
[
-
2
:]
# Run inference loop
mask
=
rnd_state
.
rand
(
h
,
w
)
>
mask_rate
for
step_index
in
range
(
self
.
model
.
scheduler
.
infer_steps
):
frames
=
frames
*
mask
logger
.
info
(
f
"==> Segment
{
segment_idx
}
, Step
{
step_index
}
/
{
self
.
model
.
scheduler
.
infer_steps
}
"
)
return
frames
with
ProfilingContext4Debug
(
"step_pre"
):
self
.
model
.
scheduler
.
step_pre
(
step_index
=
step_index
)
def
add_noise_to_frames
(
with
ProfilingContext4Debug
(
"infer"
):
frames
:
np
.
ndarray
,
self
.
model
.
infer
(
inputs
)
noise_mean
:
float
=
-
3.0
,
noise_std
:
float
=
0.5
,
rnd_state
:
np
.
random
.
RandomState
=
None
,
)
->
np
.
ndarray
:
if
noise_mean
is
None
or
noise_std
is
None
:
return
frames
if
rnd_state
is
None
:
with
ProfilingContext4Debug
(
"step_post"
)
:
rnd_state
=
np
.
random
.
RandomState
()
self
.
model
.
scheduler
.
step_post
()
shape
=
frames
.
shape
# Decode latents
bs
=
1
if
len
(
shape
)
==
4
else
shape
[
0
]
latents
=
self
.
model
.
scheduler
.
latents
sigma
=
rnd_state
.
normal
(
loc
=
noise_mean
,
scale
=
noise_std
,
size
=
(
bs
,))
generator
=
self
.
model
.
scheduler
.
generator
sigma
=
np
.
exp
(
sigma
)
gen_video
=
self
.
vae_decoder
.
decode
(
latents
,
generator
=
generator
,
config
=
self
.
config
)
sigma
=
np
.
expand_dims
(
sigma
,
axis
=
tuple
(
range
(
1
,
len
(
shape
))))
gen_video
=
torch
.
clamp
(
gen_video
,
-
1
,
1
).
to
(
torch
.
float
)
noise
=
rnd_state
.
randn
(
*
shape
)
*
sigma
frames
=
frames
+
noise
return
gen_video
return
frames
def
get_crop_bbox
(
ori_h
,
ori_w
,
tgt_h
,
tgt_w
):
def
get_crop_bbox
(
ori_h
,
ori_w
,
tgt_h
,
tgt_w
):
...
@@ -131,221 +355,226 @@ def adaptive_resize(img):
...
@@ -131,221 +355,226 @@ def adaptive_resize(img):
return
cropped_img
,
target_h
,
target_w
return
cropped_img
,
target_h
,
target_w
def
array_to_video
(
@
RUNNER_REGISTER
(
"wan2.1_audio"
)
image_array
:
np
.
ndarray
,
class
WanAudioRunner
(
WanRunner
):
output_path
:
str
,
def
__init__
(
self
,
config
):
fps
:
int
|
float
=
30
,
super
().
__init__
(
config
)
resolution
:
tuple
[
int
,
int
]
|
tuple
[
float
,
float
]
|
None
=
None
,
self
.
_is_initialized
=
False
disable_log
:
bool
=
False
,
self
.
_audio_adapter_pipe
=
None
lossless
:
bool
=
True
,
self
.
_audio_processor
=
None
output_pix_fmt
:
str
=
"yuv420p"
,
self
.
_video_generator
=
None
)
->
None
:
self
.
_audio_preprocess
=
None
"""Convert an array to a video directly, gif not supported.
Args:
image_array (np.ndarray): shape should be (f * h * w * 3).
output_path (str): output video file path.
fps (Union[int, float, optional): fps. Defaults to 30.
resolution (Optional[Union[Tuple[int, int], Tuple[float, float]]],
optional): (height, width) of the output video.
Defaults to None.
disable_log (bool, optional): whether close the ffmepg command info.
Defaults to False.
output_pix_fmt (str): output pix_fmt in ffmpeg command.
Raises:
FileNotFoundError: check output path.
TypeError: check input array.
Returns:
None.
"""
if
not
isinstance
(
image_array
,
np
.
ndarray
):
raise
TypeError
(
"Input should be np.ndarray."
)
assert
image_array
.
ndim
==
4
assert
image_array
.
shape
[
-
1
]
==
3
if
resolution
:
height
,
width
=
resolution
width
+=
width
%
2
height
+=
height
%
2
else
:
image_array
=
pad_for_libx264
(
image_array
)
height
,
width
=
image_array
.
shape
[
1
],
image_array
.
shape
[
2
]
if
lossless
:
command
=
[
"/usr/bin/ffmpeg"
,
"-y"
,
# (optional) overwrite output file if it exists
"-f"
,
"rawvideo"
,
"-s"
,
f
"
{
int
(
width
)
}
x
{
int
(
height
)
}
"
,
# size of one frame
"-pix_fmt"
,
"bgr24"
,
"-r"
,
f
"
{
fps
}
"
,
# frames per second
"-loglevel"
,
"error"
,
"-threads"
,
"4"
,
"-i"
,
"-"
,
# The input comes from a pipe
"-vcodec"
,
"libx264rgb"
,
"-crf"
,
"0"
,
"-an"
,
# Tells FFMPEG not to expect any audio
output_path
,
]
else
:
output_pix_fmt
=
output_pix_fmt
or
"yuv420p"
command
=
[
"/usr/bin/ffmpeg"
,
"-y"
,
# (optional) overwrite output file if it exists
"-f"
,
"rawvideo"
,
"-s"
,
f
"
{
int
(
width
)
}
x
{
int
(
height
)
}
"
,
# size of one frame
"-pix_fmt"
,
"bgr24"
,
"-r"
,
f
"
{
fps
}
"
,
# frames per second
"-loglevel"
,
"error"
,
"-threads"
,
"4"
,
"-i"
,
"-"
,
# The input comes from a pipe
"-vcodec"
,
"libx264"
,
"-pix_fmt"
,
f
"
{
output_pix_fmt
}
"
,
"-an"
,
# Tells FFMPEG not to expect any audio
output_path
,
]
if
output_pix_fmt
is
not
None
:
command
+=
[
"-pix_fmt"
,
output_pix_fmt
]
if
not
disable_log
:
print
(
f
'Running "
{
" "
.
join
(
command
)
}
"'
)
process
=
subprocess
.
Popen
(
command
,
stdin
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
,
)
if
process
.
stdin
is
None
or
process
.
stderr
is
None
:
raise
BrokenPipeError
(
"No buffer received."
)
index
=
0
while
True
:
if
index
>=
image_array
.
shape
[
0
]:
break
process
.
stdin
.
write
(
image_array
[
index
].
tobytes
())
index
+=
1
process
.
stdin
.
close
()
process
.
stderr
.
close
()
process
.
wait
()
def
pad_for_libx264
(
image_array
):
if
image_array
.
ndim
==
2
or
(
image_array
.
ndim
==
3
and
image_array
.
shape
[
2
]
==
3
):
hei_index
=
0
wid_index
=
1
elif
image_array
.
ndim
==
4
or
(
image_array
.
ndim
==
3
and
image_array
.
shape
[
2
]
!=
3
):
hei_index
=
1
wid_index
=
2
else
:
return
image_array
hei_pad
=
image_array
.
shape
[
hei_index
]
%
2
wid_pad
=
image_array
.
shape
[
wid_index
]
%
2
if
hei_pad
+
wid_pad
>
0
:
pad_width
=
[]
for
dim_index
in
range
(
image_array
.
ndim
):
if
dim_index
==
hei_index
:
pad_width
.
append
((
0
,
hei_pad
))
elif
dim_index
==
wid_index
:
pad_width
.
append
((
0
,
wid_pad
))
else
:
pad_width
.
append
((
0
,
0
))
values
=
0
image_array
=
np
.
pad
(
image_array
,
pad_width
,
mode
=
"constant"
,
constant_values
=
values
)
return
image_array
def
generate_unique_path
(
path
):
if
not
os
.
path
.
exists
(
path
):
return
path
root
,
ext
=
os
.
path
.
splitext
(
path
)
index
=
1
new_path
=
f
"
{
root
}
-
{
index
}{
ext
}
"
while
os
.
path
.
exists
(
new_path
):
index
+=
1
new_path
=
f
"
{
root
}
-
{
index
}{
ext
}
"
return
new_path
def
save_audio
(
audio_array
,
audio_name
:
str
,
video_name
:
str
,
sr
:
int
=
16000
,
output_path
:
Optional
[
str
]
=
None
,
):
logger
.
info
(
f
"Saving audio to
{
audio_name
}
type:
{
type
(
audio_array
)
}
"
)
ta
.
save
(
audio_name
,
torch
.
tensor
(
audio_array
[
None
]),
sample_rate
=
sr
,
)
if
output_path
is
None
:
out_video
=
f
"
{
video_name
[:
-
4
]
}
_with_audio.mp4"
else
:
out_video
=
output_path
parent_dir
=
os
.
path
.
dirname
(
out_video
)
def
initialize_once
(
self
):
if
parent_dir
and
not
os
.
path
.
exists
(
parent_dir
):
"""Initialize all models once for multiple runs"""
os
.
makedirs
(
parent_dir
,
exist_ok
=
True
)
if
self
.
_is_initialized
:
return
if
os
.
path
.
exists
(
out_video
):
logger
.
info
(
"Initializing models (one-time setup)..."
)
os
.
remove
(
out_video
)
subprocess
.
call
([
"/usr/bin/ffmpeg"
,
"-y"
,
"-i"
,
video_name
,
"-i"
,
audio_name
,
out_video
])
# Initialize audio processor
audio_sr
=
self
.
config
.
get
(
"audio_sr"
,
16000
)
target_fps
=
self
.
config
.
get
(
"target_fps"
,
16
)
self
.
_audio_processor
=
AudioProcessor
(
audio_sr
,
target_fps
)
return
out_video
# Load audio feature extractor
self
.
_audio_preprocess
=
AutoFeatureExtractor
.
from_pretrained
(
self
.
config
[
"model_path"
],
subfolder
=
"audio_encoder"
)
# Initialize scheduler
self
.
init_scheduler
()
@
RUNNER_REGISTER
(
"wan2.1_audio"
)
self
.
_is_initialized
=
True
class
WanAudioRunner
(
WanRunner
):
logger
.
info
(
"Model initialization complete"
)
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
def
init_scheduler
(
self
):
def
init_scheduler
(
self
):
"""Initialize consistency model scheduler"""
scheduler
=
ConsistencyModelScheduler
(
self
.
config
)
scheduler
=
ConsistencyModelScheduler
(
self
.
config
)
self
.
model
.
set_scheduler
(
scheduler
)
self
.
model
.
set_scheduler
(
scheduler
)
def
load_audio_models
(
self
):
def
load_audio_adapter_lazy
(
self
):
##音频特征提取器
"""Lazy load audio adapter when needed"""
self
.
audio_preprocess
=
AutoFeatureExtractor
.
from_pretrained
(
self
.
config
[
"model_path"
],
subfolder
=
"audio_encoder"
)
if
self
.
_audio_adapter_pipe
is
not
None
:
return
self
.
_audio_adapter_pipe
#
#音频驱动视频生成
adapter
#
Audio
adapter
audio_adapter_path
=
self
.
config
[
"model_path"
]
+
"/audio_adapter.safetensors"
audio_adapter_path
=
self
.
config
[
"model_path"
]
+
"/audio_adapter.safetensors"
audio_adaper
=
AudioAdapter
.
from_transformer
(
audio_adap
t
er
=
AudioAdapter
.
from_transformer
(
self
.
model
,
self
.
model
,
audio_feature_dim
=
1024
,
audio_feature_dim
=
1024
,
interval
=
1
,
interval
=
1
,
time_freq_dim
=
256
,
time_freq_dim
=
256
,
projection_transformer_layers
=
4
,
projection_transformer_layers
=
4
,
)
)
audio_adapter
=
rank0_load_state_dict_from_path
(
audio_adaper
,
audio_adapter_path
,
strict
=
False
)
audio_adapter
=
rank0_load_state_dict_from_path
(
audio_adap
t
er
,
audio_adapter_path
,
strict
=
False
)
#
#音频特征编码器
#
Audio encoder
device
=
self
.
model
.
device
device
=
self
.
model
.
device
audio_encoder_repo
=
self
.
config
[
"model_path"
]
+
"/audio_encoder"
audio_encoder_repo
=
self
.
config
[
"model_path"
]
+
"/audio_encoder"
audio_adapter_pipe
=
AudioAdapterPipe
(
audio_adapter
,
audio_encoder_repo
=
audio_encoder_repo
,
dtype
=
torch
.
bfloat16
,
device
=
device
,
generator
=
torch
.
Generator
(
device
),
weight
=
1.0
)
self
.
_
audio_adapter_pipe
=
AudioAdapterPipe
(
audio_adapter
,
audio_encoder_repo
=
audio_encoder_repo
,
dtype
=
torch
.
bfloat16
,
device
=
device
,
generator
=
torch
.
Generator
(
device
),
weight
=
1.0
)
return
audio_adapter_pipe
return
self
.
_audio_adapter_pipe
def
prepare_inputs
(
self
):
"""Prepare inputs for the model"""
image_encoder_output
=
None
if
os
.
path
.
isfile
(
self
.
config
.
image_path
):
with
ProfilingContext
(
"Run Img Encoder"
):
vae_encode_out
,
clip_encoder_out
=
self
.
run_image_encoder
(
self
.
config
,
self
.
vae_encoder
)
image_encoder_output
=
{
"clip_encoder_out"
:
clip_encoder_out
,
"vae_encode_out"
:
vae_encode_out
,
}
with
ProfilingContext
(
"Run Text Encoder"
):
img
=
Image
.
open
(
self
.
config
[
"image_path"
]).
convert
(
"RGB"
)
text_encoder_output
=
self
.
run_text_encoder
(
self
.
config
[
"prompt"
],
img
)
self
.
set_target_shape
()
return
{
"text_encoder_output"
:
text_encoder_output
,
"image_encoder_output"
:
image_encoder_output
,
"audio_adapter_pipe"
:
self
.
load_audio_adapter_lazy
()}
def
run_pipeline
(
self
,
save_video
=
True
):
"""Optimized pipeline with modular components"""
# Ensure models are initialized
self
.
initialize_once
()
# Initialize video generator if needed
if
self
.
_video_generator
is
None
:
self
.
_video_generator
=
VideoGenerator
(
self
.
model
,
self
.
vae_encoder
,
self
.
vae_decoder
,
self
.
config
)
# Prepare inputs
with
memory_efficient_inference
():
if
self
.
config
[
"use_prompt_enhancer"
]:
self
.
config
[
"prompt_enhanced"
]
=
self
.
post_prompt_enhancer
()
self
.
inputs
=
self
.
prepare_inputs
()
self
.
model
.
scheduler
.
prepare
(
self
.
inputs
[
"image_encoder_output"
])
# Process audio
audio_array
=
self
.
_audio_processor
.
load_audio
(
self
.
config
[
"audio_path"
])
video_duration
=
self
.
config
.
get
(
"video_duration"
,
5
)
target_fps
=
self
.
config
.
get
(
"target_fps"
,
16
)
max_num_frames
=
self
.
config
.
get
(
"target_video_length"
,
81
)
audio_len
=
int
(
audio_array
.
shape
[
0
]
/
self
.
_audio_processor
.
audio_sr
*
target_fps
)
expected_frames
=
min
(
max
(
1
,
int
(
video_duration
*
target_fps
)),
audio_len
)
# Segment audio
audio_segments
=
self
.
_audio_processor
.
segment_audio
(
audio_array
,
expected_frames
,
max_num_frames
)
# Generate video segments
gen_video_list
=
[]
cut_audio_list
=
[]
prev_video
=
None
for
idx
,
segment
in
enumerate
(
audio_segments
):
# Update seed for each segment
self
.
config
.
seed
=
self
.
config
.
seed
+
idx
torch
.
manual_seed
(
self
.
config
.
seed
)
logger
.
info
(
f
"Processing segment
{
idx
+
1
}
/
{
len
(
audio_segments
)
}
, seed:
{
self
.
config
.
seed
}
"
)
# Process audio features
audio_features
=
self
.
_audio_preprocess
(
segment
.
audio_array
,
sampling_rate
=
self
.
_audio_processor
.
audio_sr
,
return_tensors
=
"pt"
).
input_values
.
squeeze
(
0
).
to
(
self
.
model
.
device
)
# Generate video segment
with
memory_efficient_inference
():
gen_video
=
self
.
_video_generator
.
generate_segment
(
self
.
inputs
.
copy
(),
# Copy to avoid modifying original
audio_features
,
prev_video
=
prev_video
,
prev_frame_length
=
5
,
segment_idx
=
idx
,
)
# Extract relevant frames
start_frame
=
0
if
idx
==
0
else
5
start_audio_frame
=
0
if
idx
==
0
else
int
(
6
*
self
.
_audio_processor
.
audio_sr
/
target_fps
)
if
segment
.
is_last
and
segment
.
useful_length
:
end_frame
=
segment
.
end_frame
-
segment
.
start_frame
gen_video_list
.
append
(
gen_video
[:,
:,
start_frame
:
end_frame
].
cpu
())
cut_audio_list
.
append
(
segment
.
audio_array
[
start_audio_frame
:
segment
.
useful_length
])
elif
segment
.
useful_length
and
expected_frames
<
max_num_frames
:
gen_video_list
.
append
(
gen_video
[:,
:,
start_frame
:
expected_frames
].
cpu
())
cut_audio_list
.
append
(
segment
.
audio_array
[
start_audio_frame
:
segment
.
useful_length
])
else
:
gen_video_list
.
append
(
gen_video
[:,
:,
start_frame
:].
cpu
())
cut_audio_list
.
append
(
segment
.
audio_array
[
start_audio_frame
:])
# Update prev_video for next iteration
prev_video
=
gen_video
# Clean up GPU memory after each segment
del
gen_video
torch
.
cuda
.
empty_cache
()
# Merge results
with
memory_efficient_inference
():
gen_lvideo
=
torch
.
cat
(
gen_video_list
,
dim
=
2
).
float
()
merge_audio
=
np
.
concatenate
(
cut_audio_list
,
axis
=
0
).
astype
(
np
.
float32
)
comfyui_images
=
vae_to_comfyui_image
(
gen_lvideo
)
# Apply frame interpolation if configured
if
"video_frame_interpolation"
in
self
.
config
and
self
.
vfi_model
is
not
None
:
interpolation_target_fps
=
self
.
config
[
"video_frame_interpolation"
][
"target_fps"
]
logger
.
info
(
f
"Interpolating frames from
{
target_fps
}
to
{
interpolation_target_fps
}
"
)
comfyui_images
=
self
.
vfi_model
.
interpolate_frames
(
comfyui_images
,
source_fps
=
target_fps
,
target_fps
=
interpolation_target_fps
,
)
target_fps
=
interpolation_target_fps
# Convert audio to ComfyUI format
audio_waveform
=
torch
.
from_numpy
(
merge_audio
).
unsqueeze
(
0
).
unsqueeze
(
0
)
comfyui_audio
=
{
"waveform"
:
audio_waveform
,
"sample_rate"
:
self
.
_audio_processor
.
audio_sr
}
# Save video if requested
if
save_video
and
self
.
config
.
get
(
"save_video_path"
,
None
):
self
.
_save_video_with_audio
(
comfyui_images
,
merge_audio
,
target_fps
)
# Final cleanup
self
.
end_run
()
return
comfyui_images
,
comfyui_audio
def
_save_video_with_audio
(
self
,
images
,
audio_array
,
fps
):
"""Save video with audio"""
import
tempfile
with
tempfile
.
NamedTemporaryFile
(
suffix
=
".mp4"
,
delete
=
False
)
as
video_tmp
:
video_path
=
video_tmp
.
name
with
tempfile
.
NamedTemporaryFile
(
suffix
=
".wav"
,
delete
=
False
)
as
audio_tmp
:
audio_path
=
audio_tmp
.
name
try
:
# Save video
save_to_video
(
images
,
video_path
,
fps
)
# Save audio
ta
.
save
(
audio_path
,
torch
.
tensor
(
audio_array
[
None
]),
sample_rate
=
self
.
_audio_processor
.
audio_sr
)
# Merge video and audio
output_path
=
self
.
config
.
get
(
"save_video_path"
)
parent_dir
=
os
.
path
.
dirname
(
output_path
)
if
parent_dir
and
not
os
.
path
.
exists
(
parent_dir
):
os
.
makedirs
(
parent_dir
,
exist_ok
=
True
)
subprocess
.
call
([
"/usr/bin/ffmpeg"
,
"-y"
,
"-i"
,
video_path
,
"-i"
,
audio_path
,
output_path
])
logger
.
info
(
f
"Saved video with audio to:
{
output_path
}
"
)
finally
:
# Clean up temp files
if
os
.
path
.
exists
(
video_path
):
os
.
remove
(
video_path
)
if
os
.
path
.
exists
(
audio_path
):
os
.
remove
(
audio_path
)
def
load_transformer
(
self
):
def
load_transformer
(
self
):
"""Load transformer with LoRA support"""
base_model
=
WanAudioModel
(
self
.
config
.
model_path
,
self
.
config
,
self
.
init_device
)
base_model
=
WanAudioModel
(
self
.
config
.
model_path
,
self
.
config
,
self
.
init_device
)
if
self
.
config
.
get
(
"lora_configs"
)
and
self
.
config
.
lora_configs
:
if
self
.
config
.
get
(
"lora_configs"
)
and
self
.
config
.
lora_configs
:
...
@@ -361,19 +590,21 @@ class WanAudioRunner(WanRunner):
...
@@ -361,19 +590,21 @@ class WanAudioRunner(WanRunner):
return
base_model
return
base_model
def
load_image_encoder
(
self
):
def
load_image_encoder
(
self
):
"""Load image encoder"""
clip_model_dir
=
self
.
config
[
"model_path"
]
+
"/image_encoder"
clip_model_dir
=
self
.
config
[
"model_path"
]
+
"/image_encoder"
image_encoder
=
WanVideoIPHandler
(
"CLIPModel"
,
repo_or_path
=
clip_model_dir
,
require_grad
=
False
,
mode
=
"eval"
,
device
=
self
.
init_device
,
dtype
=
torch
.
float16
)
image_encoder
=
WanVideoIPHandler
(
"CLIPModel"
,
repo_or_path
=
clip_model_dir
,
require_grad
=
False
,
mode
=
"eval"
,
device
=
self
.
init_device
,
dtype
=
torch
.
float16
)
return
image_encoder
return
image_encoder
def
run_image_encoder
(
self
,
config
,
vae_model
):
def
run_image_encoder
(
self
,
config
,
vae_model
):
"""Run image encoder"""
ref_img
=
Image
.
open
(
config
.
image_path
)
ref_img
=
Image
.
open
(
config
.
image_path
)
ref_img
=
(
np
.
array
(
ref_img
).
astype
(
np
.
float32
)
-
127.5
)
/
127.5
ref_img
=
(
np
.
array
(
ref_img
).
astype
(
np
.
float32
)
-
127.5
)
/
127.5
ref_img
=
torch
.
from_numpy
(
ref_img
).
to
(
vae_model
.
device
)
ref_img
=
torch
.
from_numpy
(
ref_img
).
to
(
vae_model
.
device
)
ref_img
=
rearrange
(
ref_img
,
"H W C -> 1 C H W"
)
ref_img
=
rearrange
(
ref_img
,
"H W C -> 1 C H W"
)
ref_img
=
ref_img
[:,
:
3
]
ref_img
=
ref_img
[:,
:
3
]
#
r
esize and crop image
#
R
esize and crop image
cond_frms
,
tgt_h
,
tgt_w
=
adaptive_resize
(
ref_img
)
cond_frms
,
tgt_h
,
tgt_w
=
adaptive_resize
(
ref_img
)
config
.
tgt_h
=
tgt_h
config
.
tgt_h
=
tgt_h
config
.
tgt_w
=
tgt_w
config
.
tgt_w
=
tgt_w
...
@@ -384,36 +615,13 @@ class WanAudioRunner(WanRunner):
...
@@ -384,36 +615,13 @@ class WanAudioRunner(WanRunner):
config
.
lat_h
=
lat_h
config
.
lat_h
=
lat_h
config
.
lat_w
=
lat_w
config
.
lat_w
=
lat_w
vae_encode_out
=
vae_model
.
encode
(
cond_frms
.
to
(
torch
.
float
),
config
)
vae_encode_out
=
vae_model
.
encode
(
cond_frms
.
to
(
torch
.
float
),
config
)
if
isinstance
(
vae_encode_out
,
list
):
#
if
isinstance
(
vae_encode_out
,
list
):
# list转tensor
vae_encode_out
=
torch
.
stack
(
vae_encode_out
,
dim
=
0
).
to
(
torch
.
bfloat16
)
vae_encode_out
=
torch
.
stack
(
vae_encode_out
,
dim
=
0
).
to
(
torch
.
bfloat16
)
return
vae_encode_out
,
clip_encoder_out
return
vae_encode_out
,
clip_encoder_out
def
run_input_encoder_internal
(
self
):
image_encoder_output
=
None
if
os
.
path
.
isfile
(
self
.
config
.
image_path
):
with
ProfilingContext
(
"Run Img Encoder"
):
vae_encode_out
,
clip_encoder_out
=
self
.
run_image_encoder
(
self
.
config
,
self
.
vae_encoder
)
image_encoder_output
=
{
"clip_encoder_out"
:
clip_encoder_out
,
"vae_encode_out"
:
vae_encode_out
,
}
logger
.
info
(
f
"clip_encoder_out:
{
clip_encoder_out
.
shape
}
vae_encode_out:
{
vae_encode_out
.
shape
}
"
)
with
ProfilingContext
(
"Run Text Encoder"
):
logger
.
info
(
f
"Prompt:
{
self
.
config
[
'prompt'
]
}
"
)
img
=
Image
.
open
(
self
.
config
[
"image_path"
]).
convert
(
"RGB"
)
text_encoder_output
=
self
.
run_text_encoder
(
self
.
config
[
"prompt"
],
img
)
self
.
set_target_shape
()
self
.
inputs
=
{
"text_encoder_output"
:
text_encoder_output
,
"image_encoder_output"
:
image_encoder_output
}
# del self.image_encoder # 删除ref的clip模型,只使用一次
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
def
set_target_shape
(
self
):
def
set_target_shape
(
self
):
"""Set target shape for generation"""
ret
=
{}
ret
=
{}
num_channels_latents
=
16
num_channels_latents
=
16
if
self
.
config
.
task
==
"i2v"
:
if
self
.
config
.
task
==
"i2v"
:
...
@@ -427,216 +635,7 @@ class WanAudioRunner(WanRunner):
...
@@ -427,216 +635,7 @@ class WanAudioRunner(WanRunner):
ret
[
"lat_w"
]
=
self
.
config
.
lat_w
ret
[
"lat_w"
]
=
self
.
config
.
lat_w
else
:
else
:
error_msg
=
"t2v task is not supported in WanAudioRunner"
error_msg
=
"t2v task is not supported in WanAudioRunner"
assert
1
==
0
,
error_msg
assert
False
,
error_msg
ret
[
"target_shape"
]
=
self
.
config
.
target_shape
ret
[
"target_shape"
]
=
self
.
config
.
target_shape
return
ret
return
ret
def
run
(
self
,
save_video
=
True
):
def
load_audio
(
in_path
:
str
,
sr
:
float
=
16000
):
audio_array
,
ori_sr
=
ta
.
load
(
in_path
)
audio_array
=
ta
.
functional
.
resample
(
audio_array
.
mean
(
0
),
orig_freq
=
ori_sr
,
new_freq
=
sr
)
return
audio_array
.
numpy
()
def
get_audio_range
(
start_frame
:
int
,
end_frame
:
int
,
fps
:
float
,
audio_sr
:
float
=
16000
):
audio_frame_rate
=
audio_sr
/
fps
return
round
(
start_frame
*
audio_frame_rate
),
round
((
end_frame
+
1
)
*
audio_frame_rate
)
def
wan_mask_rearrange
(
mask
:
torch
.
Tensor
):
# mask: 1, T, H, W, where 1 means the input mask is one-channel
if
mask
.
ndim
==
3
:
mask
=
mask
[
None
]
assert
mask
.
ndim
==
4
_
,
t
,
h
,
w
=
mask
.
shape
assert
t
==
((
t
-
1
)
//
4
*
4
+
1
)
mask_first_frame
=
torch
.
repeat_interleave
(
mask
[:,
0
:
1
],
repeats
=
4
,
dim
=
1
)
mask
=
torch
.
concat
([
mask_first_frame
,
mask
[:,
1
:]],
dim
=
1
)
mask
=
mask
.
view
(
mask
.
shape
[
1
]
//
4
,
4
,
h
,
w
)
return
mask
.
transpose
(
0
,
1
)
# 4, T // 4, H, W
self
.
inputs
[
"audio_adapter_pipe"
]
=
self
.
load_audio_models
()
# process audio
audio_sr
=
self
.
config
.
get
(
"audio_sr"
,
16000
)
max_num_frames
=
self
.
config
.
get
(
"target_video_length"
,
81
)
# wan2.1一段最多81帧,5秒,16fps
target_fps
=
self
.
config
.
get
(
"target_fps"
,
16
)
# 音视频同步帧率
video_duration
=
self
.
config
.
get
(
"video_duration"
,
5
)
# 期望视频输出时长
audio_array
=
load_audio
(
self
.
config
[
"audio_path"
],
sr
=
audio_sr
)
audio_len
=
int
(
audio_array
.
shape
[
0
]
/
audio_sr
*
target_fps
)
prev_frame_length
=
5
prev_token_length
=
(
prev_frame_length
-
1
)
//
4
+
1
max_num_audio_length
=
int
((
max_num_frames
+
1
)
/
target_fps
*
audio_sr
)
interval_num
=
1
# expected_frames
expected_frames
=
min
(
max
(
1
,
int
(
float
(
video_duration
)
*
target_fps
)),
audio_len
)
res_frame_num
=
0
if
expected_frames
<=
max_num_frames
:
interval_num
=
1
else
:
interval_num
=
max
(
int
((
expected_frames
-
max_num_frames
)
/
(
max_num_frames
-
prev_frame_length
))
+
1
,
1
)
res_frame_num
=
expected_frames
-
interval_num
*
(
max_num_frames
-
prev_frame_length
)
if
res_frame_num
>
5
:
interval_num
+=
1
audio_start
,
audio_end
=
get_audio_range
(
0
,
expected_frames
,
fps
=
target_fps
,
audio_sr
=
audio_sr
)
audio_array_ori
=
audio_array
[
audio_start
:
audio_end
]
gen_video_list
=
[]
cut_audio_list
=
[]
# reference latents
tgt_h
=
self
.
config
.
tgt_h
tgt_w
=
self
.
config
.
tgt_w
device
=
self
.
model
.
scheduler
.
latents
.
device
dtype
=
torch
.
bfloat16
vae_dtype
=
torch
.
float
for
idx
in
range
(
interval_num
):
self
.
config
.
seed
=
self
.
config
.
seed
+
idx
torch
.
manual_seed
(
self
.
config
.
seed
)
logger
.
info
(
f
"### manual_seed:
{
self
.
config
.
seed
}
####"
)
useful_length
=
-
1
if
idx
==
0
:
# 第一段 Condition padding0
prev_frames
=
torch
.
zeros
((
1
,
3
,
max_num_frames
,
tgt_h
,
tgt_w
),
device
=
device
)
prev_latents
=
self
.
vae_encoder
.
encode
(
prev_frames
.
to
(
vae_dtype
),
self
.
config
)[
0
].
to
(
dtype
)
prev_len
=
0
audio_start
,
audio_end
=
get_audio_range
(
0
,
max_num_frames
,
fps
=
target_fps
,
audio_sr
=
audio_sr
)
audio_array
=
audio_array_ori
[
audio_start
:
audio_end
]
if
expected_frames
<
max_num_frames
:
useful_length
=
audio_array
.
shape
[
0
]
audio_array
=
np
.
concatenate
((
audio_array
,
np
.
zeros
(
max_num_audio_length
)[:
max_num_audio_length
-
useful_length
]),
axis
=
0
)
audio_input_feat
=
self
.
audio_preprocess
(
audio_array
,
sampling_rate
=
audio_sr
,
return_tensors
=
"pt"
).
input_values
.
squeeze
(
0
)
elif
res_frame_num
>
5
and
idx
==
interval_num
-
1
:
# 最后一段可能不够81帧
prev_frames
=
torch
.
zeros
((
1
,
3
,
max_num_frames
,
tgt_h
,
tgt_w
),
device
=
device
)
last_frames
=
gen_video_list
[
-
1
][:,
:,
-
prev_frame_length
:].
clone
().
to
(
device
)
last_frames
=
last_frames
.
cpu
().
detach
().
numpy
()
last_frames
=
add_noise_to_frames
(
last_frames
)
last_frames
=
add_mask_to_frames
(
last_frames
,
mask_rate
=
0.1
)
# mask 0.10
last_frames
=
torch
.
from_numpy
(
last_frames
).
to
(
dtype
=
dtype
,
device
=
device
)
prev_frames
[:,
:,
:
prev_frame_length
]
=
last_frames
prev_latents
=
self
.
vae_encoder
.
encode
(
prev_frames
.
to
(
vae_dtype
),
self
.
config
)[
0
].
to
(
dtype
)
prev_len
=
prev_token_length
audio_start
,
audio_end
=
get_audio_range
(
idx
*
max_num_frames
-
idx
*
prev_frame_length
,
expected_frames
,
fps
=
target_fps
,
audio_sr
=
audio_sr
)
audio_array
=
audio_array_ori
[
audio_start
:
audio_end
]
useful_length
=
audio_array
.
shape
[
0
]
audio_array
=
np
.
concatenate
((
audio_array
,
np
.
zeros
(
max_num_audio_length
)[:
max_num_audio_length
-
useful_length
]),
axis
=
0
)
audio_input_feat
=
self
.
audio_preprocess
(
audio_array
,
sampling_rate
=
audio_sr
,
return_tensors
=
"pt"
).
input_values
.
squeeze
(
0
)
else
:
# 中间段满81帧带pre_latens
prev_frames
=
torch
.
zeros
((
1
,
3
,
max_num_frames
,
tgt_h
,
tgt_w
),
device
=
device
)
last_frames
=
gen_video_list
[
-
1
][:,
:,
-
prev_frame_length
:].
clone
().
to
(
device
)
last_frames
=
last_frames
.
cpu
().
detach
().
numpy
()
last_frames
=
add_noise_to_frames
(
last_frames
)
# mean:-3.0 std:0.5
last_frames
=
add_mask_to_frames
(
last_frames
,
mask_rate
=
0.1
)
# mask 0.10
last_frames
=
torch
.
from_numpy
(
last_frames
).
to
(
dtype
=
dtype
,
device
=
device
)
prev_frames
[:,
:,
:
prev_frame_length
]
=
last_frames
prev_latents
=
self
.
vae_encoder
.
encode
(
prev_frames
.
to
(
vae_dtype
),
self
.
config
)[
0
].
to
(
dtype
)
prev_len
=
prev_token_length
audio_start
,
audio_end
=
get_audio_range
(
idx
*
max_num_frames
-
idx
*
prev_frame_length
,
(
idx
+
1
)
*
max_num_frames
-
idx
*
prev_frame_length
,
fps
=
target_fps
,
audio_sr
=
audio_sr
)
audio_array
=
audio_array_ori
[
audio_start
:
audio_end
]
audio_input_feat
=
self
.
audio_preprocess
(
audio_array
,
sampling_rate
=
audio_sr
,
return_tensors
=
"pt"
).
input_values
.
squeeze
(
0
)
self
.
inputs
[
"audio_encoder_output"
]
=
audio_input_feat
.
to
(
device
)
if
idx
!=
0
:
self
.
model
.
scheduler
.
reset
()
if
prev_latents
is
not
None
:
_
,
nframe
,
height
,
width
=
self
.
model
.
scheduler
.
latents
.
shape
# bs = 1
frames_n
=
(
nframe
-
1
)
*
4
+
1
prev_frame_len
=
max
((
prev_len
-
1
)
*
4
+
1
,
0
)
prev_mask
=
torch
.
ones
((
1
,
frames_n
,
height
,
width
),
device
=
device
,
dtype
=
dtype
)
prev_mask
[:,
prev_frame_len
:]
=
0
prev_mask
=
wan_mask_rearrange
(
prev_mask
).
unsqueeze
(
0
)
previmg_encoder_output
=
{
"prev_latents"
:
prev_latents
,
"prev_mask"
:
prev_mask
,
}
self
.
inputs
[
"previmg_encoder_output"
]
=
previmg_encoder_output
for
step_index
in
range
(
self
.
model
.
scheduler
.
infer_steps
):
logger
.
info
(
f
"==> step_index:
{
step_index
}
/
{
self
.
model
.
scheduler
.
infer_steps
}
"
)
with
ProfilingContext4Debug
(
"step_pre"
):
self
.
model
.
scheduler
.
step_pre
(
step_index
=
step_index
)
with
ProfilingContext4Debug
(
"infer"
):
self
.
model
.
infer
(
self
.
inputs
)
with
ProfilingContext4Debug
(
"step_post"
):
self
.
model
.
scheduler
.
step_post
()
latents
=
self
.
model
.
scheduler
.
latents
generator
=
self
.
model
.
scheduler
.
generator
gen_video
=
self
.
vae_decoder
.
decode
(
latents
,
generator
=
generator
,
config
=
self
.
config
)
gen_video
=
torch
.
clamp
(
gen_video
,
-
1
,
1
).
to
(
torch
.
float
)
start_frame
=
0
if
idx
==
0
else
prev_frame_length
start_audio_frame
=
0
if
idx
==
0
else
int
((
prev_frame_length
+
1
)
*
audio_sr
/
target_fps
)
if
res_frame_num
>
5
and
idx
==
interval_num
-
1
:
gen_video_list
.
append
(
gen_video
[:,
:,
start_frame
:
res_frame_num
].
cpu
())
cut_audio_list
.
append
(
audio_array
[
start_audio_frame
:
useful_length
])
elif
expected_frames
<
max_num_frames
and
useful_length
!=
-
1
:
gen_video_list
.
append
(
gen_video
[:,
:,
start_frame
:
expected_frames
].
cpu
())
cut_audio_list
.
append
(
audio_array
[
start_audio_frame
:
useful_length
])
else
:
gen_video_list
.
append
(
gen_video
[:,
:,
start_frame
:].
cpu
())
cut_audio_list
.
append
(
audio_array
[
start_audio_frame
:])
gen_lvideo
=
torch
.
cat
(
gen_video_list
,
dim
=
2
).
float
()
merge_audio
=
np
.
concatenate
(
cut_audio_list
,
axis
=
0
).
astype
(
np
.
float32
)
comfyui_images
=
vae_to_comfyui_image
(
gen_lvideo
)
# Apply frame interpolation if configured
if
"video_frame_interpolation"
in
self
.
config
:
assert
self
.
vfi_model
is
not
None
and
self
.
config
[
"video_frame_interpolation"
].
get
(
"target_fps"
,
None
)
is
not
None
interpolation_target_fps
=
self
.
config
[
"video_frame_interpolation"
][
"target_fps"
]
logger
.
info
(
f
"Interpolating frames from
{
target_fps
}
to
{
interpolation_target_fps
}
"
)
comfyui_images
=
self
.
vfi_model
.
interpolate_frames
(
comfyui_images
,
source_fps
=
target_fps
,
target_fps
=
interpolation_target_fps
,
)
# Update target_fps for saving
target_fps
=
interpolation_target_fps
# Convert audio to ComfyUI format
# Convert numpy array to torch tensor and add batch dimension
audio_waveform
=
torch
.
from_numpy
(
merge_audio
).
unsqueeze
(
0
).
unsqueeze
(
0
)
# [batch, channels, samples]
comfyui_audio
=
{
"waveform"
:
audio_waveform
,
"sample_rate"
:
audio_sr
}
# Save video if requested
if
save_video
and
self
.
config
.
get
(
"save_video_path"
,
None
):
out_path
=
os
.
path
.
join
(
"./"
,
"video_merge.mp4"
)
audio_file
=
os
.
path
.
join
(
"./"
,
"audio_merge.wav"
)
# Use the updated target_fps (after interpolation if applied)
save_to_video
(
comfyui_images
,
out_path
,
target_fps
)
save_audio
(
merge_audio
,
audio_file
,
out_path
,
output_path
=
self
.
config
.
get
(
"save_video_path"
,
None
))
os
.
remove
(
out_path
)
os
.
remove
(
audio_file
)
return
comfyui_images
,
comfyui_audio
def
run_pipeline
(
self
,
save_video
=
True
):
if
self
.
config
[
"use_prompt_enhancer"
]:
self
.
config
[
"prompt_enhanced"
]
=
self
.
post_prompt_enhancer
()
self
.
run_input_encoder_internal
()
self
.
set_target_shape
()
self
.
init_scheduler
()
self
.
model
.
scheduler
.
prepare
(
self
.
inputs
[
"image_encoder_output"
])
images
,
audio
=
self
.
run
(
save_video
)
# run() now returns both images and audio
self
.
end_run
()
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
return
images
,
audio
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment