Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
e58dd9fe
Commit
e58dd9fe
authored
Jun 25, 2025
by
wangshankun
Browse files
audio驱动wan视频生成
parent
7260cb2e
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
1280 additions
and
13 deletions
+1280
-13
configs/wan_i2v_audio.json
configs/wan_i2v_audio.json
+12
-0
lightx2v/infer.py
lightx2v/infer.py
+5
-2
lightx2v/models/networks/wan/audio_adapter.py
lightx2v/models/networks/wan/audio_adapter.py
+414
-0
lightx2v/models/networks/wan/audio_model.py
lightx2v/models/networks/wan/audio_model.py
+78
-0
lightx2v/models/networks/wan/infer/post_wan_audio_infer.py
lightx2v/models/networks/wan/infer/post_wan_audio_infer.py
+43
-0
lightx2v/models/networks/wan/infer/pre_wan_audio_infer.py
lightx2v/models/networks/wan/infer/pre_wan_audio_infer.py
+106
-0
lightx2v/models/networks/wan/infer/transformer_infer.py
lightx2v/models/networks/wan/infer/transformer_infer.py
+43
-11
lightx2v/models/networks/wan/infer/utils.py
lightx2v/models/networks/wan/infer/utils.py
+40
-0
lightx2v/models/runners/wan/wan_audio_runner.py
lightx2v/models/runners/wan/wan_audio_runner.py
+490
-0
lightx2v/models/schedulers/wan/scheduler.py
lightx2v/models/schedulers/wan/scheduler.py
+9
-0
scripts/run_wan_i2v_audio.sh
scripts/run_wan_i2v_audio.sh
+40
-0
No files found.
configs/wan_i2v_audio.json
0 → 100755
View file @
e58dd9fe
{
"infer_steps"
:
20
,
"target_video_length"
:
81
,
"target_height"
:
480
,
"target_width"
:
832
,
"attention_type"
:
"flash_attn3"
,
"seed"
:
42
,
"sample_guide_scale"
:
5.0
,
"sample_shift"
:
5
,
"enable_cfg"
:
false
,
"cpu_offload"
:
false
}
lightx2v/infer.py
View file @
e58dd9fe
...
@@ -14,6 +14,7 @@ from lightx2v.models.runners.hunyuan.hunyuan_runner import HunyuanRunner
...
@@ -14,6 +14,7 @@ from lightx2v.models.runners.hunyuan.hunyuan_runner import HunyuanRunner
from
lightx2v.models.runners.wan.wan_runner
import
WanRunner
from
lightx2v.models.runners.wan.wan_runner
import
WanRunner
from
lightx2v.models.runners.wan.wan_distill_runner
import
WanDistillRunner
from
lightx2v.models.runners.wan.wan_distill_runner
import
WanDistillRunner
from
lightx2v.models.runners.wan.wan_causvid_runner
import
WanCausVidRunner
from
lightx2v.models.runners.wan.wan_causvid_runner
import
WanCausVidRunner
from
lightx2v.models.runners.wan.wan_audio_runner
import
WanAudioRunner
from
lightx2v.models.runners.wan.wan_skyreels_v2_df_runner
import
WanSkyreelsV2DFRunner
from
lightx2v.models.runners.wan.wan_skyreels_v2_df_runner
import
WanSkyreelsV2DFRunner
from
lightx2v.models.runners.graph_runner
import
GraphRunner
from
lightx2v.models.runners.graph_runner
import
GraphRunner
from
lightx2v.models.runners.cogvideox.cogvidex_runner
import
CogvideoxRunner
from
lightx2v.models.runners.cogvideox.cogvidex_runner
import
CogvideoxRunner
...
@@ -41,14 +42,16 @@ def init_runner(config):
...
@@ -41,14 +42,16 @@ def init_runner(config):
async
def
main
():
async
def
main
():
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--model_cls"
,
type
=
str
,
required
=
True
,
choices
=
[
"wan2.1"
,
"hunyuan"
,
"wan2.1_distill"
,
"wan2.1_causvid"
,
"wan2.1_skyreels_v2_df"
,
"cogvideox"
],
default
=
"hunyuan"
)
parser
.
add_argument
(
"--model_cls"
,
type
=
str
,
required
=
True
,
choices
=
[
"wan2.1"
,
"hunyuan"
,
"wan2.1_distill"
,
"wan2.1_causvid"
,
"wan2.1_skyreels_v2_df"
,
"cogvideox"
,
"wan2.1_audio"
],
default
=
"hunyuan"
)
parser
.
add_argument
(
"--task"
,
type
=
str
,
choices
=
[
"t2v"
,
"i2v"
],
default
=
"t2v"
)
parser
.
add_argument
(
"--task"
,
type
=
str
,
choices
=
[
"t2v"
,
"i2v"
],
default
=
"t2v"
)
parser
.
add_argument
(
"--model_path"
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
"--model_path"
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
"--config_json"
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
"--config_json"
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
"--use_prompt_enhancer"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--use_prompt_enhancer"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--prompt"
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
"--prompt"
,
type
=
str
,
default
=
""
,
help
=
"The input prompt for text-to-video generation"
)
parser
.
add_argument
(
"--negative_prompt"
,
type
=
str
,
default
=
""
)
parser
.
add_argument
(
"--negative_prompt"
,
type
=
str
,
default
=
""
)
parser
.
add_argument
(
"--prompt_path"
,
type
=
str
,
default
=
""
,
help
=
"The path to input prompt file"
)
parser
.
add_argument
(
"--audio_path"
,
type
=
str
,
default
=
""
,
help
=
"The path to input audio file"
)
parser
.
add_argument
(
"--image_path"
,
type
=
str
,
default
=
""
,
help
=
"The path to input image file or path for image-to-video (i2v) task"
)
parser
.
add_argument
(
"--image_path"
,
type
=
str
,
default
=
""
,
help
=
"The path to input image file or path for image-to-video (i2v) task"
)
parser
.
add_argument
(
"--save_video_path"
,
type
=
str
,
default
=
"./output_lightx2v.mp4"
,
help
=
"The path to save video path/file"
)
parser
.
add_argument
(
"--save_video_path"
,
type
=
str
,
default
=
"./output_lightx2v.mp4"
,
help
=
"The path to save video path/file"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
...
...
lightx2v/models/networks/wan/audio_adapter.py
0 → 100644
View file @
e58dd9fe
import
flash_attn
import
math
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
diffusers.models.embeddings
import
TimestepEmbedding
,
Timesteps
from
einops
import
rearrange
from
transformers
import
AutoModel
from
loguru
import
logger
import
pdb
import
os
import
safetensors
from
typing
import
List
,
Optional
,
Tuple
,
Union
def
load_safetensors
(
in_path
:
str
):
if
os
.
path
.
isdir
(
in_path
):
return
load_safetensors_from_dir
(
in_path
)
elif
os
.
path
.
isfile
(
in_path
):
return
load_safetensors_from_path
(
in_path
)
else
:
raise
ValueError
(
f
"
{
in_path
}
does not exist"
)
def
load_safetensors_from_path
(
in_path
:
str
):
tensors
=
{}
with
safetensors
.
safe_open
(
in_path
,
framework
=
"pt"
,
device
=
"cpu"
)
as
f
:
for
key
in
f
.
keys
():
tensors
[
key
]
=
f
.
get_tensor
(
key
)
return
tensors
def
load_safetensors_from_dir
(
in_dir
:
str
):
tensors
=
{}
safetensors
=
os
.
listdir
(
in_dir
)
safetensors
=
[
f
for
f
in
safetensors
if
f
.
endswith
(
".safetensors"
)]
for
f
in
safetensors
:
tensors
.
update
(
load_safetensors_from_path
(
os
.
path
.
join
(
in_dir
,
f
)))
return
tensors
def
load_pt_safetensors
(
in_path
:
str
):
ext
=
os
.
path
.
splitext
(
in_path
)[
-
1
]
if
ext
in
(
".pt"
,
".pth"
,
".tar"
):
state_dict
=
torch
.
load
(
in_path
,
map_location
=
"cpu"
,
weights_only
=
True
)
else
:
state_dict
=
load_safetensors
(
in_path
)
return
state_dict
def
rank0_load_state_dict_from_path
(
model
,
in_path
:
str
,
strict
:
bool
=
True
):
import
torch.distributed
as
dist
if
(
dist
.
is_initialized
()
and
dist
.
get_rank
()
==
0
)
or
(
not
dist
.
is_initialized
()):
state_dict
=
load_pt_safetensors
(
in_path
)
model
.
load_state_dict
(
state_dict
,
strict
=
strict
)
if
dist
.
is_initialized
():
dist
.
barrier
()
return
model
.
to
(
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
def
linear_interpolation
(
features
,
output_len
:
int
):
features
=
features
.
transpose
(
1
,
2
)
output_features
=
F
.
interpolate
(
features
,
size
=
output_len
,
align_corners
=
False
,
mode
=
"linear"
)
return
output_features
.
transpose
(
1
,
2
)
def
get_q_lens_audio_range
(
batchsize
,
n_tokens_per_rank
,
n_query_tokens
,
n_tokens_per_frame
,
sp_rank
,
):
if
n_query_tokens
==
0
:
q_lens
=
[
1
]
*
batchsize
return
q_lens
,
0
,
1
idx0
=
n_tokens_per_rank
*
sp_rank
first_length
=
idx0
-
idx0
//
n_tokens_per_frame
*
n_tokens_per_frame
n_frames
=
(
n_query_tokens
-
first_length
)
//
n_tokens_per_frame
last_length
=
n_query_tokens
-
n_frames
*
n_tokens_per_frame
-
first_length
q_lens
=
[]
if
first_length
>
0
:
q_lens
.
append
(
first_length
)
q_lens
+=
[
n_tokens_per_frame
]
*
n_frames
if
last_length
>
0
:
q_lens
.
append
(
last_length
)
t0
=
idx0
//
n_tokens_per_frame
idx1
=
idx0
+
n_query_tokens
t1
=
math
.
ceil
(
idx1
/
n_tokens_per_frame
)
return
q_lens
*
batchsize
,
t0
,
t1
class
PerceiverAttentionCA
(
nn
.
Module
):
def
__init__
(
self
,
dim_head
=
128
,
heads
=
16
,
kv_dim
=
2048
,
adaLN
:
bool
=
False
):
super
().
__init__
()
self
.
dim_head
=
dim_head
self
.
heads
=
heads
inner_dim
=
dim_head
*
heads
kv_dim
=
inner_dim
if
kv_dim
is
None
else
kv_dim
self
.
norm_kv
=
nn
.
LayerNorm
(
kv_dim
)
self
.
norm_q
=
nn
.
LayerNorm
(
inner_dim
,
elementwise_affine
=
not
adaLN
)
self
.
to_q
=
nn
.
Linear
(
inner_dim
,
inner_dim
)
self
.
to_kv
=
nn
.
Linear
(
kv_dim
,
inner_dim
*
2
)
self
.
to_out
=
nn
.
Linear
(
inner_dim
,
inner_dim
)
if
adaLN
:
self
.
shift_scale_gate
=
nn
.
Parameter
(
torch
.
randn
(
1
,
3
,
inner_dim
)
/
inner_dim
**
0.5
)
else
:
shift_scale_gate
=
torch
.
zeros
((
1
,
3
,
inner_dim
))
shift_scale_gate
[:,
2
]
=
1
self
.
register_buffer
(
"shift_scale_gate"
,
shift_scale_gate
,
persistent
=
False
)
def
forward
(
self
,
x
,
latents
,
t_emb
,
q_lens
,
k_lens
):
"""x shape (batchsize, latent_frame, audio_tokens_per_latent,
model_dim) latents (batchsize, length, model_dim)"""
batchsize
=
len
(
x
)
x
=
self
.
norm_kv
(
x
)
shift
,
scale
,
gate
=
(
t_emb
+
self
.
shift_scale_gate
).
chunk
(
3
,
dim
=
1
)
latents
=
self
.
norm_q
(
latents
)
*
(
1
+
scale
)
+
shift
q
=
self
.
to_q
(
latents
)
k
,
v
=
self
.
to_kv
(
x
).
chunk
(
2
,
dim
=-
1
)
q
=
rearrange
(
q
,
"B L (H C) -> (B L) H C"
,
H
=
self
.
heads
)
k
=
rearrange
(
k
,
"B T L (H C) -> (B T L) H C"
,
H
=
self
.
heads
)
v
=
rearrange
(
v
,
"B T L (H C) -> (B T L) H C"
,
H
=
self
.
heads
)
out
=
flash_attn
.
flash_attn_varlen_func
(
q
=
q
,
k
=
k
,
v
=
v
,
cu_seqlens_q
=
torch
.
cat
([
q_lens
.
new_zeros
([
1
]),
q_lens
]).
cumsum
(
0
,
dtype
=
torch
.
int32
).
to
(
q
.
device
,
non_blocking
=
True
),
cu_seqlens_k
=
torch
.
cat
([
k_lens
.
new_zeros
([
1
]),
k_lens
]).
cumsum
(
0
,
dtype
=
torch
.
int32
).
to
(
q
.
device
,
non_blocking
=
True
),
max_seqlen_q
=
q_lens
.
max
(),
max_seqlen_k
=
k_lens
.
max
(),
dropout_p
=
0.0
,
softmax_scale
=
None
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
deterministic
=
False
,
)
out
=
rearrange
(
out
,
"(B L) H C -> B L (H C)"
,
B
=
batchsize
)
return
self
.
to_out
(
out
)
*
gate
class
AudioProjection
(
nn
.
Module
):
def
__init__
(
self
,
audio_feature_dim
:
int
=
768
,
n_neighbors
:
tuple
=
(
2
,
2
),
num_tokens
:
int
=
32
,
mlp_dims
:
tuple
=
(
1024
,
1024
,
32
*
768
),
transformer_layers
:
int
=
4
,
):
super
().
__init__
()
mlp
=
[]
self
.
left
,
self
.
right
=
n_neighbors
self
.
audio_frames
=
sum
(
n_neighbors
)
+
1
in_dim
=
audio_feature_dim
*
self
.
audio_frames
for
i
,
out_dim
in
enumerate
(
mlp_dims
):
mlp
.
append
(
nn
.
Linear
(
in_dim
,
out_dim
))
if
i
!=
len
(
mlp_dims
)
-
1
:
mlp
.
append
(
nn
.
ReLU
())
in_dim
=
out_dim
self
.
mlp
=
nn
.
Sequential
(
*
mlp
)
self
.
norm
=
nn
.
LayerNorm
(
mlp_dims
[
-
1
]
//
num_tokens
)
self
.
num_tokens
=
num_tokens
if
transformer_layers
>
0
:
decoder_layer
=
nn
.
TransformerDecoderLayer
(
d_model
=
audio_feature_dim
,
nhead
=
audio_feature_dim
//
64
,
dim_feedforward
=
4
*
audio_feature_dim
,
dropout
=
0.0
,
batch_first
=
True
)
self
.
transformer_decoder
=
nn
.
TransformerDecoder
(
decoder_layer
,
num_layers
=
transformer_layers
,
)
else
:
self
.
transformer_decoder
=
None
def
forward
(
self
,
audio_feature
,
latent_frame
):
video_frame
=
(
latent_frame
-
1
)
*
4
+
1
audio_feature_ori
=
audio_feature
audio_feature
=
linear_interpolation
(
audio_feature_ori
,
video_frame
)
if
self
.
transformer_decoder
is
not
None
:
audio_feature
=
self
.
transformer_decoder
(
audio_feature
,
audio_feature_ori
)
audio_feature
=
F
.
pad
(
audio_feature
,
pad
=
(
0
,
0
,
self
.
left
,
self
.
right
),
mode
=
"replicate"
)
audio_feature
=
audio_feature
.
unfold
(
dimension
=
1
,
size
=
self
.
audio_frames
,
step
=
1
)
audio_feature
=
rearrange
(
audio_feature
,
"B T C W -> B T (W C)"
)
audio_feature
=
self
.
mlp
(
audio_feature
)
# (B, video_frame, C)
audio_feature
=
rearrange
(
audio_feature
,
"B T (N C) -> B T N C"
,
N
=
self
.
num_tokens
)
# (B, video_frame, num_tokens, C)
return
self
.
norm
(
audio_feature
)
class
TimeEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
time_freq_dim
,
time_proj_dim
):
super
().
__init__
()
self
.
timesteps_proj
=
Timesteps
(
num_channels
=
time_freq_dim
,
flip_sin_to_cos
=
True
,
downscale_freq_shift
=
0
)
self
.
time_embedder
=
TimestepEmbedding
(
in_channels
=
time_freq_dim
,
time_embed_dim
=
dim
)
self
.
act_fn
=
nn
.
SiLU
()
self
.
time_proj
=
nn
.
Linear
(
dim
,
time_proj_dim
)
def
forward
(
self
,
timestep
:
torch
.
Tensor
,
):
timestep
=
self
.
timesteps_proj
(
timestep
)
time_embedder_dtype
=
next
(
iter
(
self
.
time_embedder
.
parameters
())).
dtype
timestep
=
timestep
.
to
(
time_embedder_dtype
)
temb
=
self
.
time_embedder
(
timestep
)
timestep_proj
=
self
.
time_proj
(
self
.
act_fn
(
temb
))
return
timestep_proj
class
AudioAdapter
(
nn
.
Module
):
def
__init__
(
self
,
attention_head_dim
=
64
,
num_attention_heads
=
40
,
base_num_layers
=
30
,
interval
=
1
,
audio_feature_dim
:
int
=
768
,
num_tokens
:
int
=
32
,
mlp_dims
:
tuple
=
(
1024
,
1024
,
32
*
768
),
time_freq_dim
:
int
=
256
,
projection_transformer_layers
:
int
=
4
,
):
super
().
__init__
()
self
.
audio_proj
=
AudioProjection
(
audio_feature_dim
=
audio_feature_dim
,
n_neighbors
=
(
2
,
2
),
num_tokens
=
num_tokens
,
mlp_dims
=
mlp_dims
,
transformer_layers
=
projection_transformer_layers
,
)
# self.num_tokens = num_tokens * 4
self
.
num_tokens_x4
=
num_tokens
*
4
self
.
audio_pe
=
nn
.
Parameter
(
torch
.
randn
(
self
.
num_tokens_x4
,
mlp_dims
[
-
1
]
//
num_tokens
)
*
0.02
)
ca_num
=
math
.
ceil
(
base_num_layers
/
interval
)
self
.
base_num_layers
=
base_num_layers
self
.
interval
=
interval
self
.
ca
=
nn
.
ModuleList
(
[
PerceiverAttentionCA
(
dim_head
=
attention_head_dim
,
heads
=
num_attention_heads
,
kv_dim
=
mlp_dims
[
-
1
]
//
num_tokens
,
adaLN
=
time_freq_dim
>
0
,
)
for
_
in
range
(
ca_num
)
]
)
self
.
dim
=
attention_head_dim
*
num_attention_heads
if
time_freq_dim
>
0
:
self
.
time_embedding
=
TimeEmbedding
(
self
.
dim
,
time_freq_dim
,
self
.
dim
*
3
)
else
:
self
.
time_embedding
=
None
def
rearange_audio_features
(
self
,
audio_feature
:
torch
.
Tensor
):
# audio_feature (B, video_frame, num_tokens, C)
audio_feature_0
=
audio_feature
[:,
:
1
]
audio_feature_0
=
torch
.
repeat_interleave
(
audio_feature_0
,
repeats
=
4
,
dim
=
1
)
audio_feature
=
torch
.
cat
([
audio_feature_0
,
audio_feature
[:,
1
:]],
dim
=
1
)
# (B, 4 * latent_frame, num_tokens, C)
audio_feature
=
rearrange
(
audio_feature
,
"B (T S) N C -> B T (S N) C"
,
S
=
4
)
return
audio_feature
def
forward
(
self
,
audio_feat
:
torch
.
Tensor
,
timestep
:
torch
.
Tensor
,
latent_frame
:
int
,
weight
:
float
=
1.0
):
def
modify_hidden_states
(
hidden_states
,
grid_sizes
,
ca_block
:
PerceiverAttentionCA
,
x
,
t_emb
,
dtype
,
weight
):
"""thw specify the latent_frame, latent_height, latenf_width after
hidden_states is patchified.
latent_frame does not include the reference images so that the
audios and hidden_states are strictly aligned
"""
if
len
(
hidden_states
.
shape
)
==
2
:
# 扩展batchsize dim
hidden_states
=
hidden_states
.
unsqueeze
(
0
)
# bs = 1
# print(weight)
t
,
h
,
w
=
grid_sizes
[
0
].
tolist
()
n_tokens
=
t
*
h
*
w
ori_dtype
=
hidden_states
.
dtype
device
=
hidden_states
.
device
bs
,
n_tokens_per_rank
=
hidden_states
.
shape
[:
2
]
tail_length
=
n_tokens_per_rank
-
n_tokens
n_query_tokens
=
n_tokens_per_rank
-
tail_length
%
n_tokens_per_rank
if
n_query_tokens
>
0
:
hidden_states_aligned
=
hidden_states
[:,
:
n_query_tokens
]
hidden_states_tail
=
hidden_states
[:,
n_query_tokens
:]
else
:
# for ranks that should be excluded from cross-attn, fake cross-attn will be applied so that FSDP works.
hidden_states_aligned
=
hidden_states
[:,
:
1
]
hidden_states_tail
=
hidden_states
[:,
1
:]
q_lens
,
t0
,
t1
=
get_q_lens_audio_range
(
batchsize
=
bs
,
n_tokens_per_rank
=
n_tokens_per_rank
,
n_query_tokens
=
n_query_tokens
,
n_tokens_per_frame
=
h
*
w
,
sp_rank
=
0
)
q_lens
=
torch
.
tensor
(
q_lens
,
device
=
device
,
dtype
=
torch
.
int32
)
"""
processing audio features in sp_state can be moved outside.
"""
x
=
x
[:,
t0
:
t1
]
x
=
x
.
to
(
dtype
)
k_lens
=
torch
.
tensor
([
self
.
num_tokens_x4
]
*
(
t1
-
t0
)
*
bs
,
device
=
device
,
dtype
=
torch
.
int32
)
assert
q_lens
.
shape
==
k_lens
.
shape
# ca_block:CrossAttention函数
residual
=
ca_block
(
x
,
hidden_states_aligned
,
t_emb
,
q_lens
,
k_lens
)
*
weight
residual
=
residual
.
to
(
ori_dtype
)
# audio做了CrossAttention之后以Residual的方式注入
if
n_query_tokens
==
0
:
residual
=
residual
*
0.0
hidden_states
=
torch
.
cat
([
hidden_states_aligned
+
residual
,
hidden_states_tail
],
dim
=
1
)
if
len
(
hidden_states
.
shape
)
==
3
:
#
hidden_states
=
hidden_states
.
squeeze
(
0
)
# bs = 1
return
hidden_states
x
=
self
.
audio_proj
(
audio_feat
,
latent_frame
)
x
=
self
.
rearange_audio_features
(
x
)
x
=
x
+
self
.
audio_pe
if
self
.
time_embedding
is
not
None
:
t_emb
=
self
.
time_embedding
(
timestep
).
unflatten
(
1
,
(
3
,
-
1
))
else
:
t_emb
=
torch
.
zeros
((
len
(
x
),
3
,
self
.
dim
),
device
=
x
.
device
,
dtype
=
x
.
dtype
)
ret_dict
=
{}
for
block_idx
,
base_idx
in
enumerate
(
range
(
0
,
self
.
base_num_layers
,
self
.
interval
)):
block_dict
=
{
"kwargs"
:
{
"ca_block"
:
self
.
ca
[
block_idx
],
"x"
:
x
,
"weight"
:
weight
,
"t_emb"
:
t_emb
,
"dtype"
:
x
.
dtype
,
},
"modify_func"
:
modify_hidden_states
,
}
ret_dict
[
base_idx
]
=
block_dict
return
ret_dict
@
classmethod
def
from_transformer
(
cls
,
transformer
,
audio_feature_dim
:
int
=
1024
,
interval
:
int
=
1
,
time_freq_dim
:
int
=
256
,
projection_transformer_layers
:
int
=
4
,
):
num_attention_heads
=
transformer
.
config
[
"num_heads"
]
base_num_layers
=
transformer
.
config
[
"num_layers"
]
attention_head_dim
=
transformer
.
config
[
"dim"
]
//
num_attention_heads
audio_adapter
=
AudioAdapter
(
attention_head_dim
,
num_attention_heads
,
base_num_layers
,
interval
=
interval
,
audio_feature_dim
=
audio_feature_dim
,
time_freq_dim
=
time_freq_dim
,
projection_transformer_layers
=
projection_transformer_layers
,
mlp_dims
=
(
1024
,
1024
,
32
*
audio_feature_dim
),
)
return
audio_adapter
def
get_fsdp_wrap_module_list
(
self
,
):
ret_list
=
list
(
self
.
ca
)
return
ret_list
def
enable_gradient_checkpointing
(
self
,
):
pass
class
AudioAdapterPipe
:
def
__init__
(
self
,
audio_adapter
:
AudioAdapter
,
audio_encoder_repo
:
str
=
"microsoft/wavlm-base-plus"
,
dtype
=
torch
.
float32
,
device
=
"cuda"
,
generator
=
None
,
tgt_fps
:
int
=
15
,
weight
:
float
=
1.0
)
->
None
:
self
.
audio_adapter
=
audio_adapter
self
.
dtype
=
dtype
self
.
device
=
device
self
.
generator
=
generator
self
.
audio_encoder_dtype
=
torch
.
float16
self
.
audio_encoder
=
AutoModel
.
from_pretrained
(
audio_encoder_repo
)
self
.
audio_encoder
.
eval
()
self
.
audio_encoder
.
to
(
device
,
self
.
audio_encoder_dtype
)
self
.
tgt_fps
=
tgt_fps
self
.
weight
=
weight
if
"base"
in
audio_encoder_repo
:
self
.
audio_feature_dim
=
768
else
:
self
.
audio_feature_dim
=
1024
def
update_model
(
self
,
audio_adapter
):
self
.
audio_adapter
=
audio_adapter
def
__call__
(
self
,
audio_input_feat
,
timestep
,
latent_shape
:
tuple
,
dropout_cond
:
callable
=
None
):
# audio_input_feat is from AudioPreprocessor
latent_frame
=
latent_shape
[
2
]
if
len
(
audio_input_feat
.
shape
)
==
1
:
# 扩展batchsize = 1
audio_input_feat
=
audio_input_feat
.
unsqueeze
(
0
)
latent_frame
=
latent_shape
[
1
]
video_frame
=
(
latent_frame
-
1
)
*
4
+
1
audio_length
=
int
(
50
/
self
.
tgt_fps
*
video_frame
)
with
torch
.
no_grad
():
audio_input_feat
=
audio_input_feat
.
to
(
self
.
device
,
self
.
audio_encoder_dtype
)
try
:
audio_feat
=
self
.
audio_encoder
(
audio_input_feat
,
return_dict
=
True
).
last_hidden_state
except
Exception
as
err
:
audio_feat
=
torch
.
rand
(
1
,
audio_length
,
self
.
audio_feature_dim
).
to
(
self
.
device
)
print
(
err
)
audio_feat
=
audio_feat
.
to
(
self
.
dtype
)
if
dropout_cond
is
not
None
:
audio_feat
=
dropout_cond
(
audio_feat
)
return
self
.
audio_adapter
(
audio_feat
=
audio_feat
,
timestep
=
timestep
,
latent_frame
=
latent_frame
,
weight
=
self
.
weight
)
lightx2v/models/networks/wan/audio_model.py
0 → 100644
View file @
e58dd9fe
import
os
import
torch
import
time
import
glob
from
lightx2v.models.networks.wan.model
import
WanModel
from
lightx2v.models.networks.wan.weights.pre_weights
import
WanPreWeights
from
lightx2v.models.networks.wan.weights.post_weights
import
WanPostWeights
from
lightx2v.models.networks.wan.weights.transformer_weights
import
(
WanTransformerWeights
,
)
from
lightx2v.models.networks.wan.infer.pre_infer
import
WanPreInfer
from
lightx2v.models.networks.wan.infer.post_infer
import
WanPostInfer
from
lightx2v.models.networks.wan.infer.pre_infer
import
WanPreInfer
from
lightx2v.models.networks.wan.infer.pre_wan_audio_infer
import
WanAudioPreInfer
from
lightx2v.models.networks.wan.infer.post_wan_audio_infer
import
WanAudioPostInfer
from
lightx2v.models.networks.wan.infer.feature_caching.transformer_infer
import
WanTransformerInferTeaCaching
from
safetensors
import
safe_open
import
lightx2v.attentions.distributed.ulysses.wrap
as
ulysses_dist_wrap
import
lightx2v.attentions.distributed.ring.wrap
as
ring_dist_wrap
from
lightx2v.models.networks.wan.infer.transformer_infer
import
(
WanTransformerInfer
,
)
from
lightx2v.models.networks.wan.infer.feature_caching.transformer_infer
import
(
WanTransformerInferTeaCaching
,
)
class
WanAudioModel
(
WanModel
):
pre_weight_class
=
WanPreWeights
post_weight_class
=
WanPostWeights
transformer_weight_class
=
WanTransformerWeights
def
__init__
(
self
,
model_path
,
config
,
device
):
super
().
__init__
(
model_path
,
config
,
device
)
def
_init_infer_class
(
self
):
self
.
pre_infer_class
=
WanAudioPreInfer
self
.
post_infer_class
=
WanAudioPostInfer
if
self
.
config
[
"feature_caching"
]
==
"NoCaching"
:
self
.
transformer_infer_class
=
WanTransformerInfer
elif
self
.
config
[
"feature_caching"
]
==
"Tea"
:
self
.
transformer_infer_class
=
WanTransformerInferTeaCaching
else
:
raise
NotImplementedError
(
f
"Unsupported feature_caching type:
{
self
.
config
[
'feature_caching'
]
}
"
)
@
torch
.
no_grad
()
def
infer
(
self
,
inputs
):
if
self
.
config
[
"cpu_offload"
]:
self
.
pre_weight
.
to_cuda
()
self
.
post_weight
.
to_cuda
()
embed
,
grid_sizes
,
pre_infer_out
,
valid_patch_length
=
self
.
pre_infer
.
infer
(
self
.
pre_weight
,
inputs
,
positive
=
True
)
x
=
self
.
transformer_infer
.
infer
(
self
.
transformer_weights
,
grid_sizes
,
embed
,
*
pre_infer_out
)
noise_pred_cond
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
x
,
embed
,
grid_sizes
,
valid_patch_length
)[
0
]
if
self
.
config
[
"feature_caching"
]
==
"Tea"
:
self
.
scheduler
.
cnt
+=
1
if
self
.
scheduler
.
cnt
>=
self
.
scheduler
.
num_steps
:
self
.
scheduler
.
cnt
=
0
self
.
scheduler
.
noise_pred
=
noise_pred_cond
if
self
.
config
[
"enable_cfg"
]:
embed
,
grid_sizes
,
pre_infer_out
,
valid_patch_length
=
self
.
pre_infer
.
infer
(
self
.
pre_weight
,
inputs
,
positive
=
False
)
x
=
self
.
transformer_infer
.
infer
(
self
.
transformer_weights
,
grid_sizes
,
embed
,
*
pre_infer_out
)
noise_pred_uncond
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
x
,
embed
,
grid_sizes
,
valid_patch_length
)[
0
]
if
self
.
config
[
"feature_caching"
]
==
"Tea"
:
self
.
scheduler
.
cnt
+=
1
if
self
.
scheduler
.
cnt
>=
self
.
scheduler
.
num_steps
:
self
.
scheduler
.
cnt
=
0
self
.
scheduler
.
noise_pred
=
noise_pred_uncond
+
self
.
config
.
sample_guide_scale
*
(
noise_pred_cond
-
noise_pred_uncond
)
if
self
.
config
[
"cpu_offload"
]:
self
.
pre_weight
.
to_cpu
()
self
.
post_weight
.
to_cpu
()
lightx2v/models/networks/wan/infer/post_wan_audio_infer.py
0 → 100755
View file @
e58dd9fe
import
math
import
torch
import
torch.cuda.amp
as
amp
from
loguru
import
logger
from
lightx2v.models.networks.wan.infer.post_infer
import
WanPostInfer
class
WanAudioPostInfer
(
WanPostInfer
):
def
__init__
(
self
,
config
):
self
.
out_dim
=
config
[
"out_dim"
]
self
.
patch_size
=
(
1
,
2
,
2
)
def
set_scheduler
(
self
,
scheduler
):
self
.
scheduler
=
scheduler
def
infer
(
self
,
weights
,
x
,
e
,
grid_sizes
,
valid_patch_length
):
if
e
.
dim
()
==
2
:
modulation
=
weights
.
head_modulation
.
tensor
# 1, 2, dim
e
=
(
modulation
+
e
.
unsqueeze
(
1
)).
chunk
(
2
,
dim
=
1
)
elif
e
.
dim
()
==
3
:
# For Diffustion forcing
modulation
=
weights
.
head_modulation
.
tensor
.
unsqueeze
(
2
)
# 1, 2, seq, dim
e
=
(
modulation
+
e
.
unsqueeze
(
1
)).
chunk
(
2
,
dim
=
1
)
e
=
[
ei
.
squeeze
(
1
)
for
ei
in
e
]
norm_out
=
torch
.
nn
.
functional
.
layer_norm
(
x
,
(
x
.
shape
[
1
],),
None
,
None
,
1e-6
).
type_as
(
x
)
out
=
norm_out
*
(
1
+
e
[
1
].
squeeze
(
0
))
+
e
[
0
].
squeeze
(
0
)
x
=
weights
.
head
.
apply
(
out
)
x
=
x
[:,
:
valid_patch_length
]
x
=
self
.
unpatchify
(
x
,
grid_sizes
)
return
[
u
.
float
()
for
u
in
x
]
def
unpatchify
(
self
,
x
,
grid_sizes
):
x
=
x
.
unsqueeze
(
0
)
c
=
self
.
out_dim
out
=
[]
for
u
,
v
in
zip
(
x
,
grid_sizes
.
tolist
()):
u
=
u
[:
math
.
prod
(
v
)].
view
(
*
v
,
*
self
.
patch_size
,
c
)
u
=
torch
.
einsum
(
"fhwpqrc->cfphqwr"
,
u
)
u
=
u
.
reshape
(
c
,
*
[
i
*
j
for
i
,
j
in
zip
(
v
,
self
.
patch_size
)])
out
.
append
(
u
)
return
out
lightx2v/models/networks/wan/infer/pre_wan_audio_infer.py
0 → 100755
View file @
e58dd9fe
import
torch
import
math
from
.utils
import
rope_params
,
sinusoidal_embedding_1d
from
lightx2v.models.networks.wan.infer.pre_infer
import
WanPreInfer
from
loguru
import
logger
class
WanAudioPreInfer
(
WanPreInfer
):
def
__init__
(
self
,
config
):
assert
(
config
[
"dim"
]
%
config
[
"num_heads"
])
==
0
and
(
config
[
"dim"
]
//
config
[
"num_heads"
])
%
2
==
0
d
=
config
[
"dim"
]
//
config
[
"num_heads"
]
self
.
task
=
config
[
"task"
]
self
.
freqs
=
torch
.
cat
(
[
rope_params
(
1024
,
d
-
4
*
(
d
//
6
)),
rope_params
(
1024
,
2
*
(
d
//
6
)),
rope_params
(
1024
,
2
*
(
d
//
6
)),
],
dim
=
1
,
).
cuda
()
self
.
freq_dim
=
config
[
"freq_dim"
]
self
.
dim
=
config
[
"dim"
]
self
.
text_len
=
config
[
"text_len"
]
def
infer
(
self
,
weights
,
inputs
,
positive
):
ltnt_channel
=
self
.
scheduler
.
latents
.
size
(
0
)
prev_latents
=
inputs
[
"previmg_encoder_output"
][
"prev_latents"
].
unsqueeze
(
0
)
prev_mask
=
inputs
[
"previmg_encoder_output"
][
"prev_mask"
]
hidden_states
=
self
.
scheduler
.
latents
.
unsqueeze
(
0
)
hidden_states
=
torch
.
cat
([
hidden_states
[:,
:
ltnt_channel
],
prev_latents
,
prev_mask
],
dim
=
1
)
hidden_states
=
hidden_states
.
squeeze
(
0
)
x
=
[
hidden_states
]
t
=
torch
.
stack
([
self
.
scheduler
.
timesteps
[
self
.
scheduler
.
step_index
]])
audio_dit_blocks
=
[]
audio_encoder_output
=
inputs
[
"audio_encoder_output"
]
audio_model_input
=
{
"audio_input_feat"
:
audio_encoder_output
.
to
(
hidden_states
.
device
),
"latent_shape"
:
hidden_states
.
shape
,
"timestep"
:
t
,
}
audio_dit_blocks
.
append
(
inputs
[
"audio_adapter_pipe"
](
**
audio_model_input
))
if
positive
:
context
=
inputs
[
"text_encoder_output"
][
"context"
]
else
:
context
=
inputs
[
"text_encoder_output"
][
"context_null"
]
seq_len
=
self
.
scheduler
.
seq_len
clip_fea
=
inputs
[
"image_encoder_output"
][
"clip_encoder_out"
]
ref_image_encoder
=
inputs
[
"image_encoder_output"
][
"vae_encode_out"
]
batch_size
=
len
(
x
)
num_channels
,
num_frames
,
height
,
width
=
x
[
0
].
shape
_
,
ref_num_channels
,
ref_num_frames
,
_
,
_
=
ref_image_encoder
.
shape
if
ref_num_channels
!=
num_channels
:
zero_padding
=
torch
.
zeros
(
(
batch_size
,
num_channels
-
ref_num_channels
,
ref_num_frames
,
height
,
width
),
dtype
=
self
.
scheduler
.
latents
.
dtype
,
device
=
self
.
scheduler
.
latents
.
device
,
)
ref_image_encoder
=
torch
.
concat
([
ref_image_encoder
,
zero_padding
],
dim
=
1
)
y
=
list
(
torch
.
unbind
(
ref_image_encoder
,
dim
=
0
))
# 第一个batch维度变成list
# embeddings
x
=
[
weights
.
patch_embedding
.
apply
(
u
.
unsqueeze
(
0
))
for
u
in
x
]
x_grid_sizes
=
torch
.
stack
([
torch
.
tensor
(
u
.
shape
[
2
:],
dtype
=
torch
.
long
)
for
u
in
x
])
x
=
[
u
.
flatten
(
2
).
transpose
(
1
,
2
)
for
u
in
x
]
seq_lens
=
torch
.
tensor
([
u
.
size
(
1
)
for
u
in
x
],
dtype
=
torch
.
long
).
cuda
()
assert
seq_lens
.
max
()
<=
seq_len
x
=
torch
.
cat
([
torch
.
cat
([
u
,
u
.
new_zeros
(
1
,
seq_len
-
u
.
size
(
1
),
u
.
size
(
2
))],
dim
=
1
)
for
u
in
x
])
valid_patch_length
=
x
[
0
].
size
(
0
)
y
=
[
weights
.
patch_embedding
.
apply
(
u
.
unsqueeze
(
0
))
for
u
in
y
]
y_grid_sizes
=
torch
.
stack
([
torch
.
tensor
(
u
.
shape
[
2
:],
dtype
=
torch
.
long
)
for
u
in
y
])
y
=
[
u
.
flatten
(
2
).
transpose
(
1
,
2
).
squeeze
(
0
)
for
u
in
y
]
x
=
[
torch
.
cat
([
a
,
b
],
dim
=
0
)
for
a
,
b
in
zip
(
x
,
y
)]
x
=
torch
.
stack
(
x
,
dim
=
0
)
embed
=
sinusoidal_embedding_1d
(
self
.
freq_dim
,
t
.
flatten
())
embed
=
weights
.
time_embedding_0
.
apply
(
embed
)
embed
=
torch
.
nn
.
functional
.
silu
(
embed
)
embed
=
weights
.
time_embedding_2
.
apply
(
embed
)
embed0
=
torch
.
nn
.
functional
.
silu
(
embed
)
embed0
=
weights
.
time_projection_1
.
apply
(
embed0
).
unflatten
(
1
,
(
6
,
self
.
dim
))
# text embeddings
stacked
=
torch
.
stack
([
torch
.
cat
([
u
,
u
.
new_zeros
(
self
.
text_len
-
u
.
size
(
0
),
u
.
size
(
1
))])
for
u
in
context
])
out
=
weights
.
text_embedding_0
.
apply
(
stacked
.
squeeze
(
0
))
out
=
torch
.
nn
.
functional
.
gelu
(
out
,
approximate
=
"tanh"
)
context
=
weights
.
text_embedding_2
.
apply
(
out
)
if
self
.
task
==
"i2v"
:
context_clip
=
weights
.
proj_0
.
apply
(
clip_fea
)
context_clip
=
weights
.
proj_1
.
apply
(
context_clip
)
context_clip
=
torch
.
nn
.
functional
.
gelu
(
context_clip
,
approximate
=
"none"
)
context_clip
=
weights
.
proj_3
.
apply
(
context_clip
)
context_clip
=
weights
.
proj_4
.
apply
(
context_clip
)
context
=
torch
.
concat
([
context_clip
,
context
],
dim
=
0
)
return
(
embed
,
x_grid_sizes
,
(
x
.
squeeze
(
0
),
embed0
.
squeeze
(
0
),
seq_lens
,
self
.
freqs
,
context
,
audio_dit_blocks
),
valid_patch_length
)
lightx2v/models/networks/wan/infer/transformer_infer.py
View file @
e58dd9fe
import
torch
import
torch
from
.utils
import
compute_freqs
,
compute_freqs_dist
,
apply_rotary_emb
,
apply_rotary_emb_chunk
from
.utils
import
compute_freqs
,
compute_freqs_dist
,
compute_freqs_audio
,
compute_freqs_audio_dist
,
apply_rotary_emb
,
apply_rotary_emb_chunk
from
lightx2v.common.offload.manager
import
(
from
lightx2v.common.offload.manager
import
(
WeightAsyncStreamManager
,
WeightAsyncStreamManager
,
LazyWeightAsyncStreamManager
,
LazyWeightAsyncStreamManager
,
)
)
from
lightx2v.common.transformer_infer.transformer_infer
import
BaseTransformerInfer
from
lightx2v.common.transformer_infer.transformer_infer
import
BaseTransformerInfer
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
from
loguru
import
logger
import
pdb
import
os
class
WanTransformerInfer
(
BaseTransformerInfer
):
class
WanTransformerInfer
(
BaseTransformerInfer
):
...
@@ -64,10 +67,10 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -64,10 +67,10 @@ class WanTransformerInfer(BaseTransformerInfer):
return
cu_seqlens_q
,
cu_seqlens_k
return
cu_seqlens_q
,
cu_seqlens_k
@
torch
.
compile
(
disable
=
not
CHECK_ENABLE_GRAPH_MODE
())
@
torch
.
compile
(
disable
=
not
CHECK_ENABLE_GRAPH_MODE
())
def
infer
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
):
def
infer
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
,
audio_dit_blocks
=
None
):
return
self
.
infer_func
(
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
)
return
self
.
infer_func
(
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
,
audio_dit_blocks
)
def
_infer_with_offload
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
):
def
_infer_with_offload
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
,
audio_dit_blocks
=
None
):
for
block_idx
in
range
(
self
.
blocks_num
):
for
block_idx
in
range
(
self
.
blocks_num
):
if
block_idx
==
0
:
if
block_idx
==
0
:
self
.
weights_stream_mgr
.
active_weights
[
0
]
=
weights
.
blocks
[
0
]
self
.
weights_stream_mgr
.
active_weights
[
0
]
=
weights
.
blocks
[
0
]
...
@@ -92,7 +95,7 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -92,7 +95,7 @@ class WanTransformerInfer(BaseTransformerInfer):
return
x
return
x
def
_infer_with_phases_offload
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
):
def
_infer_with_phases_offload
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
,
audio_dit_blocks
=
None
):
for
block_idx
in
range
(
weights
.
blocks_num
):
for
block_idx
in
range
(
weights
.
blocks_num
):
for
phase_idx
in
range
(
self
.
phases_num
):
for
phase_idx
in
range
(
self
.
phases_num
):
if
block_idx
==
0
and
phase_idx
==
0
:
if
block_idx
==
0
and
phase_idx
==
0
:
...
@@ -133,7 +136,7 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -133,7 +136,7 @@ class WanTransformerInfer(BaseTransformerInfer):
return
x
return
x
def
_infer_with_phases_lazy_offload
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
):
def
_infer_with_phases_lazy_offload
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
,
audio_dit_blocks
=
None
):
self
.
weights_stream_mgr
.
prefetch_weights_from_disk
(
weights
)
self
.
weights_stream_mgr
.
prefetch_weights_from_disk
(
weights
)
for
block_idx
in
range
(
weights
.
blocks_num
):
for
block_idx
in
range
(
weights
.
blocks_num
):
...
@@ -194,7 +197,22 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -194,7 +197,22 @@ class WanTransformerInfer(BaseTransformerInfer):
return
x
return
x
def
_infer_without_offload
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
):
def
zero_temporal_component_in_3DRoPE
(
self
,
valid_token_length
,
rotary_emb
=
None
):
if
rotary_emb
is
None
:
return
None
self
.
use_real
=
False
rope_t_dim
=
44
if
self
.
use_real
:
freqs_cos
,
freqs_sin
=
rotary_emb
freqs_cos
[
valid_token_length
:,
:,
:
rope_t_dim
]
=
0
freqs_sin
[
valid_token_length
:,
:,
:
rope_t_dim
]
=
0
return
freqs_cos
,
freqs_sin
else
:
freqs_cis
=
rotary_emb
freqs_cis
[
valid_token_length
:,
:,
:
rope_t_dim
//
2
]
=
0
return
freqs_cis
def
_infer_without_offload
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
,
audio_dit_blocks
=
None
):
for
block_idx
in
range
(
self
.
blocks_num
):
for
block_idx
in
range
(
self
.
blocks_num
):
x
=
self
.
infer_block
(
x
=
self
.
infer_block
(
weights
.
blocks
[
block_idx
],
weights
.
blocks
[
block_idx
],
...
@@ -206,6 +224,12 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -206,6 +224,12 @@ class WanTransformerInfer(BaseTransformerInfer):
freqs
,
freqs
,
context
,
context
,
)
)
if
audio_dit_blocks
is
not
None
and
len
(
audio_dit_blocks
)
>
0
:
for
ipa_out
in
audio_dit_blocks
:
if
block_idx
in
ipa_out
:
cur_modify
=
ipa_out
[
block_idx
]
x
=
cur_modify
[
"modify_func"
](
x
,
grid_sizes
,
**
cur_modify
[
"kwargs"
])
return
x
return
x
def
infer_block
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
):
def
infer_block
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
):
...
@@ -265,14 +289,23 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -265,14 +289,23 @@ class WanTransformerInfer(BaseTransformerInfer):
v
=
weights
.
self_attn_v
.
apply
(
norm1_out
).
view
(
s
,
n
,
d
)
v
=
weights
.
self_attn_v
.
apply
(
norm1_out
).
view
(
s
,
n
,
d
)
if
not
self
.
parallel_attention
:
if
not
self
.
parallel_attention
:
freqs_i
=
compute_freqs
(
q
.
size
(
2
)
//
2
,
grid_sizes
,
freqs
)
if
self
.
config
.
get
(
"audio_sr"
,
False
):
freqs_i
=
compute_freqs_audio
(
q
.
size
(
0
),
q
.
size
(
2
)
//
2
,
grid_sizes
,
freqs
)
else
:
freqs_i
=
compute_freqs
(
q
.
size
(
2
)
//
2
,
grid_sizes
,
freqs
)
else
:
else
:
freqs_i
=
compute_freqs_dist
(
q
.
size
(
0
),
q
.
size
(
2
)
//
2
,
grid_sizes
,
freqs
)
if
self
.
config
.
get
(
"audio_sr"
,
False
):
freqs_i
=
compute_freqs_audio_dist
(
q
.
size
(
0
),
q
.
size
(
2
)
//
2
,
grid_sizes
,
freqs
)
else
:
freqs_i
=
compute_freqs_dist
(
q
.
size
(
2
)
//
2
,
grid_sizes
,
freqs
)
freqs_i
=
self
.
zero_temporal_component_in_3DRoPE
(
seq_lens
,
freqs_i
)
q
=
self
.
apply_rotary_emb_func
(
q
,
freqs_i
)
q
=
self
.
apply_rotary_emb_func
(
q
,
freqs_i
)
k
=
self
.
apply_rotary_emb_func
(
k
,
freqs_i
)
k
=
self
.
apply_rotary_emb_func
(
k
,
freqs_i
)
cu_seqlens_q
,
cu_seqlens_k
=
self
.
_calculate_q_k_len
(
q
,
k_lens
=
seq_lens
)
k_lens
=
torch
.
empty_like
(
seq_lens
).
fill_
(
freqs_i
.
size
(
0
))
cu_seqlens_q
,
cu_seqlens_k
=
self
.
_calculate_q_k_len
(
q
,
k_lens
=
k_lens
)
if
self
.
clean_cuda_cache
:
if
self
.
clean_cuda_cache
:
del
freqs_i
,
norm1_out
,
norm1_weight
,
norm1_bias
del
freqs_i
,
norm1_out
,
norm1_weight
,
norm1_bias
...
@@ -353,7 +386,6 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -353,7 +386,6 @@ class WanTransformerInfer(BaseTransformerInfer):
q
,
q
,
k_lens
=
torch
.
tensor
([
k_img
.
size
(
0
)],
dtype
=
torch
.
int32
,
device
=
k
.
device
),
k_lens
=
torch
.
tensor
([
k_img
.
size
(
0
)],
dtype
=
torch
.
int32
,
device
=
k
.
device
),
)
)
img_attn_out
=
weights
.
cross_attn_2
.
apply
(
img_attn_out
=
weights
.
cross_attn_2
.
apply
(
q
=
q
,
q
=
q
,
k
=
k_img
,
k
=
k_img
,
...
...
lightx2v/models/networks/wan/infer/utils.py
View file @
e58dd9fe
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
loguru
import
logger
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
...
@@ -19,6 +20,45 @@ def compute_freqs(c, grid_sizes, freqs):
...
@@ -19,6 +20,45 @@ def compute_freqs(c, grid_sizes, freqs):
return
freqs_i
return
freqs_i
def
compute_freqs_audio
(
c
,
grid_sizes
,
freqs
):
freqs
=
freqs
.
split
([
c
-
2
*
(
c
//
3
),
c
//
3
,
c
//
3
],
dim
=
1
)
f
,
h
,
w
=
grid_sizes
[
0
].
tolist
()
f
=
f
+
1
seq_len
=
f
*
h
*
w
freqs_i
=
torch
.
cat
(
[
freqs
[
0
][:
f
].
view
(
f
,
1
,
1
,
-
1
).
expand
(
f
,
h
,
w
,
-
1
),
freqs
[
1
][:
h
].
view
(
1
,
h
,
1
,
-
1
).
expand
(
f
,
h
,
w
,
-
1
),
freqs
[
2
][:
w
].
view
(
1
,
1
,
w
,
-
1
).
expand
(
f
,
h
,
w
,
-
1
),
],
dim
=-
1
,
).
reshape
(
seq_len
,
1
,
-
1
)
return
freqs_i
def
compute_freqs_audio_dist
(
s
,
c
,
grid_sizes
,
freqs
):
world_size
=
dist
.
get_world_size
()
cur_rank
=
dist
.
get_rank
()
freqs
=
freqs
.
split
([
c
-
2
*
(
c
//
3
),
c
//
3
,
c
//
3
],
dim
=
1
)
f
,
h
,
w
=
grid_sizes
[
0
].
tolist
()
f
=
f
+
1
seq_len
=
f
*
h
*
w
freqs_i
=
torch
.
cat
(
[
freqs
[
0
][:
f
].
view
(
f
,
1
,
1
,
-
1
).
expand
(
f
,
h
,
w
,
-
1
),
freqs
[
1
][:
h
].
view
(
1
,
h
,
1
,
-
1
).
expand
(
f
,
h
,
w
,
-
1
),
freqs
[
2
][:
w
].
view
(
1
,
1
,
w
,
-
1
).
expand
(
f
,
h
,
w
,
-
1
),
],
dim
=-
1
,
).
reshape
(
seq_len
,
1
,
-
1
)
freqs_i
=
pad_freqs
(
freqs_i
,
s
*
world_size
)
s_per_rank
=
s
freqs_i_rank
=
freqs_i
[(
cur_rank
*
s_per_rank
)
:
((
cur_rank
+
1
)
*
s_per_rank
),
:,
:]
return
freqs_i_rank
def
compute_freqs_causvid
(
c
,
grid_sizes
,
freqs
,
start_frame
=
0
):
def
compute_freqs_causvid
(
c
,
grid_sizes
,
freqs
,
start_frame
=
0
):
freqs
=
freqs
.
split
([
c
-
2
*
(
c
//
3
),
c
//
3
,
c
//
3
],
dim
=
1
)
freqs
=
freqs
.
split
([
c
-
2
*
(
c
//
3
),
c
//
3
,
c
//
3
],
dim
=
1
)
f
,
h
,
w
=
grid_sizes
[
0
].
tolist
()
f
,
h
,
w
=
grid_sizes
[
0
].
tolist
()
...
...
lightx2v/models/runners/wan/wan_audio_runner.py
0 → 100644
View file @
e58dd9fe
import
os
import
gc
import
numpy
as
np
import
torch
import
torchvision.transforms.functional
as
TF
from
PIL
import
Image
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.models.runners.wan.wan_runner
import
WanRunner
from
lightx2v.models.runners.default_runner
import
DefaultRunner
from
lightx2v.models.schedulers.wan.scheduler
import
WanScheduler
from
lightx2v.models.networks.wan.model
import
WanModel
from
lightx2v.utils.profiler
import
ProfilingContext4Debug
,
ProfilingContext
from
lightx2v.models.input_encoders.hf.t5.model
import
T5EncoderModel
from
lightx2v.models.input_encoders.hf.xlm_roberta.model
import
CLIPModel
from
lightx2v.models.networks.wan.audio_model
import
WanAudioModel
from
lightx2v.models.networks.wan.lora_adapter
import
WanLoraWrapper
from
lightx2v.models.video_encoders.hf.wan.vae
import
WanVAE
from
lightx2v.models.networks.wan.audio_adapter
import
AudioAdapter
,
AudioAdapterPipe
,
rank0_load_state_dict_from_path
from
loguru
import
logger
import
torch.distributed
as
dist
from
einops
import
rearrange
import
torchaudio
as
ta
from
transformers
import
AutoFeatureExtractor
from
torchvision.datasets.folder
import
IMG_EXTENSIONS
from
torchvision.transforms
import
InterpolationMode
from
torchvision.transforms.functional
import
resize
import
subprocess
import
warnings
from
typing
import
Optional
,
Tuple
,
Union
import
pdb
def
get_crop_bbox
(
ori_h
,
ori_w
,
tgt_h
,
tgt_w
):
tgt_ar
=
tgt_h
/
tgt_w
ori_ar
=
ori_h
/
ori_w
if
abs
(
ori_ar
-
tgt_ar
)
<
0.01
:
return
0
,
ori_h
,
0
,
ori_w
if
ori_ar
>
tgt_ar
:
crop_h
=
int
(
tgt_ar
*
ori_w
)
y0
=
(
ori_h
-
crop_h
)
//
2
y1
=
y0
+
crop_h
return
y0
,
y1
,
0
,
ori_w
else
:
crop_w
=
int
(
ori_h
/
tgt_ar
)
x0
=
(
ori_w
-
crop_w
)
//
2
x1
=
x0
+
crop_w
return
0
,
ori_h
,
x0
,
x1
def
isotropic_crop_resize
(
frames
:
torch
.
Tensor
,
size
:
tuple
):
"""
frames: (T, C, H, W)
size: (H, W)
"""
ori_h
,
ori_w
=
frames
.
shape
[
2
:]
h
,
w
=
size
y0
,
y1
,
x0
,
x1
=
get_crop_bbox
(
ori_h
,
ori_w
,
h
,
w
)
cropped_frames
=
frames
[:,
:,
y0
:
y1
,
x0
:
x1
]
resized_frames
=
resize
(
cropped_frames
,
size
,
InterpolationMode
.
BICUBIC
,
antialias
=
True
)
return
resized_frames
def
adaptive_resize
(
img
):
bucket_config
=
{
0.667
:
(
np
.
array
([[
480
,
832
],
[
544
,
960
],
[
720
,
1280
]],
dtype
=
np
.
int64
),
np
.
array
([
0.2
,
0.5
,
0.3
])),
1.0
:
(
np
.
array
([[
480
,
480
],
[
576
,
576
],
[
704
,
704
],
[
960
,
960
]],
dtype
=
np
.
int64
),
np
.
array
([
0.1
,
0.1
,
0.5
,
0.3
])),
1.5
:
(
np
.
array
([[
480
,
832
],
[
544
,
960
],
[
720
,
1280
]],
dtype
=
np
.
int64
)[:,
::
-
1
],
np
.
array
([
0.2
,
0.5
,
0.3
])),
}
ori_height
=
img
.
shape
[
-
2
]
ori_weight
=
img
.
shape
[
-
1
]
ori_ratio
=
ori_height
/
ori_weight
aspect_ratios
=
np
.
array
(
np
.
array
(
list
(
bucket_config
.
keys
())))
closet_aspect_idx
=
np
.
argmin
(
np
.
abs
(
aspect_ratios
-
ori_ratio
))
closet_ratio
=
aspect_ratios
[
closet_aspect_idx
]
target_h
,
target_w
=
480
,
832
for
resolution
in
bucket_config
[
closet_ratio
][
0
]:
if
ori_height
*
ori_weight
>=
resolution
[
0
]
*
resolution
[
1
]:
target_h
,
target_w
=
resolution
cropped_img
=
isotropic_crop_resize
(
img
,
(
target_h
,
target_w
))
return
cropped_img
,
target_h
,
target_w
def
array_to_video
(
image_array
:
np
.
ndarray
,
output_path
:
str
,
fps
:
Union
[
int
,
float
]
=
30
,
resolution
:
Optional
[
Union
[
Tuple
[
int
,
int
],
Tuple
[
float
,
float
]]]
=
None
,
disable_log
:
bool
=
False
,
lossless
:
bool
=
True
,
)
->
None
:
if
not
isinstance
(
image_array
,
np
.
ndarray
):
raise
TypeError
(
"Input should be np.ndarray."
)
assert
image_array
.
ndim
==
4
assert
image_array
.
shape
[
-
1
]
==
3
if
resolution
:
height
,
width
=
resolution
width
+=
width
%
2
height
+=
height
%
2
else
:
image_array
=
pad_for_libx264
(
image_array
)
height
,
width
=
image_array
.
shape
[
1
],
image_array
.
shape
[
2
]
if
lossless
:
command
=
[
"/usr/bin/ffmpeg"
,
"-y"
,
# (optional) overwrite output file if it exists
"-f"
,
"rawvideo"
,
"-s"
,
f
"
{
int
(
width
)
}
x
{
int
(
height
)
}
"
,
# size of one frame
"-pix_fmt"
,
"bgr24"
,
"-r"
,
f
"
{
fps
}
"
,
# frames per second
"-loglevel"
,
"error"
,
"-threads"
,
"4"
,
"-i"
,
"-"
,
# The input comes from a pipe
"-vcodec"
,
"libx264rgb"
,
"-crf"
,
"0"
,
"-an"
,
# Tells FFMPEG not to expect any audio
output_path
,
]
else
:
command
=
[
"/usr/bin/ffmpeg"
,
"-y"
,
# (optional) overwrite output file if it exists
"-f"
,
"rawvideo"
,
"-s"
,
f
"
{
int
(
width
)
}
x
{
int
(
height
)
}
"
,
# size of one frame
"-pix_fmt"
,
"bgr24"
,
"-r"
,
f
"
{
fps
}
"
,
# frames per second
"-loglevel"
,
"error"
,
"-threads"
,
"4"
,
"-i"
,
"-"
,
# The input comes from a pipe
"-vcodec"
,
"libx264"
,
"-an"
,
# Tells FFMPEG not to expect any audio
output_path
,
]
if
not
disable_log
:
print
(
f
'Running "
{
" "
.
join
(
command
)
}
"'
)
process
=
subprocess
.
Popen
(
command
,
stdin
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
,
)
if
process
.
stdin
is
None
or
process
.
stderr
is
None
:
raise
BrokenPipeError
(
"No buffer received."
)
index
=
0
while
True
:
if
index
>=
image_array
.
shape
[
0
]:
break
process
.
stdin
.
write
(
image_array
[
index
].
tobytes
())
index
+=
1
process
.
stdin
.
close
()
process
.
stderr
.
close
()
process
.
wait
()
def
pad_for_libx264
(
image_array
):
if
image_array
.
ndim
==
2
or
(
image_array
.
ndim
==
3
and
image_array
.
shape
[
2
]
==
3
):
hei_index
=
0
wid_index
=
1
elif
image_array
.
ndim
==
4
or
(
image_array
.
ndim
==
3
and
image_array
.
shape
[
2
]
!=
3
):
hei_index
=
1
wid_index
=
2
else
:
return
image_array
hei_pad
=
image_array
.
shape
[
hei_index
]
%
2
wid_pad
=
image_array
.
shape
[
wid_index
]
%
2
if
hei_pad
+
wid_pad
>
0
:
pad_width
=
[]
for
dim_index
in
range
(
image_array
.
ndim
):
if
dim_index
==
hei_index
:
pad_width
.
append
((
0
,
hei_pad
))
elif
dim_index
==
wid_index
:
pad_width
.
append
((
0
,
wid_pad
))
else
:
pad_width
.
append
((
0
,
0
))
values
=
0
image_array
=
np
.
pad
(
image_array
,
pad_width
,
mode
=
"constant"
,
constant_values
=
values
)
return
image_array
def
generate_unique_path
(
path
):
if
not
os
.
path
.
exists
(
path
):
return
path
root
,
ext
=
os
.
path
.
splitext
(
path
)
index
=
1
new_path
=
f
"
{
root
}
-
{
index
}{
ext
}
"
while
os
.
path
.
exists
(
new_path
):
index
+=
1
new_path
=
f
"
{
root
}
-
{
index
}{
ext
}
"
return
new_path
def
save_to_video
(
gen_lvideo
,
out_path
,
target_fps
):
print
(
gen_lvideo
.
shape
)
gen_lvideo
=
rearrange
(
gen_lvideo
,
"B C T H W -> B T H W C"
)
gen_lvideo
=
(
gen_lvideo
[
0
].
cpu
().
numpy
()
*
127.5
+
127.5
).
astype
(
np
.
uint8
)
gen_lvideo
=
gen_lvideo
[...,
::
-
1
].
copy
()
generate_unique_path
(
out_path
)
array_to_video
(
gen_lvideo
,
output_path
=
out_path
,
fps
=
target_fps
,
lossless
=
False
)
def
save_audio
(
audio_array
:
str
,
audio_name
:
str
,
video_name
:
str
=
None
,
sr
:
int
=
16000
,
):
logger
.
info
(
f
"Saving audio to
{
audio_name
}
type:
{
type
(
audio_array
)
}
"
)
if
not
os
.
path
.
exists
(
audio_name
):
ta
.
save
(
audio_name
,
torch
.
tensor
(
audio_array
[
None
]),
sample_rate
=
sr
,
)
out_video
=
f
"
{
video_name
[:
-
4
]
}
_with_audio.mp4"
# generate_unique_path(out_path)
cmd
=
f
"/usr/bin/ffmpeg -i
{
video_name
}
-i
{
audio_name
}
{
out_video
}
"
subprocess
.
call
(
cmd
,
shell
=
True
)
@
RUNNER_REGISTER
(
"wan2.1_audio"
)
class
WanAudioRunner
(
WanRunner
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
def
load_audio_models
(
self
):
self
.
audio_encoder
=
AutoFeatureExtractor
.
from_pretrained
(
self
.
config
[
"model_path"
],
subfolder
=
"audio_encoder"
)
audio_adaper
=
AudioAdapter
.
from_transformer
(
self
.
model
,
audio_feature_dim
=
1024
,
interval
=
1
,
time_freq_dim
=
256
,
projection_transformer_layers
=
4
,
)
load_path
=
"/mnt/aigc/zoemodels/Zoetrained/vigendit/audio_driven/audio_adapter/audio_adapter_V1_0507_bf16.safetensors"
audio_adapter
=
rank0_load_state_dict_from_path
(
audio_adaper
,
load_path
,
strict
=
False
)
device
=
self
.
model
.
device
audio_encoder_repo
=
"/mnt/aigc/zoemodels/models--TencentGameMate--chinese-hubert-large/snapshots/90cb660492214f687e60f5ca509b20edae6e75bd"
audio_adapter_pipe
=
AudioAdapterPipe
(
audio_adapter
,
audio_encoder_repo
=
audio_encoder_repo
,
dtype
=
torch
.
bfloat16
,
device
=
device
,
generator
=
torch
.
Generator
(
device
),
weight
=
1.0
)
return
audio_adapter_pipe
def
load_transformer
(
self
):
base_model
=
WanAudioModel
(
self
.
config
.
model_path
,
self
.
config
,
self
.
init_device
)
return
base_model
def
run_image_encoder
(
self
,
config
,
vae_model
):
ref_img
=
Image
.
open
(
config
.
image_path
)
ref_img
=
(
np
.
array
(
ref_img
).
astype
(
np
.
float32
)
-
127.5
)
/
127.5
ref_img
=
torch
.
from_numpy
(
ref_img
).
to
(
vae_model
.
device
)
ref_img
=
rearrange
(
ref_img
,
"H W C -> 1 C H W"
)
ref_img
=
ref_img
[:,
:
3
]
# resize and crop image
cond_frms
,
tgt_h
,
tgt_w
=
adaptive_resize
(
ref_img
)
config
.
tgt_h
=
tgt_h
config
.
tgt_w
=
tgt_w
clip_encoder_out
=
self
.
image_encoder
.
visual
([
cond_frms
.
squeeze
(
0
)[:,
None
,
:,
:]],
config
).
squeeze
(
0
).
to
(
torch
.
bfloat16
)
cond_frms
=
rearrange
(
cond_frms
,
"1 C H W -> 1 C 1 H W"
)
lat_h
,
lat_w
=
tgt_h
//
8
,
tgt_w
//
8
config
.
lat_h
=
lat_h
config
.
lat_w
=
lat_w
vae_encode_out
=
vae_model
.
encode
(
cond_frms
.
to
(
torch
.
float
),
config
)
if
isinstance
(
vae_encode_out
,
list
):
#
# list转tensor
vae_encode_out
=
torch
.
stack
(
vae_encode_out
,
dim
=
0
).
to
(
torch
.
bfloat16
)
return
vae_encode_out
,
clip_encoder_out
def
run_input_encoder_internal
(
self
):
image_encoder_output
=
None
if
os
.
path
.
isfile
(
self
.
config
.
image_path
):
with
ProfilingContext
(
"Run Img Encoder"
):
vae_encode_out
,
clip_encoder_out
=
self
.
run_image_encoder
(
self
.
config
,
self
.
vae_encoder
)
image_encoder_output
=
{
"clip_encoder_out"
:
clip_encoder_out
,
"vae_encode_out"
:
vae_encode_out
,
}
logger
.
info
(
f
"clip_encoder_out:
{
clip_encoder_out
.
shape
}
vae_encode_out:
{
vae_encode_out
.
shape
}
"
)
with
ProfilingContext
(
"Run Text Encoder"
):
with
open
(
self
.
config
[
"prompt_path"
],
"r"
,
encoding
=
"utf-8"
)
as
f
:
prompt
=
f
.
readline
().
strip
()
logger
.
info
(
f
"Prompt:
{
prompt
}
"
)
img
=
Image
.
open
(
self
.
config
[
"image_path"
]).
convert
(
"RGB"
)
text_encoder_output
=
self
.
run_text_encoder
(
prompt
,
img
)
self
.
set_target_shape
()
self
.
inputs
=
{
"text_encoder_output"
:
text_encoder_output
,
"image_encoder_output"
:
image_encoder_output
}
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
def
set_target_shape
(
self
):
ret
=
{}
num_channels_latents
=
16
if
self
.
config
.
task
==
"i2v"
:
self
.
config
.
target_shape
=
(
num_channels_latents
,
(
self
.
config
.
target_video_length
-
1
)
//
self
.
config
.
vae_stride
[
0
]
+
1
,
self
.
config
.
lat_h
,
self
.
config
.
lat_w
,
)
ret
[
"lat_h"
]
=
self
.
config
.
lat_h
ret
[
"lat_w"
]
=
self
.
config
.
lat_w
else
:
error_msg
=
"t2v task is not supported in WanAudioRunner"
assert
1
==
0
,
error_msg
ret
[
"target_shape"
]
=
self
.
config
.
target_shape
return
ret
def
run
(
self
):
def
load_audio
(
in_path
:
str
,
sr
:
float
=
16000
):
audio_array
,
ori_sr
=
ta
.
load
(
in_path
)
audio_array
=
ta
.
functional
.
resample
(
audio_array
.
mean
(
0
),
orig_freq
=
ori_sr
,
new_freq
=
sr
)
return
audio_array
.
numpy
()
def
get_audio_range
(
start_frame
:
int
,
end_frame
:
int
,
fps
:
float
,
audio_sr
:
float
=
16000
):
audio_frame_rate
=
audio_sr
/
fps
return
round
(
start_frame
*
audio_frame_rate
),
round
((
end_frame
+
1
)
*
audio_frame_rate
)
self
.
inputs
[
"audio_adapter_pipe"
]
=
self
.
load_audio_models
()
# process audio
audio_sr
=
16000
max_num_frames
=
81
# wan2.1一段最多81帧,5秒,16fps
target_fps
=
self
.
config
.
get
(
"target_fps"
,
16
)
# 音视频同步帧率
video_duration
=
self
.
config
.
get
(
"video_duration"
,
8
)
# 期望视频输出时长
audio_array
=
load_audio
(
self
.
config
[
"audio_path"
],
sr
=
audio_sr
)
audio_len
=
int
(
audio_array
.
shape
[
0
]
/
audio_sr
*
target_fps
)
prev_frame_length
=
5
prev_token_length
=
(
prev_frame_length
-
1
)
//
4
+
1
max_num_audio_length
=
int
((
max_num_frames
+
1
)
/
target_fps
*
16000
)
interval_num
=
1
# expected_frames
expected_frames
=
min
(
max
(
1
,
int
(
float
(
video_duration
)
*
target_fps
)),
audio_len
)
res_frame_num
=
0
if
expected_frames
<=
max_num_frames
:
interval_num
=
1
else
:
interval_num
=
max
(
int
((
expected_frames
-
max_num_frames
)
/
(
max_num_frames
-
prev_frame_length
))
+
1
,
1
)
res_frame_num
=
expected_frames
-
interval_num
*
(
max_num_frames
-
prev_frame_length
)
if
res_frame_num
>
5
:
interval_num
+=
1
audio_start
,
audio_end
=
get_audio_range
(
0
,
expected_frames
,
fps
=
target_fps
,
audio_sr
=
audio_sr
)
audio_array_ori
=
audio_array
[
audio_start
:
audio_end
]
gen_video_list
=
[]
cut_audio_list
=
[]
# reference latents
tgt_h
=
self
.
config
.
tgt_h
tgt_w
=
self
.
config
.
tgt_w
device
=
self
.
model
.
scheduler
.
latents
.
device
dtype
=
torch
.
bfloat16
vae_dtype
=
torch
.
float
for
idx
in
range
(
interval_num
):
torch
.
manual_seed
(
42
+
idx
)
logger
.
info
(
f
"### manual_seed:
{
42
+
idx
}
####"
)
useful_length
=
-
1
if
idx
==
0
:
# 第一段 Condition padding0
prev_frames
=
torch
.
zeros
((
1
,
3
,
max_num_frames
,
tgt_h
,
tgt_w
),
device
=
device
)
prev_latents
=
self
.
vae_encoder
.
encode
(
prev_frames
.
to
(
vae_dtype
),
self
.
config
)[
0
].
to
(
dtype
)
prev_len
=
0
audio_start
,
audio_end
=
get_audio_range
(
0
,
max_num_frames
,
fps
=
target_fps
,
audio_sr
=
audio_sr
)
audio_array
=
audio_array_ori
[
audio_start
:
audio_end
]
if
expected_frames
<
max_num_frames
:
useful_length
=
audio_array
.
shape
[
0
]
audio_array
=
np
.
concatenate
((
audio_array
,
np
.
zeros
(
max_num_audio_length
)[:
max_num_audio_length
-
useful_length
]),
axis
=
0
)
audio_input_feat
=
self
.
audio_encoder
(
audio_array
,
sampling_rate
=
audio_sr
,
return_tensors
=
"pt"
).
input_values
.
squeeze
(
0
)
elif
res_frame_num
>
5
and
idx
==
interval_num
-
1
:
# 最后一段可能不够81帧
prev_frames
=
torch
.
zeros
((
1
,
3
,
max_num_frames
,
tgt_h
,
tgt_w
),
device
=
device
)
prev_frames
[:,
:,
:
prev_frame_length
]
=
gen_video_list
[
-
1
][:,
:,
-
prev_frame_length
:]
prev_latents
=
self
.
vae_encoder
.
encode
(
prev_frames
.
to
(
vae_dtype
),
self
.
config
)[
0
].
to
(
dtype
)
prev_len
=
prev_token_length
audio_start
,
audio_end
=
get_audio_range
(
idx
*
max_num_frames
-
idx
*
prev_frame_length
,
expected_frames
,
fps
=
target_fps
,
audio_sr
=
audio_sr
)
audio_array
=
audio_array_ori
[
audio_start
:
audio_end
]
useful_length
=
audio_array
.
shape
[
0
]
audio_array
=
np
.
concatenate
((
audio_array
,
np
.
zeros
(
max_num_audio_length
)[:
max_num_audio_length
-
useful_length
]),
axis
=
0
)
audio_input_feat
=
self
.
audio_encoder
(
audio_array
,
sampling_rate
=
audio_sr
,
return_tensors
=
"pt"
).
input_values
.
squeeze
(
0
)
else
:
# 中间段满81帧带pre_latens
prev_frames
=
torch
.
zeros
((
1
,
3
,
max_num_frames
,
tgt_h
,
tgt_w
),
device
=
device
)
prev_frames
[:,
:,
:
prev_frame_length
]
=
gen_video_list
[
-
1
][:,
:,
-
prev_frame_length
:]
prev_latents
=
self
.
vae_encoder
.
encode
(
prev_frames
.
to
(
vae_dtype
),
self
.
config
)[
0
].
to
(
dtype
)
prev_len
=
prev_token_length
audio_start
,
audio_end
=
get_audio_range
(
idx
*
max_num_frames
-
idx
*
prev_frame_length
,
(
idx
+
1
)
*
max_num_frames
-
idx
*
prev_frame_length
,
fps
=
target_fps
,
audio_sr
=
audio_sr
)
audio_array
=
audio_array_ori
[
audio_start
:
audio_end
]
audio_input_feat
=
self
.
audio_encoder
(
audio_array
,
sampling_rate
=
audio_sr
,
return_tensors
=
"pt"
).
input_values
.
squeeze
(
0
)
self
.
inputs
[
"audio_encoder_output"
]
=
audio_input_feat
.
to
(
device
)
if
idx
!=
0
:
self
.
model
.
scheduler
.
reset
()
if
prev_latents
is
not
None
:
ltnt_channel
,
nframe
,
height
,
width
=
self
.
model
.
scheduler
.
latents
.
shape
bs
=
1
prev_mask
=
torch
.
zeros
((
bs
,
1
,
nframe
,
height
,
width
),
device
=
device
,
dtype
=
dtype
)
if
prev_len
>
0
:
prev_mask
[:,
:,
:
prev_len
]
=
1.0
previmg_encoder_output
=
{
"prev_latents"
:
prev_latents
,
"prev_mask"
:
prev_mask
,
}
self
.
inputs
[
"previmg_encoder_output"
]
=
previmg_encoder_output
for
step_index
in
range
(
self
.
model
.
scheduler
.
infer_steps
):
logger
.
info
(
f
"==> step_index:
{
step_index
}
/
{
self
.
model
.
scheduler
.
infer_steps
}
"
)
with
ProfilingContext4Debug
(
"step_pre"
):
self
.
model
.
scheduler
.
step_pre
(
step_index
=
step_index
)
with
ProfilingContext4Debug
(
"infer"
):
self
.
model
.
infer
(
self
.
inputs
)
with
ProfilingContext4Debug
(
"step_post"
):
self
.
model
.
scheduler
.
step_post
()
latents
=
self
.
model
.
scheduler
.
latents
generator
=
self
.
model
.
scheduler
.
generator
gen_video
=
self
.
vae_decoder
.
decode
(
latents
,
generator
=
generator
,
config
=
self
.
config
)
# gen_img = vae_handler.decode(xt.to(vae_dtype))
# B, C, T, H, W
gen_video
=
torch
.
clamp
(
gen_video
,
-
1
,
1
)
start_frame
=
0
if
idx
==
0
else
prev_frame_length
start_audio_frame
=
0
if
idx
==
0
else
int
((
prev_frame_length
+
1
)
*
audio_sr
/
target_fps
)
print
(
f
"----
{
idx
}
,
{
gen_video
[:,
:,
start_frame
:].
shape
}
"
)
if
res_frame_num
>
5
and
idx
==
interval_num
-
1
:
gen_video_list
.
append
(
gen_video
[:,
:,
start_frame
:
res_frame_num
])
cut_audio_list
.
append
(
audio_array
[
start_audio_frame
:
useful_length
])
elif
expected_frames
<
max_num_frames
and
useful_length
!=
-
1
:
gen_video_list
.
append
(
gen_video
[:,
:,
start_frame
:
expected_frames
])
cut_audio_list
.
append
(
audio_array
[
start_audio_frame
:
useful_length
])
else
:
gen_video_list
.
append
(
gen_video
[:,
:,
start_frame
:])
cut_audio_list
.
append
(
audio_array
[
start_audio_frame
:])
gen_lvideo
=
torch
.
cat
(
gen_video_list
,
dim
=
2
).
float
()
merge_audio
=
np
.
concatenate
(
cut_audio_list
,
axis
=
0
).
astype
(
np
.
float32
)
out_path
=
os
.
path
.
join
(
"./"
,
"video_merge.mp4"
)
audio_file
=
os
.
path
.
join
(
"./"
,
"audio_merge.wav"
)
save_to_video
(
gen_lvideo
,
out_path
,
target_fps
)
save_audio
(
merge_audio
,
audio_file
,
out_path
)
os
.
remove
(
out_path
)
os
.
remove
(
audio_file
)
async
def
run_pipeline
(
self
):
if
self
.
config
[
"use_prompt_enhancer"
]:
self
.
config
[
"prompt_enhanced"
]
=
self
.
post_prompt_enhancer
()
self
.
run_input_encoder_internal
()
self
.
set_target_shape
()
self
.
init_scheduler
()
self
.
model
.
scheduler
.
prepare
(
self
.
inputs
[
"image_encoder_output"
])
self
.
run
()
self
.
end_run
()
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
lightx2v/models/schedulers/wan/scheduler.py
View file @
e58dd9fe
...
@@ -115,6 +115,15 @@ class WanScheduler(BaseScheduler):
...
@@ -115,6 +115,15 @@ class WanScheduler(BaseScheduler):
x0_pred
=
sample
-
sigma_t
*
model_output
x0_pred
=
sample
-
sigma_t
*
model_output
return
x0_pred
return
x0_pred
def
reset
(
self
):
self
.
model_outputs
=
[
None
]
*
self
.
solver_order
self
.
timestep_list
=
[
None
]
*
self
.
solver_order
self
.
last_sample
=
None
self
.
noise_pred
=
None
self
.
this_order
=
None
self
.
lower_order_nums
=
0
self
.
prepare_latents
(
self
.
config
.
target_shape
,
dtype
=
torch
.
float32
)
def
multistep_uni_p_bh_update
(
def
multistep_uni_p_bh_update
(
self
,
self
,
model_output
:
torch
.
Tensor
,
model_output
:
torch
.
Tensor
,
...
...
scripts/run_wan_i2v_audio.sh
0 → 100755
View file @
e58dd9fe
#!/bin/bash
# set path and first
lightx2v_path
=
"/mnt/Text2Video/wangshankun/lightx2v"
model_path
=
"/mnt/Text2Video/wangshankun/HF_Cache/Wan2.1-I2V-Audio-14B-720P/"
# check section
if
[
-z
"
${
CUDA_VISIBLE_DEVICES
}
"
]
;
then
cuda_devices
=
0
echo
"Warn: CUDA_VISIBLE_DEVICES is not set, using default value:
${
cuda_devices
}
, change at shell script or set env variable."
export
CUDA_VISIBLE_DEVICES
=
${
cuda_devices
}
fi
if
[
-z
"
${
lightx2v_path
}
"
]
;
then
echo
"Error: lightx2v_path is not set. Please set this variable first."
exit
1
fi
if
[
-z
"
${
model_path
}
"
]
;
then
echo
"Error: model_path is not set. Please set this variable first."
exit
1
fi
export
TOKENIZERS_PARALLELISM
=
false
export
PYTHONPATH
=
${
lightx2v_path
}
:
$PYTHONPATH
export
ENABLE_PROFILING_DEBUG
=
true
export
ENABLE_GRAPH_MODE
=
false
python
-m
lightx2v.infer
\
--model_cls
wan2.1_audio
\
--task
i2v
\
--model_path
$model_path
\
--config_json
${
lightx2v_path
}
/configs/wan_i2v_audio.json
\
--prompt_path
${
lightx2v_path
}
/assets/inputs/audio/15.txt
\
--negative_prompt
色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走
\
--image_path
${
lightx2v_path
}
/assets/inputs/audio/15.png
\
--audio_path
${
lightx2v_path
}
/assets/inputs/audio/15.wav
\
--save_video_path
${
lightx2v_path
}
/save_results/output_lightx2v_wan_i2v_audio.mp4
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment