Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
d048b178
Commit
d048b178
authored
Jul 23, 2025
by
gaclove
Browse files
refactor: enhance WanAudioRunner to improve audio handling and frame interpolation
parent
5bd9bdbd
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
458 additions
and
459 deletions
+458
-459
lightx2v/models/runners/wan/wan_audio_runner.py
lightx2v/models/runners/wan/wan_audio_runner.py
+458
-459
No files found.
lightx2v/models/runners/wan/wan_audio_runner.py
View file @
d048b178
...
@@ -4,6 +4,10 @@ import numpy as np
...
@@ -4,6 +4,10 @@ import numpy as np
import
torch
import
torch
import
torchvision.transforms.functional
as
TF
import
torchvision.transforms.functional
as
TF
from
PIL
import
Image
from
PIL
import
Image
from
contextlib
import
contextmanager
from
typing
import
Optional
,
Tuple
,
Union
,
List
,
Dict
,
Any
from
dataclasses
import
dataclass
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.models.runners.wan.wan_runner
import
WanRunner
from
lightx2v.models.runners.wan.wan_runner
import
WanRunner
from
lightx2v.models.runners.default_runner
import
DefaultRunner
from
lightx2v.models.runners.default_runner
import
DefaultRunner
...
@@ -34,33 +38,41 @@ from torchvision.transforms.functional import resize
...
@@ -34,33 +38,41 @@ from torchvision.transforms.functional import resize
import
subprocess
import
subprocess
import
warnings
import
warnings
from
typing
import
Optional
,
Tuple
,
Union
def
add_mask_to_frames
(
@
contextmanager
frames
:
np
.
ndarray
,
def
memory_efficient_inference
():
mask_rate
:
float
=
0.1
,
"""Context manager for memory-efficient inference"""
rnd_state
:
np
.
random
.
RandomState
=
None
,
try
:
)
->
np
.
ndarray
:
yield
if
mask_rate
is
None
:
finally
:
return
frames
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
if
rnd_state
is
None
:
rnd_state
=
np
.
random
.
RandomState
()
h
,
w
=
frames
.
shape
[
-
2
:]
@
dataclass
mask
=
rnd_state
.
rand
(
h
,
w
)
>
mask_rate
class
AudioSegment
:
frames
=
frames
*
mask
"""Data class for audio segment information"""
return
frames
audio_array
:
np
.
ndarray
start_frame
:
int
end_frame
:
int
is_last
:
bool
=
False
useful_length
:
Optional
[
int
]
=
None
class
FramePreprocessor
:
"""Handles frame preprocessing including noise and masking"""
def
__init__
(
self
,
noise_mean
:
float
=
-
3.0
,
noise_std
:
float
=
0.5
,
mask_rate
:
float
=
0.1
):
self
.
noise_mean
=
noise_mean
self
.
noise_std
=
noise_std
self
.
mask_rate
=
mask_rate
def
add_noise_to_frames
(
def
add_noise
(
self
,
frames
:
np
.
ndarray
,
rnd_state
:
Optional
[
np
.
random
.
RandomState
]
=
None
)
->
np
.
ndarray
:
frames
:
np
.
ndarray
,
"""Add noise to frames"""
noise_mean
:
float
=
-
3.0
,
if
self
.
noise_mean
is
None
or
self
.
noise_std
is
None
:
noise_std
:
float
=
0.5
,
rnd_state
:
np
.
random
.
RandomState
=
None
,
)
->
np
.
ndarray
:
if
noise_mean
is
None
or
noise_std
is
None
:
return
frames
return
frames
if
rnd_state
is
None
:
if
rnd_state
is
None
:
...
@@ -68,13 +80,225 @@ def add_noise_to_frames(
...
@@ -68,13 +80,225 @@ def add_noise_to_frames(
shape
=
frames
.
shape
shape
=
frames
.
shape
bs
=
1
if
len
(
shape
)
==
4
else
shape
[
0
]
bs
=
1
if
len
(
shape
)
==
4
else
shape
[
0
]
sigma
=
rnd_state
.
normal
(
loc
=
noise_mean
,
scale
=
noise_std
,
size
=
(
bs
,))
sigma
=
rnd_state
.
normal
(
loc
=
self
.
noise_mean
,
scale
=
self
.
noise_std
,
size
=
(
bs
,))
sigma
=
np
.
exp
(
sigma
)
sigma
=
np
.
exp
(
sigma
)
sigma
=
np
.
expand_dims
(
sigma
,
axis
=
tuple
(
range
(
1
,
len
(
shape
))))
sigma
=
np
.
expand_dims
(
sigma
,
axis
=
tuple
(
range
(
1
,
len
(
shape
))))
noise
=
rnd_state
.
randn
(
*
shape
)
*
sigma
noise
=
rnd_state
.
randn
(
*
shape
)
*
sigma
frames
=
frames
+
noise
return
frames
+
noise
def
add_mask
(
self
,
frames
:
np
.
ndarray
,
rnd_state
:
Optional
[
np
.
random
.
RandomState
]
=
None
)
->
np
.
ndarray
:
"""Add mask to frames"""
if
self
.
mask_rate
is
None
:
return
frames
return
frames
if
rnd_state
is
None
:
rnd_state
=
np
.
random
.
RandomState
()
h
,
w
=
frames
.
shape
[
-
2
:]
mask
=
rnd_state
.
rand
(
h
,
w
)
>
self
.
mask_rate
return
frames
*
mask
def
process_prev_frames
(
self
,
frames
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Process previous frames with noise and masking"""
frames_np
=
frames
.
cpu
().
detach
().
numpy
()
frames_np
=
self
.
add_noise
(
frames_np
)
frames_np
=
self
.
add_mask
(
frames_np
)
return
torch
.
from_numpy
(
frames_np
).
to
(
dtype
=
frames
.
dtype
,
device
=
frames
.
device
)
class
AudioProcessor
:
"""Handles audio loading and segmentation"""
def
__init__
(
self
,
audio_sr
:
int
=
16000
,
target_fps
:
int
=
16
):
self
.
audio_sr
=
audio_sr
self
.
target_fps
=
target_fps
def
load_audio
(
self
,
audio_path
:
str
)
->
np
.
ndarray
:
"""Load and resample audio"""
audio_array
,
ori_sr
=
ta
.
load
(
audio_path
)
audio_array
=
ta
.
functional
.
resample
(
audio_array
.
mean
(
0
),
orig_freq
=
ori_sr
,
new_freq
=
self
.
audio_sr
)
return
audio_array
.
numpy
()
def
get_audio_range
(
self
,
start_frame
:
int
,
end_frame
:
int
)
->
Tuple
[
int
,
int
]:
"""Calculate audio range for given frame range"""
audio_frame_rate
=
self
.
audio_sr
/
self
.
target_fps
return
round
(
start_frame
*
audio_frame_rate
),
round
((
end_frame
+
1
)
*
audio_frame_rate
)
def
segment_audio
(
self
,
audio_array
:
np
.
ndarray
,
expected_frames
:
int
,
max_num_frames
:
int
,
prev_frame_length
:
int
=
5
)
->
List
[
AudioSegment
]:
"""Segment audio based on frame requirements"""
segments
=
[]
# Calculate intervals
interval_num
=
1
res_frame_num
=
0
if
expected_frames
<=
max_num_frames
:
interval_num
=
1
else
:
interval_num
=
max
(
int
((
expected_frames
-
max_num_frames
)
/
(
max_num_frames
-
prev_frame_length
))
+
1
,
1
)
res_frame_num
=
expected_frames
-
interval_num
*
(
max_num_frames
-
prev_frame_length
)
if
res_frame_num
>
5
:
interval_num
+=
1
# Create segments
for
idx
in
range
(
interval_num
):
if
idx
==
0
:
# First segment
audio_start
,
audio_end
=
self
.
get_audio_range
(
0
,
max_num_frames
)
segment_audio
=
audio_array
[
audio_start
:
audio_end
]
useful_length
=
None
if
expected_frames
<
max_num_frames
:
useful_length
=
segment_audio
.
shape
[
0
]
max_num_audio_length
=
int
((
max_num_frames
+
1
)
/
self
.
target_fps
*
self
.
audio_sr
)
segment_audio
=
np
.
concatenate
((
segment_audio
,
np
.
zeros
(
max_num_audio_length
-
useful_length
)),
axis
=
0
)
segments
.
append
(
AudioSegment
(
segment_audio
,
0
,
max_num_frames
,
False
,
useful_length
))
elif
res_frame_num
>
5
and
idx
==
interval_num
-
1
:
# Last segment (might be shorter)
start_frame
=
idx
*
max_num_frames
-
idx
*
prev_frame_length
audio_start
,
audio_end
=
self
.
get_audio_range
(
start_frame
,
expected_frames
)
segment_audio
=
audio_array
[
audio_start
:
audio_end
]
useful_length
=
segment_audio
.
shape
[
0
]
max_num_audio_length
=
int
((
max_num_frames
+
1
)
/
self
.
target_fps
*
self
.
audio_sr
)
segment_audio
=
np
.
concatenate
((
segment_audio
,
np
.
zeros
(
max_num_audio_length
-
useful_length
)),
axis
=
0
)
segments
.
append
(
AudioSegment
(
segment_audio
,
start_frame
,
expected_frames
,
True
,
useful_length
))
else
:
# Middle segments
start_frame
=
idx
*
max_num_frames
-
idx
*
prev_frame_length
end_frame
=
(
idx
+
1
)
*
max_num_frames
-
idx
*
prev_frame_length
audio_start
,
audio_end
=
self
.
get_audio_range
(
start_frame
,
end_frame
)
segment_audio
=
audio_array
[
audio_start
:
audio_end
]
segments
.
append
(
AudioSegment
(
segment_audio
,
start_frame
,
end_frame
,
False
))
return
segments
class
VideoGenerator
:
"""Handles video generation for each segment"""
def
__init__
(
self
,
model
,
vae_encoder
,
vae_decoder
,
config
):
self
.
model
=
model
self
.
vae_encoder
=
vae_encoder
self
.
vae_decoder
=
vae_decoder
self
.
config
=
config
self
.
frame_preprocessor
=
FramePreprocessor
()
def
prepare_prev_latents
(
self
,
prev_video
:
Optional
[
torch
.
Tensor
],
prev_frame_length
:
int
)
->
Optional
[
Dict
[
str
,
torch
.
Tensor
]]:
"""Prepare previous latents for conditioning"""
if
prev_video
is
None
:
return
None
device
=
self
.
model
.
device
dtype
=
torch
.
bfloat16
vae_dtype
=
torch
.
float
tgt_h
,
tgt_w
=
self
.
config
.
tgt_h
,
self
.
config
.
tgt_w
prev_frames
=
torch
.
zeros
((
1
,
3
,
self
.
config
.
target_video_length
,
tgt_h
,
tgt_w
),
device
=
device
)
# Extract and process last frames
last_frames
=
prev_video
[:,
:,
-
prev_frame_length
:].
clone
().
to
(
device
)
last_frames
=
self
.
frame_preprocessor
.
process_prev_frames
(
last_frames
)
prev_frames
[:,
:,
:
prev_frame_length
]
=
last_frames
prev_latents
=
self
.
vae_encoder
.
encode
(
prev_frames
.
to
(
vae_dtype
),
self
.
config
)[
0
].
to
(
dtype
)
# Create mask
prev_token_length
=
(
prev_frame_length
-
1
)
//
4
+
1
_
,
nframe
,
height
,
width
=
self
.
model
.
scheduler
.
latents
.
shape
frames_n
=
(
nframe
-
1
)
*
4
+
1
prev_frame_len
=
max
((
prev_token_length
-
1
)
*
4
+
1
,
0
)
prev_mask
=
torch
.
ones
((
1
,
frames_n
,
height
,
width
),
device
=
device
,
dtype
=
dtype
)
prev_mask
[:,
prev_frame_len
:]
=
0
prev_mask
=
self
.
_wan_mask_rearrange
(
prev_mask
).
unsqueeze
(
0
)
return
{
"prev_latents"
:
prev_latents
,
"prev_mask"
:
prev_mask
}
def
_wan_mask_rearrange
(
self
,
mask
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Rearrange mask for WAN model"""
if
mask
.
ndim
==
3
:
mask
=
mask
[
None
]
assert
mask
.
ndim
==
4
_
,
t
,
h
,
w
=
mask
.
shape
assert
t
==
((
t
-
1
)
//
4
*
4
+
1
)
mask_first_frame
=
torch
.
repeat_interleave
(
mask
[:,
0
:
1
],
repeats
=
4
,
dim
=
1
)
mask
=
torch
.
concat
([
mask_first_frame
,
mask
[:,
1
:]],
dim
=
1
)
mask
=
mask
.
view
(
mask
.
shape
[
1
]
//
4
,
4
,
h
,
w
)
return
mask
.
transpose
(
0
,
1
)
@
torch
.
no_grad
()
def
generate_segment
(
self
,
inputs
:
Dict
[
str
,
Any
],
audio_features
:
torch
.
Tensor
,
prev_video
:
Optional
[
torch
.
Tensor
]
=
None
,
prev_frame_length
:
int
=
5
,
segment_idx
:
int
=
0
)
->
torch
.
Tensor
:
"""Generate video segment"""
# Update inputs with audio features
inputs
[
"audio_encoder_output"
]
=
audio_features
# Reset scheduler for non-first segments
if
segment_idx
>
0
:
self
.
model
.
scheduler
.
reset
()
# Prepare previous latents - ALWAYS needed, even for first segment
device
=
self
.
model
.
device
dtype
=
torch
.
bfloat16
vae_dtype
=
torch
.
float
tgt_h
,
tgt_w
=
self
.
config
.
tgt_h
,
self
.
config
.
tgt_w
max_num_frames
=
self
.
config
.
target_video_length
if
segment_idx
==
0
:
# First segment - create zero frames
prev_frames
=
torch
.
zeros
((
1
,
3
,
max_num_frames
,
tgt_h
,
tgt_w
),
device
=
device
)
prev_latents
=
self
.
vae_encoder
.
encode
(
prev_frames
.
to
(
vae_dtype
),
self
.
config
)[
0
].
to
(
dtype
)
prev_len
=
0
else
:
# Subsequent segments - use previous video
previmg_encoder_output
=
self
.
prepare_prev_latents
(
prev_video
,
prev_frame_length
)
if
previmg_encoder_output
:
prev_latents
=
previmg_encoder_output
[
"prev_latents"
]
prev_len
=
(
prev_frame_length
-
1
)
//
4
+
1
else
:
# Fallback to zeros if prepare_prev_latents fails
prev_frames
=
torch
.
zeros
((
1
,
3
,
max_num_frames
,
tgt_h
,
tgt_w
),
device
=
device
)
prev_latents
=
self
.
vae_encoder
.
encode
(
prev_frames
.
to
(
vae_dtype
),
self
.
config
)[
0
].
to
(
dtype
)
prev_len
=
0
# Create mask for prev_latents
_
,
nframe
,
height
,
width
=
self
.
model
.
scheduler
.
latents
.
shape
frames_n
=
(
nframe
-
1
)
*
4
+
1
prev_frame_len
=
max
((
prev_len
-
1
)
*
4
+
1
,
0
)
prev_mask
=
torch
.
ones
((
1
,
frames_n
,
height
,
width
),
device
=
device
,
dtype
=
dtype
)
prev_mask
[:,
prev_frame_len
:]
=
0
prev_mask
=
self
.
_wan_mask_rearrange
(
prev_mask
).
unsqueeze
(
0
)
# Always set previmg_encoder_output
inputs
[
"previmg_encoder_output"
]
=
{
"prev_latents"
:
prev_latents
,
"prev_mask"
:
prev_mask
}
# Run inference loop
for
step_index
in
range
(
self
.
model
.
scheduler
.
infer_steps
):
logger
.
info
(
f
"==> Segment
{
segment_idx
}
, Step
{
step_index
}
/
{
self
.
model
.
scheduler
.
infer_steps
}
"
)
with
ProfilingContext4Debug
(
"step_pre"
):
self
.
model
.
scheduler
.
step_pre
(
step_index
=
step_index
)
with
ProfilingContext4Debug
(
"infer"
):
self
.
model
.
infer
(
inputs
)
with
ProfilingContext4Debug
(
"step_post"
):
self
.
model
.
scheduler
.
step_post
()
# Decode latents
latents
=
self
.
model
.
scheduler
.
latents
generator
=
self
.
model
.
scheduler
.
generator
gen_video
=
self
.
vae_decoder
.
decode
(
latents
,
generator
=
generator
,
config
=
self
.
config
)
gen_video
=
torch
.
clamp
(
gen_video
,
-
1
,
1
).
to
(
torch
.
float
)
return
gen_video
def
get_crop_bbox
(
ori_h
,
ori_w
,
tgt_h
,
tgt_w
):
def
get_crop_bbox
(
ori_h
,
ori_w
,
tgt_h
,
tgt_w
):
tgt_ar
=
tgt_h
/
tgt_w
tgt_ar
=
tgt_h
/
tgt_w
...
@@ -131,267 +355,69 @@ def adaptive_resize(img):
...
@@ -131,267 +355,69 @@ def adaptive_resize(img):
return
cropped_img
,
target_h
,
target_w
return
cropped_img
,
target_h
,
target_w
def
array_to_video
(
@
RUNNER_REGISTER
(
"wan2.1_audio"
)
image_array
:
np
.
ndarray
,
class
WanAudioRunner
(
WanRunner
):
output_path
:
str
,
def
__init__
(
self
,
config
):
fps
:
int
|
float
=
30
,
super
().
__init__
(
config
)
resolution
:
tuple
[
int
,
int
]
|
tuple
[
float
,
float
]
|
None
=
None
,
self
.
_is_initialized
=
False
disable_log
:
bool
=
False
,
self
.
_audio_adapter_pipe
=
None
lossless
:
bool
=
True
,
self
.
_audio_processor
=
None
output_pix_fmt
:
str
=
"yuv420p"
,
self
.
_video_generator
=
None
)
->
None
:
self
.
_audio_preprocess
=
None
"""Convert an array to a video directly, gif not supported.
Args:
image_array (np.ndarray): shape should be (f * h * w * 3).
output_path (str): output video file path.
fps (Union[int, float, optional): fps. Defaults to 30.
resolution (Optional[Union[Tuple[int, int], Tuple[float, float]]],
optional): (height, width) of the output video.
Defaults to None.
disable_log (bool, optional): whether close the ffmepg command info.
Defaults to False.
output_pix_fmt (str): output pix_fmt in ffmpeg command.
Raises:
FileNotFoundError: check output path.
TypeError: check input array.
Returns:
None.
"""
if
not
isinstance
(
image_array
,
np
.
ndarray
):
raise
TypeError
(
"Input should be np.ndarray."
)
assert
image_array
.
ndim
==
4
assert
image_array
.
shape
[
-
1
]
==
3
if
resolution
:
height
,
width
=
resolution
width
+=
width
%
2
height
+=
height
%
2
else
:
image_array
=
pad_for_libx264
(
image_array
)
height
,
width
=
image_array
.
shape
[
1
],
image_array
.
shape
[
2
]
if
lossless
:
command
=
[
"/usr/bin/ffmpeg"
,
"-y"
,
# (optional) overwrite output file if it exists
"-f"
,
"rawvideo"
,
"-s"
,
f
"
{
int
(
width
)
}
x
{
int
(
height
)
}
"
,
# size of one frame
"-pix_fmt"
,
"bgr24"
,
"-r"
,
f
"
{
fps
}
"
,
# frames per second
"-loglevel"
,
"error"
,
"-threads"
,
"4"
,
"-i"
,
"-"
,
# The input comes from a pipe
"-vcodec"
,
"libx264rgb"
,
"-crf"
,
"0"
,
"-an"
,
# Tells FFMPEG not to expect any audio
output_path
,
]
else
:
output_pix_fmt
=
output_pix_fmt
or
"yuv420p"
command
=
[
"/usr/bin/ffmpeg"
,
"-y"
,
# (optional) overwrite output file if it exists
"-f"
,
"rawvideo"
,
"-s"
,
f
"
{
int
(
width
)
}
x
{
int
(
height
)
}
"
,
# size of one frame
"-pix_fmt"
,
"bgr24"
,
"-r"
,
f
"
{
fps
}
"
,
# frames per second
"-loglevel"
,
"error"
,
"-threads"
,
"4"
,
"-i"
,
"-"
,
# The input comes from a pipe
"-vcodec"
,
"libx264"
,
"-pix_fmt"
,
f
"
{
output_pix_fmt
}
"
,
"-an"
,
# Tells FFMPEG not to expect any audio
output_path
,
]
if
output_pix_fmt
is
not
None
:
command
+=
[
"-pix_fmt"
,
output_pix_fmt
]
if
not
disable_log
:
print
(
f
'Running "
{
" "
.
join
(
command
)
}
"'
)
process
=
subprocess
.
Popen
(
command
,
stdin
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
,
)
if
process
.
stdin
is
None
or
process
.
stderr
is
None
:
raise
BrokenPipeError
(
"No buffer received."
)
index
=
0
while
True
:
if
index
>=
image_array
.
shape
[
0
]:
break
process
.
stdin
.
write
(
image_array
[
index
].
tobytes
())
index
+=
1
process
.
stdin
.
close
()
process
.
stderr
.
close
()
process
.
wait
()
def
pad_for_libx264
(
image_array
):
if
image_array
.
ndim
==
2
or
(
image_array
.
ndim
==
3
and
image_array
.
shape
[
2
]
==
3
):
hei_index
=
0
wid_index
=
1
elif
image_array
.
ndim
==
4
or
(
image_array
.
ndim
==
3
and
image_array
.
shape
[
2
]
!=
3
):
hei_index
=
1
wid_index
=
2
else
:
return
image_array
hei_pad
=
image_array
.
shape
[
hei_index
]
%
2
wid_pad
=
image_array
.
shape
[
wid_index
]
%
2
if
hei_pad
+
wid_pad
>
0
:
pad_width
=
[]
for
dim_index
in
range
(
image_array
.
ndim
):
if
dim_index
==
hei_index
:
pad_width
.
append
((
0
,
hei_pad
))
elif
dim_index
==
wid_index
:
pad_width
.
append
((
0
,
wid_pad
))
else
:
pad_width
.
append
((
0
,
0
))
values
=
0
image_array
=
np
.
pad
(
image_array
,
pad_width
,
mode
=
"constant"
,
constant_values
=
values
)
return
image_array
def
generate_unique_path
(
path
):
if
not
os
.
path
.
exists
(
path
):
return
path
root
,
ext
=
os
.
path
.
splitext
(
path
)
index
=
1
new_path
=
f
"
{
root
}
-
{
index
}{
ext
}
"
while
os
.
path
.
exists
(
new_path
):
index
+=
1
new_path
=
f
"
{
root
}
-
{
index
}{
ext
}
"
return
new_path
def
save_audio
(
audio_array
,
audio_name
:
str
,
video_name
:
str
,
sr
:
int
=
16000
,
output_path
:
Optional
[
str
]
=
None
,
):
logger
.
info
(
f
"Saving audio to
{
audio_name
}
type:
{
type
(
audio_array
)
}
"
)
ta
.
save
(
audio_name
,
torch
.
tensor
(
audio_array
[
None
]),
sample_rate
=
sr
,
)
if
output_path
is
None
:
out_video
=
f
"
{
video_name
[:
-
4
]
}
_with_audio.mp4"
else
:
out_video
=
output_path
parent_dir
=
os
.
path
.
dirname
(
out_video
)
def
initialize_once
(
self
):
if
parent_dir
and
not
os
.
path
.
exists
(
parent_dir
):
"""Initialize all models once for multiple runs"""
os
.
makedirs
(
parent_dir
,
exist_ok
=
True
)
if
self
.
_is_initialized
:
return
if
os
.
path
.
exists
(
out_video
):
logger
.
info
(
"Initializing models (one-time setup)..."
)
os
.
remove
(
out_video
)
subprocess
.
call
([
"/usr/bin/ffmpeg"
,
"-y"
,
"-i"
,
video_name
,
"-i"
,
audio_name
,
out_video
])
# Initialize audio processor
audio_sr
=
self
.
config
.
get
(
"audio_sr"
,
16000
)
target_fps
=
self
.
config
.
get
(
"target_fps"
,
16
)
self
.
_audio_processor
=
AudioProcessor
(
audio_sr
,
target_fps
)
return
out_video
# Load audio feature extractor
self
.
_audio_preprocess
=
AutoFeatureExtractor
.
from_pretrained
(
self
.
config
[
"model_path"
],
subfolder
=
"audio_encoder"
)
# Initialize scheduler
self
.
init_scheduler
()
@
RUNNER_REGISTER
(
"wan2.1_audio"
)
self
.
_is_initialized
=
True
class
WanAudioRunner
(
WanRunner
):
logger
.
info
(
"Model initialization complete"
)
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
def
init_scheduler
(
self
):
def
init_scheduler
(
self
):
"""Initialize consistency model scheduler"""
scheduler
=
ConsistencyModelScheduler
(
self
.
config
)
scheduler
=
ConsistencyModelScheduler
(
self
.
config
)
self
.
model
.
set_scheduler
(
scheduler
)
self
.
model
.
set_scheduler
(
scheduler
)
def
load_audio_models
(
self
):
def
load_audio_adapter_lazy
(
self
):
##音频特征提取器
"""Lazy load audio adapter when needed"""
self
.
audio_preprocess
=
AutoFeatureExtractor
.
from_pretrained
(
self
.
config
[
"model_path"
],
subfolder
=
"audio_encoder"
)
if
self
.
_audio_adapter_pipe
is
not
None
:
return
self
.
_audio_adapter_pipe
#
#音频驱动视频生成
adapter
#
Audio
adapter
audio_adapter_path
=
self
.
config
[
"model_path"
]
+
"/audio_adapter.safetensors"
audio_adapter_path
=
self
.
config
[
"model_path"
]
+
"/audio_adapter.safetensors"
audio_adaper
=
AudioAdapter
.
from_transformer
(
audio_adap
t
er
=
AudioAdapter
.
from_transformer
(
self
.
model
,
self
.
model
,
audio_feature_dim
=
1024
,
audio_feature_dim
=
1024
,
interval
=
1
,
interval
=
1
,
time_freq_dim
=
256
,
time_freq_dim
=
256
,
projection_transformer_layers
=
4
,
projection_transformer_layers
=
4
,
)
)
audio_adapter
=
rank0_load_state_dict_from_path
(
audio_adaper
,
audio_adapter_path
,
strict
=
False
)
audio_adapter
=
rank0_load_state_dict_from_path
(
audio_adap
t
er
,
audio_adapter_path
,
strict
=
False
)
#
#音频特征编码器
#
Audio encoder
device
=
self
.
model
.
device
device
=
self
.
model
.
device
audio_encoder_repo
=
self
.
config
[
"model_path"
]
+
"/audio_encoder"
audio_encoder_repo
=
self
.
config
[
"model_path"
]
+
"/audio_encoder"
audio_adapter_pipe
=
AudioAdapterPipe
(
audio_adapter
,
audio_encoder_repo
=
audio_encoder_repo
,
dtype
=
torch
.
bfloat16
,
device
=
device
,
generator
=
torch
.
Generator
(
device
),
weight
=
1.0
)
self
.
_audio_adapter_pipe
=
AudioAdapterPipe
(
audio_adapter
,
audio_encoder_repo
=
audio_encoder_repo
,
dtype
=
torch
.
bfloat16
,
device
=
device
,
generator
=
torch
.
Generator
(
device
),
weight
=
1.0
)
return
audio_adapter_pipe
def
load_transformer
(
self
):
base_model
=
WanAudioModel
(
self
.
config
.
model_path
,
self
.
config
,
self
.
init_device
)
if
self
.
config
.
get
(
"lora_configs"
)
and
self
.
config
.
lora_configs
:
assert
not
self
.
config
.
get
(
"dit_quantized"
,
False
)
or
self
.
config
.
mm_config
.
get
(
"weight_auto_quant"
,
False
)
lora_wrapper
=
WanLoraWrapper
(
base_model
)
for
lora_config
in
self
.
config
.
lora_configs
:
lora_path
=
lora_config
[
"path"
]
strength
=
lora_config
.
get
(
"strength"
,
1.0
)
lora_name
=
lora_wrapper
.
load_lora
(
lora_path
)
lora_wrapper
.
apply_lora
(
lora_name
,
strength
)
logger
.
info
(
f
"Loaded LoRA:
{
lora_name
}
with strength:
{
strength
}
"
)
return
base_model
def
load_image_encoder
(
self
):
clip_model_dir
=
self
.
config
[
"model_path"
]
+
"/image_encoder"
image_encoder
=
WanVideoIPHandler
(
"CLIPModel"
,
repo_or_path
=
clip_model_dir
,
require_grad
=
False
,
mode
=
"eval"
,
device
=
self
.
init_device
,
dtype
=
torch
.
float16
)
return
image_encoder
def
run_image_encoder
(
self
,
config
,
vae_model
):
ref_img
=
Image
.
open
(
config
.
image_path
)
ref_img
=
(
np
.
array
(
ref_img
).
astype
(
np
.
float32
)
-
127.5
)
/
127.5
ref_img
=
torch
.
from_numpy
(
ref_img
).
to
(
vae_model
.
device
)
ref_img
=
rearrange
(
ref_img
,
"H W C -> 1 C H W"
)
ref_img
=
ref_img
[:,
:
3
]
# resize and crop image
cond_frms
,
tgt_h
,
tgt_w
=
adaptive_resize
(
ref_img
)
config
.
tgt_h
=
tgt_h
config
.
tgt_w
=
tgt_w
clip_encoder_out
=
self
.
image_encoder
.
encode
(
cond_frms
).
squeeze
(
0
).
to
(
torch
.
bfloat16
)
cond_frms
=
rearrange
(
cond_frms
,
"1 C H W -> 1 C 1 H W"
)
lat_h
,
lat_w
=
tgt_h
//
8
,
tgt_w
//
8
config
.
lat_h
=
lat_h
config
.
lat_w
=
lat_w
vae_encode_out
=
vae_model
.
encode
(
cond_frms
.
to
(
torch
.
float
),
config
)
if
isinstance
(
vae_encode_out
,
list
):
#
# list转tensor
vae_encode_out
=
torch
.
stack
(
vae_encode_out
,
dim
=
0
).
to
(
torch
.
bfloat16
)
return
vae_encode_out
,
clip_encoder_out
return
self
.
_audio_adapter_pipe
def
run_input_encoder_internal
(
self
):
def
prepare_inputs
(
self
):
"""Prepare inputs for the model"""
image_encoder_output
=
None
image_encoder_output
=
None
if
os
.
path
.
isfile
(
self
.
config
.
image_path
):
if
os
.
path
.
isfile
(
self
.
config
.
image_path
):
with
ProfilingContext
(
"Run Img Encoder"
):
with
ProfilingContext
(
"Run Img Encoder"
):
vae_encode_out
,
clip_encoder_out
=
self
.
run_image_encoder
(
self
.
config
,
self
.
vae_encoder
)
vae_encode_out
,
clip_encoder_out
=
self
.
run_image_encoder
(
self
.
config
,
self
.
vae_encoder
)
...
@@ -399,204 +425,98 @@ class WanAudioRunner(WanRunner):
...
@@ -399,204 +425,98 @@ class WanAudioRunner(WanRunner):
"clip_encoder_out"
:
clip_encoder_out
,
"clip_encoder_out"
:
clip_encoder_out
,
"vae_encode_out"
:
vae_encode_out
,
"vae_encode_out"
:
vae_encode_out
,
}
}
logger
.
info
(
f
"clip_encoder_out:
{
clip_encoder_out
.
shape
}
vae_encode_out:
{
vae_encode_out
.
shape
}
"
)
with
ProfilingContext
(
"Run Text Encoder"
):
with
ProfilingContext
(
"Run Text Encoder"
):
logger
.
info
(
f
"Prompt:
{
self
.
config
[
'prompt'
]
}
"
)
img
=
Image
.
open
(
self
.
config
[
"image_path"
]).
convert
(
"RGB"
)
img
=
Image
.
open
(
self
.
config
[
"image_path"
]).
convert
(
"RGB"
)
text_encoder_output
=
self
.
run_text_encoder
(
self
.
config
[
"prompt"
],
img
)
text_encoder_output
=
self
.
run_text_encoder
(
self
.
config
[
"prompt"
],
img
)
self
.
set_target_shape
()
self
.
set_target_shape
()
self
.
inputs
=
{
"text_encoder_output"
:
text_encoder_output
,
"image_encoder_output"
:
image_encoder_output
}
# del self.image_encoder # 删除ref的clip模型,只使用一次
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
def
set_target_shape
(
self
):
ret
=
{}
num_channels_latents
=
16
if
self
.
config
.
task
==
"i2v"
:
self
.
config
.
target_shape
=
(
num_channels_latents
,
(
self
.
config
.
target_video_length
-
1
)
//
self
.
config
.
vae_stride
[
0
]
+
1
,
self
.
config
.
lat_h
,
self
.
config
.
lat_w
,
)
ret
[
"lat_h"
]
=
self
.
config
.
lat_h
ret
[
"lat_w"
]
=
self
.
config
.
lat_w
else
:
error_msg
=
"t2v task is not supported in WanAudioRunner"
assert
1
==
0
,
error_msg
ret
[
"target_shape"
]
=
self
.
config
.
target_shape
return
{
"text_encoder_output"
:
text_encoder_output
,
"image_encoder_output"
:
image_encoder_output
,
"audio_adapter_pipe"
:
self
.
load_audio_adapter_lazy
()}
return
ret
def
run
(
self
,
save_video
=
True
):
def
run_pipeline
(
self
,
save_video
=
True
):
def
load_audio
(
in_path
:
str
,
sr
:
float
=
16000
):
"""Optimized pipeline with modular components"""
audio_array
,
ori_sr
=
ta
.
load
(
in_path
)
# Ensure models are initialized
audio_array
=
ta
.
functional
.
resample
(
audio_array
.
mean
(
0
),
orig_freq
=
ori_sr
,
new_freq
=
sr
)
self
.
initialize_once
()
return
audio_array
.
numpy
()
def
get_audio_range
(
start_frame
:
int
,
end_frame
:
int
,
fps
:
float
,
audio_sr
:
float
=
16000
):
# Initialize video generator if needed
audio_frame_rate
=
audio_sr
/
fps
if
self
.
_video_generator
is
None
:
return
round
(
start_frame
*
audio_frame_rate
),
round
((
end_frame
+
1
)
*
audio_frame_rate
)
self
.
_video_generator
=
VideoGenerator
(
self
.
model
,
self
.
vae_encoder
,
self
.
vae_decoder
,
self
.
config
)
def
wan_mask_rearrange
(
mask
:
torch
.
Tensor
):
# Prepare inputs
# mask: 1, T, H, W, where 1 means the input mask is one-channel
with
memory_efficient_inference
():
if
mask
.
ndim
==
3
:
if
self
.
config
[
"use_prompt_enhancer"
]:
mask
=
mask
[
None
]
self
.
config
[
"prompt_enhanced"
]
=
self
.
post_prompt_enhancer
()
assert
mask
.
ndim
==
4
_
,
t
,
h
,
w
=
mask
.
shape
assert
t
==
((
t
-
1
)
//
4
*
4
+
1
)
mask_first_frame
=
torch
.
repeat_interleave
(
mask
[:,
0
:
1
],
repeats
=
4
,
dim
=
1
)
mask
=
torch
.
concat
([
mask_first_frame
,
mask
[:,
1
:]],
dim
=
1
)
mask
=
mask
.
view
(
mask
.
shape
[
1
]
//
4
,
4
,
h
,
w
)
return
mask
.
transpose
(
0
,
1
)
# 4, T // 4, H, W
self
.
inputs
[
"audio_adapter_pipe"
]
=
self
.
load_audio_models
()
self
.
inputs
=
self
.
prepare_inputs
()
self
.
model
.
scheduler
.
prepare
(
self
.
inputs
[
"image_encoder_output"
])
# process audio
# Process audio
audio_sr
=
self
.
config
.
get
(
"audio_sr"
,
16000
)
audio_array
=
self
.
_audio_processor
.
load_audio
(
self
.
config
[
"audio_path"
])
max_num_frames
=
self
.
config
.
get
(
"target_video_length"
,
81
)
# wan2.1一段最多81帧,5秒,16fps
video_duration
=
self
.
config
.
get
(
"video_duration"
,
5
)
target_fps
=
self
.
config
.
get
(
"target_fps"
,
16
)
# 音视频同步帧率
target_fps
=
self
.
config
.
get
(
"target_fps"
,
16
)
video_duration
=
self
.
config
.
get
(
"video_duration"
,
5
)
# 期望视频输出时长
max_num_frames
=
self
.
config
.
get
(
"target_video_length"
,
81
)
audio_array
=
load_audio
(
self
.
config
[
"audio_path"
],
sr
=
audio_sr
)
audio_len
=
int
(
audio_array
.
shape
[
0
]
/
audio_sr
*
target_fps
)
prev_frame_length
=
5
prev_token_length
=
(
prev_frame_length
-
1
)
//
4
+
1
max_num_audio_length
=
int
((
max_num_frames
+
1
)
/
target_fps
*
audio_sr
)
interval_num
=
1
audio_len
=
int
(
audio_array
.
shape
[
0
]
/
self
.
_audio_processor
.
audio_sr
*
target_fps
)
# expected_frames
expected_frames
=
min
(
max
(
1
,
int
(
video_duration
*
target_fps
)),
audio_len
)
expected_frames
=
min
(
max
(
1
,
int
(
float
(
video_duration
)
*
target_fps
)),
audio_len
)
res_frame_num
=
0
if
expected_frames
<=
max_num_frames
:
interval_num
=
1
else
:
interval_num
=
max
(
int
((
expected_frames
-
max_num_frames
)
/
(
max_num_frames
-
prev_frame_length
))
+
1
,
1
)
res_frame_num
=
expected_frames
-
interval_num
*
(
max_num_frames
-
prev_frame_length
)
if
res_frame_num
>
5
:
interval_num
+=
1
audio_start
,
audio_end
=
get_audio_range
(
0
,
expected_frames
,
fps
=
target_fps
,
audio_sr
=
audio_sr
)
# Segment audio
audio_
array_ori
=
audio_array
[
audio_start
:
audio_end
]
audio_
segments
=
self
.
_audio_processor
.
segment_audio
(
audio_array
,
expected_frames
,
max_num_frames
)
# Generate video segments
gen_video_list
=
[]
gen_video_list
=
[]
cut_audio_list
=
[]
cut_audio_list
=
[]
# reference latents
prev_video
=
None
tgt_h
=
self
.
config
.
tgt_h
tgt_w
=
self
.
config
.
tgt_w
device
=
self
.
model
.
scheduler
.
latents
.
device
dtype
=
torch
.
bfloat16
vae_dtype
=
torch
.
float
for
idx
in
range
(
interval_num
):
for
idx
,
segment
in
enumerate
(
audio_segments
):
# Update seed for each segment
self
.
config
.
seed
=
self
.
config
.
seed
+
idx
self
.
config
.
seed
=
self
.
config
.
seed
+
idx
torch
.
manual_seed
(
self
.
config
.
seed
)
torch
.
manual_seed
(
self
.
config
.
seed
)
logger
.
info
(
f
"### manual_seed:
{
self
.
config
.
seed
}
####"
)
logger
.
info
(
f
"Processing segment
{
idx
+
1
}
/
{
len
(
audio_segments
)
}
, seed:
{
self
.
config
.
seed
}
"
)
useful_length
=
-
1
if
idx
==
0
:
# 第一段 Condition padding0
# Process audio features
prev_frames
=
torch
.
zeros
((
1
,
3
,
max_num_frames
,
tgt_h
,
tgt_w
),
device
=
device
)
audio_features
=
self
.
_audio_preprocess
(
segment
.
audio_array
,
sampling_rate
=
self
.
_audio_processor
.
audio_sr
,
return_tensors
=
"pt"
).
input_values
.
squeeze
(
0
).
to
(
self
.
model
.
device
)
prev_latents
=
self
.
vae_encoder
.
encode
(
prev_frames
.
to
(
vae_dtype
),
self
.
config
)[
0
].
to
(
dtype
)
prev_len
=
0
# Generate video segment
audio_start
,
audio_end
=
get_audio_range
(
0
,
max_num_frames
,
fps
=
target_fps
,
audio_sr
=
audio_sr
)
with
memory_efficient_inference
():
audio_array
=
audio_array_ori
[
audio_start
:
audio_end
]
gen_video
=
self
.
_video_generator
.
generate_segment
(
if
expected_frames
<
max_num_frames
:
self
.
inputs
.
copy
(),
# Copy to avoid modifying original
useful_length
=
audio_array
.
shape
[
0
]
audio_features
,
audio_array
=
np
.
concatenate
((
audio_array
,
np
.
zeros
(
max_num_audio_length
)[:
max_num_audio_length
-
useful_length
]),
axis
=
0
)
prev_video
=
prev_video
,
audio_input_feat
=
self
.
audio_preprocess
(
audio_array
,
sampling_rate
=
audio_sr
,
return_tensors
=
"pt"
).
input_values
.
squeeze
(
0
)
prev_frame_length
=
5
,
segment_idx
=
idx
,
elif
res_frame_num
>
5
and
idx
==
interval_num
-
1
:
# 最后一段可能不够81帧
)
prev_frames
=
torch
.
zeros
((
1
,
3
,
max_num_frames
,
tgt_h
,
tgt_w
),
device
=
device
)
last_frames
=
gen_video_list
[
-
1
][:,
:,
-
prev_frame_length
:].
clone
().
to
(
device
)
last_frames
=
last_frames
.
cpu
().
detach
().
numpy
()
last_frames
=
add_noise_to_frames
(
last_frames
)
last_frames
=
add_mask_to_frames
(
last_frames
,
mask_rate
=
0.1
)
# mask 0.10
last_frames
=
torch
.
from_numpy
(
last_frames
).
to
(
dtype
=
dtype
,
device
=
device
)
prev_frames
[:,
:,
:
prev_frame_length
]
=
last_frames
prev_latents
=
self
.
vae_encoder
.
encode
(
prev_frames
.
to
(
vae_dtype
),
self
.
config
)[
0
].
to
(
dtype
)
prev_len
=
prev_token_length
audio_start
,
audio_end
=
get_audio_range
(
idx
*
max_num_frames
-
idx
*
prev_frame_length
,
expected_frames
,
fps
=
target_fps
,
audio_sr
=
audio_sr
)
audio_array
=
audio_array_ori
[
audio_start
:
audio_end
]
useful_length
=
audio_array
.
shape
[
0
]
audio_array
=
np
.
concatenate
((
audio_array
,
np
.
zeros
(
max_num_audio_length
)[:
max_num_audio_length
-
useful_length
]),
axis
=
0
)
audio_input_feat
=
self
.
audio_preprocess
(
audio_array
,
sampling_rate
=
audio_sr
,
return_tensors
=
"pt"
).
input_values
.
squeeze
(
0
)
else
:
# 中间段满81帧带pre_latens
prev_frames
=
torch
.
zeros
((
1
,
3
,
max_num_frames
,
tgt_h
,
tgt_w
),
device
=
device
)
last_frames
=
gen_video_list
[
-
1
][:,
:,
-
prev_frame_length
:].
clone
().
to
(
device
)
last_frames
=
last_frames
.
cpu
().
detach
().
numpy
()
last_frames
=
add_noise_to_frames
(
last_frames
)
# mean:-3.0 std:0.5
last_frames
=
add_mask_to_frames
(
last_frames
,
mask_rate
=
0.1
)
# mask 0.10
last_frames
=
torch
.
from_numpy
(
last_frames
).
to
(
dtype
=
dtype
,
device
=
device
)
prev_frames
[:,
:,
:
prev_frame_length
]
=
last_frames
prev_latents
=
self
.
vae_encoder
.
encode
(
prev_frames
.
to
(
vae_dtype
),
self
.
config
)[
0
].
to
(
dtype
)
prev_len
=
prev_token_length
audio_start
,
audio_end
=
get_audio_range
(
idx
*
max_num_frames
-
idx
*
prev_frame_length
,
(
idx
+
1
)
*
max_num_frames
-
idx
*
prev_frame_length
,
fps
=
target_fps
,
audio_sr
=
audio_sr
)
audio_array
=
audio_array_ori
[
audio_start
:
audio_end
]
audio_input_feat
=
self
.
audio_preprocess
(
audio_array
,
sampling_rate
=
audio_sr
,
return_tensors
=
"pt"
).
input_values
.
squeeze
(
0
)
self
.
inputs
[
"audio_encoder_output"
]
=
audio_input_feat
.
to
(
device
)
if
idx
!=
0
:
self
.
model
.
scheduler
.
reset
()
if
prev_latents
is
not
None
:
_
,
nframe
,
height
,
width
=
self
.
model
.
scheduler
.
latents
.
shape
# bs = 1
frames_n
=
(
nframe
-
1
)
*
4
+
1
prev_frame_len
=
max
((
prev_len
-
1
)
*
4
+
1
,
0
)
prev_mask
=
torch
.
ones
((
1
,
frames_n
,
height
,
width
),
device
=
device
,
dtype
=
dtype
)
prev_mask
[:,
prev_frame_len
:]
=
0
prev_mask
=
wan_mask_rearrange
(
prev_mask
).
unsqueeze
(
0
)
previmg_encoder_output
=
{
"prev_latents"
:
prev_latents
,
"prev_mask"
:
prev_mask
,
}
self
.
inputs
[
"previmg_encoder_output"
]
=
previmg_encoder_output
for
step_index
in
range
(
self
.
model
.
scheduler
.
infer_steps
):
logger
.
info
(
f
"==> step_index:
{
step_index
}
/
{
self
.
model
.
scheduler
.
infer_steps
}
"
)
with
ProfilingContext4Debug
(
"step_pre"
):
self
.
model
.
scheduler
.
step_pre
(
step_index
=
step_index
)
with
ProfilingContext4Debug
(
"infer"
):
self
.
model
.
infer
(
self
.
inputs
)
with
ProfilingContext4Debug
(
"step_post"
):
# Extract relevant frames
self
.
model
.
scheduler
.
step_post
()
start_frame
=
0
if
idx
==
0
else
5
start_audio_frame
=
0
if
idx
==
0
else
int
(
6
*
self
.
_audio_processor
.
audio_sr
/
target_fps
)
latents
=
self
.
model
.
scheduler
.
latents
if
segment
.
is_last
and
segment
.
useful_length
:
generator
=
self
.
model
.
scheduler
.
generator
end_frame
=
segment
.
end_frame
-
segment
.
start_frame
gen_video
=
self
.
vae_decoder
.
decode
(
latents
,
generator
=
generator
,
config
=
self
.
config
)
gen_video_list
.
append
(
gen_video
[:,
:,
start_frame
:
end_frame
].
cpu
())
gen_video
=
torch
.
clamp
(
gen_video
,
-
1
,
1
).
to
(
torch
.
float
)
cut_audio_list
.
append
(
segment
.
audio_array
[
start_audio_frame
:
segment
.
useful_length
])
start_frame
=
0
if
idx
==
0
else
prev_frame_length
elif
segment
.
useful_length
and
expected_frames
<
max_num_frames
:
start_audio_frame
=
0
if
idx
==
0
else
int
((
prev_frame_length
+
1
)
*
audio_sr
/
target_fps
)
if
res_frame_num
>
5
and
idx
==
interval_num
-
1
:
gen_video_list
.
append
(
gen_video
[:,
:,
start_frame
:
res_frame_num
].
cpu
())
cut_audio_list
.
append
(
audio_array
[
start_audio_frame
:
useful_length
])
elif
expected_frames
<
max_num_frames
and
useful_length
!=
-
1
:
gen_video_list
.
append
(
gen_video
[:,
:,
start_frame
:
expected_frames
].
cpu
())
gen_video_list
.
append
(
gen_video
[:,
:,
start_frame
:
expected_frames
].
cpu
())
cut_audio_list
.
append
(
audio_array
[
start_audio_frame
:
useful_length
])
cut_audio_list
.
append
(
segment
.
audio_array
[
start_audio_frame
:
segment
.
useful_length
])
else
:
else
:
gen_video_list
.
append
(
gen_video
[:,
:,
start_frame
:].
cpu
())
gen_video_list
.
append
(
gen_video
[:,
:,
start_frame
:].
cpu
())
cut_audio_list
.
append
(
audio_array
[
start_audio_frame
:])
cut_audio_list
.
append
(
segment
.
audio_array
[
start_audio_frame
:])
# Update prev_video for next iteration
prev_video
=
gen_video
# Clean up GPU memory after each segment
del
gen_video
torch
.
cuda
.
empty_cache
()
# Merge results
with
memory_efficient_inference
():
gen_lvideo
=
torch
.
cat
(
gen_video_list
,
dim
=
2
).
float
()
gen_lvideo
=
torch
.
cat
(
gen_video_list
,
dim
=
2
).
float
()
merge_audio
=
np
.
concatenate
(
cut_audio_list
,
axis
=
0
).
astype
(
np
.
float32
)
merge_audio
=
np
.
concatenate
(
cut_audio_list
,
axis
=
0
).
astype
(
np
.
float32
)
comfyui_images
=
vae_to_comfyui_image
(
gen_lvideo
)
comfyui_images
=
vae_to_comfyui_image
(
gen_lvideo
)
# Apply frame interpolation if configured
# Apply frame interpolation if configured
if
"video_frame_interpolation"
in
self
.
config
:
if
"video_frame_interpolation"
in
self
.
config
and
self
.
vfi_model
is
not
None
:
assert
self
.
vfi_model
is
not
None
and
self
.
config
[
"video_frame_interpolation"
].
get
(
"target_fps"
,
None
)
is
not
None
interpolation_target_fps
=
self
.
config
[
"video_frame_interpolation"
][
"target_fps"
]
interpolation_target_fps
=
self
.
config
[
"video_frame_interpolation"
][
"target_fps"
]
logger
.
info
(
f
"Interpolating frames from
{
target_fps
}
to
{
interpolation_target_fps
}
"
)
logger
.
info
(
f
"Interpolating frames from
{
target_fps
}
to
{
interpolation_target_fps
}
"
)
comfyui_images
=
self
.
vfi_model
.
interpolate_frames
(
comfyui_images
=
self
.
vfi_model
.
interpolate_frames
(
...
@@ -604,39 +524,118 @@ class WanAudioRunner(WanRunner):
...
@@ -604,39 +524,118 @@ class WanAudioRunner(WanRunner):
source_fps
=
target_fps
,
source_fps
=
target_fps
,
target_fps
=
interpolation_target_fps
,
target_fps
=
interpolation_target_fps
,
)
)
# Update target_fps for saving
target_fps
=
interpolation_target_fps
target_fps
=
interpolation_target_fps
# Convert audio to ComfyUI format
# Convert audio to ComfyUI format
# Convert numpy array to torch tensor and add batch dimension
audio_waveform
=
torch
.
from_numpy
(
merge_audio
).
unsqueeze
(
0
).
unsqueeze
(
0
)
audio_waveform
=
torch
.
from_numpy
(
merge_audio
).
unsqueeze
(
0
).
unsqueeze
(
0
)
# [batch, channels, samples]
comfyui_audio
=
{
"waveform"
:
audio_waveform
,
"sample_rate"
:
self
.
_audio_processor
.
audio_sr
}
comfyui_audio
=
{
"waveform"
:
audio_waveform
,
"sample_rate"
:
audio_sr
}
# Save video if requested
# Save video if requested
if
save_video
and
self
.
config
.
get
(
"save_video_path"
,
None
):
if
save_video
and
self
.
config
.
get
(
"save_video_path"
,
None
):
out_path
=
os
.
path
.
join
(
"./"
,
"video_merge.mp4"
)
self
.
_save_video_with_audio
(
comfyui_images
,
merge_audio
,
target_fps
)
audio_file
=
os
.
path
.
join
(
"./"
,
"audio_merge.wav"
)
# Use the updated target_fps (after interpolation if applied)
# Final cleanup
save_to_video
(
comfyui_images
,
out_path
,
target_fps
)
self
.
end_run
()
save_audio
(
merge_audio
,
audio_file
,
out_path
,
output_path
=
self
.
config
.
get
(
"save_video_path"
,
None
))
os
.
remove
(
out_path
)
os
.
remove
(
audio_file
)
return
comfyui_images
,
comfyui_audio
return
comfyui_images
,
comfyui_audio
def
run_pipeline
(
self
,
save_video
=
True
):
def
_save_video_with_audio
(
self
,
images
,
audio_array
,
fps
):
if
self
.
config
[
"use_prompt_enhancer"
]:
"""Save video with audio"""
self
.
config
[
"prompt_enhanced"
]
=
self
.
post_prompt_enhancer
()
import
tempfile
self
.
run_input_encoder_internal
()
with
tempfile
.
NamedTemporaryFile
(
suffix
=
".mp4"
,
delete
=
False
)
as
video_tmp
:
self
.
set_target_shape
()
video_path
=
video_tmp
.
name
self
.
init_scheduler
()
with
tempfile
.
NamedTemporaryFile
(
suffix
=
".wav"
,
delete
=
False
)
as
audio_tmp
:
self
.
model
.
scheduler
.
prepare
(
self
.
inputs
[
"image_encoder_output"
])
audio_path
=
audio_tmp
.
name
images
,
audio
=
self
.
run
(
save_video
)
# run() now returns both images and audio
self
.
end_run
()
gc
.
collect
()
try
:
torch
.
cuda
.
empty_cache
()
# Save video
save_to_video
(
images
,
video_path
,
fps
)
# Save audio
ta
.
save
(
audio_path
,
torch
.
tensor
(
audio_array
[
None
]),
sample_rate
=
self
.
_audio_processor
.
audio_sr
)
# Merge video and audio
output_path
=
self
.
config
.
get
(
"save_video_path"
)
parent_dir
=
os
.
path
.
dirname
(
output_path
)
if
parent_dir
and
not
os
.
path
.
exists
(
parent_dir
):
os
.
makedirs
(
parent_dir
,
exist_ok
=
True
)
subprocess
.
call
([
"/usr/bin/ffmpeg"
,
"-y"
,
"-i"
,
video_path
,
"-i"
,
audio_path
,
output_path
])
logger
.
info
(
f
"Saved video with audio to:
{
output_path
}
"
)
finally
:
# Clean up temp files
if
os
.
path
.
exists
(
video_path
):
os
.
remove
(
video_path
)
if
os
.
path
.
exists
(
audio_path
):
os
.
remove
(
audio_path
)
def
load_transformer
(
self
):
"""Load transformer with LoRA support"""
base_model
=
WanAudioModel
(
self
.
config
.
model_path
,
self
.
config
,
self
.
init_device
)
if
self
.
config
.
get
(
"lora_configs"
)
and
self
.
config
.
lora_configs
:
assert
not
self
.
config
.
get
(
"dit_quantized"
,
False
)
or
self
.
config
.
mm_config
.
get
(
"weight_auto_quant"
,
False
)
lora_wrapper
=
WanLoraWrapper
(
base_model
)
for
lora_config
in
self
.
config
.
lora_configs
:
lora_path
=
lora_config
[
"path"
]
strength
=
lora_config
.
get
(
"strength"
,
1.0
)
lora_name
=
lora_wrapper
.
load_lora
(
lora_path
)
lora_wrapper
.
apply_lora
(
lora_name
,
strength
)
logger
.
info
(
f
"Loaded LoRA:
{
lora_name
}
with strength:
{
strength
}
"
)
return
base_model
return
images
,
audio
def
load_image_encoder
(
self
):
"""Load image encoder"""
clip_model_dir
=
self
.
config
[
"model_path"
]
+
"/image_encoder"
image_encoder
=
WanVideoIPHandler
(
"CLIPModel"
,
repo_or_path
=
clip_model_dir
,
require_grad
=
False
,
mode
=
"eval"
,
device
=
self
.
init_device
,
dtype
=
torch
.
float16
)
return
image_encoder
def
run_image_encoder
(
self
,
config
,
vae_model
):
"""Run image encoder"""
ref_img
=
Image
.
open
(
config
.
image_path
)
ref_img
=
(
np
.
array
(
ref_img
).
astype
(
np
.
float32
)
-
127.5
)
/
127.5
ref_img
=
torch
.
from_numpy
(
ref_img
).
to
(
vae_model
.
device
)
ref_img
=
rearrange
(
ref_img
,
"H W C -> 1 C H W"
)
ref_img
=
ref_img
[:,
:
3
]
# Resize and crop image
cond_frms
,
tgt_h
,
tgt_w
=
adaptive_resize
(
ref_img
)
config
.
tgt_h
=
tgt_h
config
.
tgt_w
=
tgt_w
clip_encoder_out
=
self
.
image_encoder
.
encode
(
cond_frms
).
squeeze
(
0
).
to
(
torch
.
bfloat16
)
cond_frms
=
rearrange
(
cond_frms
,
"1 C H W -> 1 C 1 H W"
)
lat_h
,
lat_w
=
tgt_h
//
8
,
tgt_w
//
8
config
.
lat_h
=
lat_h
config
.
lat_w
=
lat_w
vae_encode_out
=
vae_model
.
encode
(
cond_frms
.
to
(
torch
.
float
),
config
)
if
isinstance
(
vae_encode_out
,
list
):
vae_encode_out
=
torch
.
stack
(
vae_encode_out
,
dim
=
0
).
to
(
torch
.
bfloat16
)
return
vae_encode_out
,
clip_encoder_out
def
set_target_shape
(
self
):
"""Set target shape for generation"""
ret
=
{}
num_channels_latents
=
16
if
self
.
config
.
task
==
"i2v"
:
self
.
config
.
target_shape
=
(
num_channels_latents
,
(
self
.
config
.
target_video_length
-
1
)
//
self
.
config
.
vae_stride
[
0
]
+
1
,
self
.
config
.
lat_h
,
self
.
config
.
lat_w
,
)
ret
[
"lat_h"
]
=
self
.
config
.
lat_h
ret
[
"lat_w"
]
=
self
.
config
.
lat_w
else
:
error_msg
=
"t2v task is not supported in WanAudioRunner"
assert
False
,
error_msg
ret
[
"target_shape"
]
=
self
.
config
.
target_shape
return
ret
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment