Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
e58dd9fe
Commit
e58dd9fe
authored
Jun 25, 2025
by
wangshankun
Browse files
audio驱动wan视频生成
parent
7260cb2e
Changes
11
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
1280 additions
and
13 deletions
+1280
-13
configs/wan_i2v_audio.json
configs/wan_i2v_audio.json
+12
-0
lightx2v/infer.py
lightx2v/infer.py
+5
-2
lightx2v/models/networks/wan/audio_adapter.py
lightx2v/models/networks/wan/audio_adapter.py
+414
-0
lightx2v/models/networks/wan/audio_model.py
lightx2v/models/networks/wan/audio_model.py
+78
-0
lightx2v/models/networks/wan/infer/post_wan_audio_infer.py
lightx2v/models/networks/wan/infer/post_wan_audio_infer.py
+43
-0
lightx2v/models/networks/wan/infer/pre_wan_audio_infer.py
lightx2v/models/networks/wan/infer/pre_wan_audio_infer.py
+106
-0
lightx2v/models/networks/wan/infer/transformer_infer.py
lightx2v/models/networks/wan/infer/transformer_infer.py
+43
-11
lightx2v/models/networks/wan/infer/utils.py
lightx2v/models/networks/wan/infer/utils.py
+40
-0
lightx2v/models/runners/wan/wan_audio_runner.py
lightx2v/models/runners/wan/wan_audio_runner.py
+490
-0
lightx2v/models/schedulers/wan/scheduler.py
lightx2v/models/schedulers/wan/scheduler.py
+9
-0
scripts/run_wan_i2v_audio.sh
scripts/run_wan_i2v_audio.sh
+40
-0
No files found.
configs/wan_i2v_audio.json
0 → 100755
View file @
e58dd9fe
{
"infer_steps"
:
20
,
"target_video_length"
:
81
,
"target_height"
:
480
,
"target_width"
:
832
,
"attention_type"
:
"flash_attn3"
,
"seed"
:
42
,
"sample_guide_scale"
:
5.0
,
"sample_shift"
:
5
,
"enable_cfg"
:
false
,
"cpu_offload"
:
false
}
lightx2v/infer.py
View file @
e58dd9fe
...
...
@@ -14,6 +14,7 @@ from lightx2v.models.runners.hunyuan.hunyuan_runner import HunyuanRunner
from
lightx2v.models.runners.wan.wan_runner
import
WanRunner
from
lightx2v.models.runners.wan.wan_distill_runner
import
WanDistillRunner
from
lightx2v.models.runners.wan.wan_causvid_runner
import
WanCausVidRunner
from
lightx2v.models.runners.wan.wan_audio_runner
import
WanAudioRunner
from
lightx2v.models.runners.wan.wan_skyreels_v2_df_runner
import
WanSkyreelsV2DFRunner
from
lightx2v.models.runners.graph_runner
import
GraphRunner
from
lightx2v.models.runners.cogvideox.cogvidex_runner
import
CogvideoxRunner
...
...
@@ -41,14 +42,16 @@ def init_runner(config):
async
def
main
():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--model_cls"
,
type
=
str
,
required
=
True
,
choices
=
[
"wan2.1"
,
"hunyuan"
,
"wan2.1_distill"
,
"wan2.1_causvid"
,
"wan2.1_skyreels_v2_df"
,
"cogvideox"
],
default
=
"hunyuan"
)
parser
.
add_argument
(
"--model_cls"
,
type
=
str
,
required
=
True
,
choices
=
[
"wan2.1"
,
"hunyuan"
,
"wan2.1_distill"
,
"wan2.1_causvid"
,
"wan2.1_skyreels_v2_df"
,
"cogvideox"
,
"wan2.1_audio"
],
default
=
"hunyuan"
)
parser
.
add_argument
(
"--task"
,
type
=
str
,
choices
=
[
"t2v"
,
"i2v"
],
default
=
"t2v"
)
parser
.
add_argument
(
"--model_path"
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
"--config_json"
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
"--use_prompt_enhancer"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--prompt"
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
"--prompt"
,
type
=
str
,
default
=
""
,
help
=
"The input prompt for text-to-video generation"
)
parser
.
add_argument
(
"--negative_prompt"
,
type
=
str
,
default
=
""
)
parser
.
add_argument
(
"--prompt_path"
,
type
=
str
,
default
=
""
,
help
=
"The path to input prompt file"
)
parser
.
add_argument
(
"--audio_path"
,
type
=
str
,
default
=
""
,
help
=
"The path to input audio file"
)
parser
.
add_argument
(
"--image_path"
,
type
=
str
,
default
=
""
,
help
=
"The path to input image file or path for image-to-video (i2v) task"
)
parser
.
add_argument
(
"--save_video_path"
,
type
=
str
,
default
=
"./output_lightx2v.mp4"
,
help
=
"The path to save video path/file"
)
args
=
parser
.
parse_args
()
...
...
lightx2v/models/networks/wan/audio_adapter.py
0 → 100644
View file @
e58dd9fe
import
flash_attn
import
math
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
diffusers.models.embeddings
import
TimestepEmbedding
,
Timesteps
from
einops
import
rearrange
from
transformers
import
AutoModel
from
loguru
import
logger
import
pdb
import
os
import
safetensors
from
typing
import
List
,
Optional
,
Tuple
,
Union
def
load_safetensors
(
in_path
:
str
):
if
os
.
path
.
isdir
(
in_path
):
return
load_safetensors_from_dir
(
in_path
)
elif
os
.
path
.
isfile
(
in_path
):
return
load_safetensors_from_path
(
in_path
)
else
:
raise
ValueError
(
f
"
{
in_path
}
does not exist"
)
def
load_safetensors_from_path
(
in_path
:
str
):
tensors
=
{}
with
safetensors
.
safe_open
(
in_path
,
framework
=
"pt"
,
device
=
"cpu"
)
as
f
:
for
key
in
f
.
keys
():
tensors
[
key
]
=
f
.
get_tensor
(
key
)
return
tensors
def
load_safetensors_from_dir
(
in_dir
:
str
):
tensors
=
{}
safetensors
=
os
.
listdir
(
in_dir
)
safetensors
=
[
f
for
f
in
safetensors
if
f
.
endswith
(
".safetensors"
)]
for
f
in
safetensors
:
tensors
.
update
(
load_safetensors_from_path
(
os
.
path
.
join
(
in_dir
,
f
)))
return
tensors
def
load_pt_safetensors
(
in_path
:
str
):
ext
=
os
.
path
.
splitext
(
in_path
)[
-
1
]
if
ext
in
(
".pt"
,
".pth"
,
".tar"
):
state_dict
=
torch
.
load
(
in_path
,
map_location
=
"cpu"
,
weights_only
=
True
)
else
:
state_dict
=
load_safetensors
(
in_path
)
return
state_dict
def
rank0_load_state_dict_from_path
(
model
,
in_path
:
str
,
strict
:
bool
=
True
):
import
torch.distributed
as
dist
if
(
dist
.
is_initialized
()
and
dist
.
get_rank
()
==
0
)
or
(
not
dist
.
is_initialized
()):
state_dict
=
load_pt_safetensors
(
in_path
)
model
.
load_state_dict
(
state_dict
,
strict
=
strict
)
if
dist
.
is_initialized
():
dist
.
barrier
()
return
model
.
to
(
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
def
linear_interpolation
(
features
,
output_len
:
int
):
features
=
features
.
transpose
(
1
,
2
)
output_features
=
F
.
interpolate
(
features
,
size
=
output_len
,
align_corners
=
False
,
mode
=
"linear"
)
return
output_features
.
transpose
(
1
,
2
)
def
get_q_lens_audio_range
(
batchsize
,
n_tokens_per_rank
,
n_query_tokens
,
n_tokens_per_frame
,
sp_rank
,
):
if
n_query_tokens
==
0
:
q_lens
=
[
1
]
*
batchsize
return
q_lens
,
0
,
1
idx0
=
n_tokens_per_rank
*
sp_rank
first_length
=
idx0
-
idx0
//
n_tokens_per_frame
*
n_tokens_per_frame
n_frames
=
(
n_query_tokens
-
first_length
)
//
n_tokens_per_frame
last_length
=
n_query_tokens
-
n_frames
*
n_tokens_per_frame
-
first_length
q_lens
=
[]
if
first_length
>
0
:
q_lens
.
append
(
first_length
)
q_lens
+=
[
n_tokens_per_frame
]
*
n_frames
if
last_length
>
0
:
q_lens
.
append
(
last_length
)
t0
=
idx0
//
n_tokens_per_frame
idx1
=
idx0
+
n_query_tokens
t1
=
math
.
ceil
(
idx1
/
n_tokens_per_frame
)
return
q_lens
*
batchsize
,
t0
,
t1
class
PerceiverAttentionCA
(
nn
.
Module
):
def
__init__
(
self
,
dim_head
=
128
,
heads
=
16
,
kv_dim
=
2048
,
adaLN
:
bool
=
False
):
super
().
__init__
()
self
.
dim_head
=
dim_head
self
.
heads
=
heads
inner_dim
=
dim_head
*
heads
kv_dim
=
inner_dim
if
kv_dim
is
None
else
kv_dim
self
.
norm_kv
=
nn
.
LayerNorm
(
kv_dim
)
self
.
norm_q
=
nn
.
LayerNorm
(
inner_dim
,
elementwise_affine
=
not
adaLN
)
self
.
to_q
=
nn
.
Linear
(
inner_dim
,
inner_dim
)
self
.
to_kv
=
nn
.
Linear
(
kv_dim
,
inner_dim
*
2
)
self
.
to_out
=
nn
.
Linear
(
inner_dim
,
inner_dim
)
if
adaLN
:
self
.
shift_scale_gate
=
nn
.
Parameter
(
torch
.
randn
(
1
,
3
,
inner_dim
)
/
inner_dim
**
0.5
)
else
:
shift_scale_gate
=
torch
.
zeros
((
1
,
3
,
inner_dim
))
shift_scale_gate
[:,
2
]
=
1
self
.
register_buffer
(
"shift_scale_gate"
,
shift_scale_gate
,
persistent
=
False
)
def
forward
(
self
,
x
,
latents
,
t_emb
,
q_lens
,
k_lens
):
"""x shape (batchsize, latent_frame, audio_tokens_per_latent,
model_dim) latents (batchsize, length, model_dim)"""
batchsize
=
len
(
x
)
x
=
self
.
norm_kv
(
x
)
shift
,
scale
,
gate
=
(
t_emb
+
self
.
shift_scale_gate
).
chunk
(
3
,
dim
=
1
)
latents
=
self
.
norm_q
(
latents
)
*
(
1
+
scale
)
+
shift
q
=
self
.
to_q
(
latents
)
k
,
v
=
self
.
to_kv
(
x
).
chunk
(
2
,
dim
=-
1
)
q
=
rearrange
(
q
,
"B L (H C) -> (B L) H C"
,
H
=
self
.
heads
)
k
=
rearrange
(
k
,
"B T L (H C) -> (B T L) H C"
,
H
=
self
.
heads
)
v
=
rearrange
(
v
,
"B T L (H C) -> (B T L) H C"
,
H
=
self
.
heads
)
out
=
flash_attn
.
flash_attn_varlen_func
(
q
=
q
,
k
=
k
,
v
=
v
,
cu_seqlens_q
=
torch
.
cat
([
q_lens
.
new_zeros
([
1
]),
q_lens
]).
cumsum
(
0
,
dtype
=
torch
.
int32
).
to
(
q
.
device
,
non_blocking
=
True
),
cu_seqlens_k
=
torch
.
cat
([
k_lens
.
new_zeros
([
1
]),
k_lens
]).
cumsum
(
0
,
dtype
=
torch
.
int32
).
to
(
q
.
device
,
non_blocking
=
True
),
max_seqlen_q
=
q_lens
.
max
(),
max_seqlen_k
=
k_lens
.
max
(),
dropout_p
=
0.0
,
softmax_scale
=
None
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
deterministic
=
False
,
)
out
=
rearrange
(
out
,
"(B L) H C -> B L (H C)"
,
B
=
batchsize
)
return
self
.
to_out
(
out
)
*
gate
class
AudioProjection
(
nn
.
Module
):
def
__init__
(
self
,
audio_feature_dim
:
int
=
768
,
n_neighbors
:
tuple
=
(
2
,
2
),
num_tokens
:
int
=
32
,
mlp_dims
:
tuple
=
(
1024
,
1024
,
32
*
768
),
transformer_layers
:
int
=
4
,
):
super
().
__init__
()
mlp
=
[]
self
.
left
,
self
.
right
=
n_neighbors
self
.
audio_frames
=
sum
(
n_neighbors
)
+
1
in_dim
=
audio_feature_dim
*
self
.
audio_frames
for
i
,
out_dim
in
enumerate
(
mlp_dims
):
mlp
.
append
(
nn
.
Linear
(
in_dim
,
out_dim
))
if
i
!=
len
(
mlp_dims
)
-
1
:
mlp
.
append
(
nn
.
ReLU
())
in_dim
=
out_dim
self
.
mlp
=
nn
.
Sequential
(
*
mlp
)
self
.
norm
=
nn
.
LayerNorm
(
mlp_dims
[
-
1
]
//
num_tokens
)
self
.
num_tokens
=
num_tokens
if
transformer_layers
>
0
:
decoder_layer
=
nn
.
TransformerDecoderLayer
(
d_model
=
audio_feature_dim
,
nhead
=
audio_feature_dim
//
64
,
dim_feedforward
=
4
*
audio_feature_dim
,
dropout
=
0.0
,
batch_first
=
True
)
self
.
transformer_decoder
=
nn
.
TransformerDecoder
(
decoder_layer
,
num_layers
=
transformer_layers
,
)
else
:
self
.
transformer_decoder
=
None
def
forward
(
self
,
audio_feature
,
latent_frame
):
video_frame
=
(
latent_frame
-
1
)
*
4
+
1
audio_feature_ori
=
audio_feature
audio_feature
=
linear_interpolation
(
audio_feature_ori
,
video_frame
)
if
self
.
transformer_decoder
is
not
None
:
audio_feature
=
self
.
transformer_decoder
(
audio_feature
,
audio_feature_ori
)
audio_feature
=
F
.
pad
(
audio_feature
,
pad
=
(
0
,
0
,
self
.
left
,
self
.
right
),
mode
=
"replicate"
)
audio_feature
=
audio_feature
.
unfold
(
dimension
=
1
,
size
=
self
.
audio_frames
,
step
=
1
)
audio_feature
=
rearrange
(
audio_feature
,
"B T C W -> B T (W C)"
)
audio_feature
=
self
.
mlp
(
audio_feature
)
# (B, video_frame, C)
audio_feature
=
rearrange
(
audio_feature
,
"B T (N C) -> B T N C"
,
N
=
self
.
num_tokens
)
# (B, video_frame, num_tokens, C)
return
self
.
norm
(
audio_feature
)
class
TimeEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
time_freq_dim
,
time_proj_dim
):
super
().
__init__
()
self
.
timesteps_proj
=
Timesteps
(
num_channels
=
time_freq_dim
,
flip_sin_to_cos
=
True
,
downscale_freq_shift
=
0
)
self
.
time_embedder
=
TimestepEmbedding
(
in_channels
=
time_freq_dim
,
time_embed_dim
=
dim
)
self
.
act_fn
=
nn
.
SiLU
()
self
.
time_proj
=
nn
.
Linear
(
dim
,
time_proj_dim
)
def
forward
(
self
,
timestep
:
torch
.
Tensor
,
):
timestep
=
self
.
timesteps_proj
(
timestep
)
time_embedder_dtype
=
next
(
iter
(
self
.
time_embedder
.
parameters
())).
dtype
timestep
=
timestep
.
to
(
time_embedder_dtype
)
temb
=
self
.
time_embedder
(
timestep
)
timestep_proj
=
self
.
time_proj
(
self
.
act_fn
(
temb
))
return
timestep_proj
class
AudioAdapter
(
nn
.
Module
):
def
__init__
(
self
,
attention_head_dim
=
64
,
num_attention_heads
=
40
,
base_num_layers
=
30
,
interval
=
1
,
audio_feature_dim
:
int
=
768
,
num_tokens
:
int
=
32
,
mlp_dims
:
tuple
=
(
1024
,
1024
,
32
*
768
),
time_freq_dim
:
int
=
256
,
projection_transformer_layers
:
int
=
4
,
):
super
().
__init__
()
self
.
audio_proj
=
AudioProjection
(
audio_feature_dim
=
audio_feature_dim
,
n_neighbors
=
(
2
,
2
),
num_tokens
=
num_tokens
,
mlp_dims
=
mlp_dims
,
transformer_layers
=
projection_transformer_layers
,
)
# self.num_tokens = num_tokens * 4
self
.
num_tokens_x4
=
num_tokens
*
4
self
.
audio_pe
=
nn
.
Parameter
(
torch
.
randn
(
self
.
num_tokens_x4
,
mlp_dims
[
-
1
]
//
num_tokens
)
*
0.02
)
ca_num
=
math
.
ceil
(
base_num_layers
/
interval
)
self
.
base_num_layers
=
base_num_layers
self
.
interval
=
interval
self
.
ca
=
nn
.
ModuleList
(
[
PerceiverAttentionCA
(
dim_head
=
attention_head_dim
,
heads
=
num_attention_heads
,
kv_dim
=
mlp_dims
[
-
1
]
//
num_tokens
,
adaLN
=
time_freq_dim
>
0
,
)
for
_
in
range
(
ca_num
)
]
)
self
.
dim
=
attention_head_dim
*
num_attention_heads
if
time_freq_dim
>
0
:
self
.
time_embedding
=
TimeEmbedding
(
self
.
dim
,
time_freq_dim
,
self
.
dim
*
3
)
else
:
self
.
time_embedding
=
None
def
rearange_audio_features
(
self
,
audio_feature
:
torch
.
Tensor
):
# audio_feature (B, video_frame, num_tokens, C)
audio_feature_0
=
audio_feature
[:,
:
1
]
audio_feature_0
=
torch
.
repeat_interleave
(
audio_feature_0
,
repeats
=
4
,
dim
=
1
)
audio_feature
=
torch
.
cat
([
audio_feature_0
,
audio_feature
[:,
1
:]],
dim
=
1
)
# (B, 4 * latent_frame, num_tokens, C)
audio_feature
=
rearrange
(
audio_feature
,
"B (T S) N C -> B T (S N) C"
,
S
=
4
)
return
audio_feature
def
forward
(
self
,
audio_feat
:
torch
.
Tensor
,
timestep
:
torch
.
Tensor
,
latent_frame
:
int
,
weight
:
float
=
1.0
):
def
modify_hidden_states
(
hidden_states
,
grid_sizes
,
ca_block
:
PerceiverAttentionCA
,
x
,
t_emb
,
dtype
,
weight
):
"""thw specify the latent_frame, latent_height, latenf_width after
hidden_states is patchified.
latent_frame does not include the reference images so that the
audios and hidden_states are strictly aligned
"""
if
len
(
hidden_states
.
shape
)
==
2
:
# 扩展batchsize dim
hidden_states
=
hidden_states
.
unsqueeze
(
0
)
# bs = 1
# print(weight)
t
,
h
,
w
=
grid_sizes
[
0
].
tolist
()
n_tokens
=
t
*
h
*
w
ori_dtype
=
hidden_states
.
dtype
device
=
hidden_states
.
device
bs
,
n_tokens_per_rank
=
hidden_states
.
shape
[:
2
]
tail_length
=
n_tokens_per_rank
-
n_tokens
n_query_tokens
=
n_tokens_per_rank
-
tail_length
%
n_tokens_per_rank
if
n_query_tokens
>
0
:
hidden_states_aligned
=
hidden_states
[:,
:
n_query_tokens
]
hidden_states_tail
=
hidden_states
[:,
n_query_tokens
:]
else
:
# for ranks that should be excluded from cross-attn, fake cross-attn will be applied so that FSDP works.
hidden_states_aligned
=
hidden_states
[:,
:
1
]
hidden_states_tail
=
hidden_states
[:,
1
:]
q_lens
,
t0
,
t1
=
get_q_lens_audio_range
(
batchsize
=
bs
,
n_tokens_per_rank
=
n_tokens_per_rank
,
n_query_tokens
=
n_query_tokens
,
n_tokens_per_frame
=
h
*
w
,
sp_rank
=
0
)
q_lens
=
torch
.
tensor
(
q_lens
,
device
=
device
,
dtype
=
torch
.
int32
)
"""
processing audio features in sp_state can be moved outside.
"""
x
=
x
[:,
t0
:
t1
]
x
=
x
.
to
(
dtype
)
k_lens
=
torch
.
tensor
([
self
.
num_tokens_x4
]
*
(
t1
-
t0
)
*
bs
,
device
=
device
,
dtype
=
torch
.
int32
)
assert
q_lens
.
shape
==
k_lens
.
shape
# ca_block:CrossAttention函数
residual
=
ca_block
(
x
,
hidden_states_aligned
,
t_emb
,
q_lens
,
k_lens
)
*
weight
residual
=
residual
.
to
(
ori_dtype
)
# audio做了CrossAttention之后以Residual的方式注入
if
n_query_tokens
==
0
:
residual
=
residual
*
0.0
hidden_states
=
torch
.
cat
([
hidden_states_aligned
+
residual
,
hidden_states_tail
],
dim
=
1
)
if
len
(
hidden_states
.
shape
)
==
3
:
#
hidden_states
=
hidden_states
.
squeeze
(
0
)
# bs = 1
return
hidden_states
x
=
self
.
audio_proj
(
audio_feat
,
latent_frame
)
x
=
self
.
rearange_audio_features
(
x
)
x
=
x
+
self
.
audio_pe
if
self
.
time_embedding
is
not
None
:
t_emb
=
self
.
time_embedding
(
timestep
).
unflatten
(
1
,
(
3
,
-
1
))
else
:
t_emb
=
torch
.
zeros
((
len
(
x
),
3
,
self
.
dim
),
device
=
x
.
device
,
dtype
=
x
.
dtype
)
ret_dict
=
{}
for
block_idx
,
base_idx
in
enumerate
(
range
(
0
,
self
.
base_num_layers
,
self
.
interval
)):
block_dict
=
{
"kwargs"
:
{
"ca_block"
:
self
.
ca
[
block_idx
],
"x"
:
x
,
"weight"
:
weight
,
"t_emb"
:
t_emb
,
"dtype"
:
x
.
dtype
,
},
"modify_func"
:
modify_hidden_states
,
}
ret_dict
[
base_idx
]
=
block_dict
return
ret_dict
@
classmethod
def
from_transformer
(
cls
,
transformer
,
audio_feature_dim
:
int
=
1024
,
interval
:
int
=
1
,
time_freq_dim
:
int
=
256
,
projection_transformer_layers
:
int
=
4
,
):
num_attention_heads
=
transformer
.
config
[
"num_heads"
]
base_num_layers
=
transformer
.
config
[
"num_layers"
]
attention_head_dim
=
transformer
.
config
[
"dim"
]
//
num_attention_heads
audio_adapter
=
AudioAdapter
(
attention_head_dim
,
num_attention_heads
,
base_num_layers
,
interval
=
interval
,
audio_feature_dim
=
audio_feature_dim
,
time_freq_dim
=
time_freq_dim
,
projection_transformer_layers
=
projection_transformer_layers
,
mlp_dims
=
(
1024
,
1024
,
32
*
audio_feature_dim
),
)
return
audio_adapter
def
get_fsdp_wrap_module_list
(
self
,
):
ret_list
=
list
(
self
.
ca
)
return
ret_list
def
enable_gradient_checkpointing
(
self
,
):
pass
class
AudioAdapterPipe
:
def
__init__
(
self
,
audio_adapter
:
AudioAdapter
,
audio_encoder_repo
:
str
=
"microsoft/wavlm-base-plus"
,
dtype
=
torch
.
float32
,
device
=
"cuda"
,
generator
=
None
,
tgt_fps
:
int
=
15
,
weight
:
float
=
1.0
)
->
None
:
self
.
audio_adapter
=
audio_adapter
self
.
dtype
=
dtype
self
.
device
=
device
self
.
generator
=
generator
self
.
audio_encoder_dtype
=
torch
.
float16
self
.
audio_encoder
=
AutoModel
.
from_pretrained
(
audio_encoder_repo
)
self
.
audio_encoder
.
eval
()
self
.
audio_encoder
.
to
(
device
,
self
.
audio_encoder_dtype
)
self
.
tgt_fps
=
tgt_fps
self
.
weight
=
weight
if
"base"
in
audio_encoder_repo
:
self
.
audio_feature_dim
=
768
else
:
self
.
audio_feature_dim
=
1024
def
update_model
(
self
,
audio_adapter
):
self
.
audio_adapter
=
audio_adapter
def
__call__
(
self
,
audio_input_feat
,
timestep
,
latent_shape
:
tuple
,
dropout_cond
:
callable
=
None
):
# audio_input_feat is from AudioPreprocessor
latent_frame
=
latent_shape
[
2
]
if
len
(
audio_input_feat
.
shape
)
==
1
:
# 扩展batchsize = 1
audio_input_feat
=
audio_input_feat
.
unsqueeze
(
0
)
latent_frame
=
latent_shape
[
1
]
video_frame
=
(
latent_frame
-
1
)
*
4
+
1
audio_length
=
int
(
50
/
self
.
tgt_fps
*
video_frame
)
with
torch
.
no_grad
():
audio_input_feat
=
audio_input_feat
.
to
(
self
.
device
,
self
.
audio_encoder_dtype
)
try
:
audio_feat
=
self
.
audio_encoder
(
audio_input_feat
,
return_dict
=
True
).
last_hidden_state
except
Exception
as
err
:
audio_feat
=
torch
.
rand
(
1
,
audio_length
,
self
.
audio_feature_dim
).
to
(
self
.
device
)
print
(
err
)
audio_feat
=
audio_feat
.
to
(
self
.
dtype
)
if
dropout_cond
is
not
None
:
audio_feat
=
dropout_cond
(
audio_feat
)
return
self
.
audio_adapter
(
audio_feat
=
audio_feat
,
timestep
=
timestep
,
latent_frame
=
latent_frame
,
weight
=
self
.
weight
)
lightx2v/models/networks/wan/audio_model.py
0 → 100644
View file @
e58dd9fe
import
os
import
torch
import
time
import
glob
from
lightx2v.models.networks.wan.model
import
WanModel
from
lightx2v.models.networks.wan.weights.pre_weights
import
WanPreWeights
from
lightx2v.models.networks.wan.weights.post_weights
import
WanPostWeights
from
lightx2v.models.networks.wan.weights.transformer_weights
import
(
WanTransformerWeights
,
)
from
lightx2v.models.networks.wan.infer.pre_infer
import
WanPreInfer
from
lightx2v.models.networks.wan.infer.post_infer
import
WanPostInfer
from
lightx2v.models.networks.wan.infer.pre_infer
import
WanPreInfer
from
lightx2v.models.networks.wan.infer.pre_wan_audio_infer
import
WanAudioPreInfer
from
lightx2v.models.networks.wan.infer.post_wan_audio_infer
import
WanAudioPostInfer
from
lightx2v.models.networks.wan.infer.feature_caching.transformer_infer
import
WanTransformerInferTeaCaching
from
safetensors
import
safe_open
import
lightx2v.attentions.distributed.ulysses.wrap
as
ulysses_dist_wrap
import
lightx2v.attentions.distributed.ring.wrap
as
ring_dist_wrap
from
lightx2v.models.networks.wan.infer.transformer_infer
import
(
WanTransformerInfer
,
)
from
lightx2v.models.networks.wan.infer.feature_caching.transformer_infer
import
(
WanTransformerInferTeaCaching
,
)
class
WanAudioModel
(
WanModel
):
pre_weight_class
=
WanPreWeights
post_weight_class
=
WanPostWeights
transformer_weight_class
=
WanTransformerWeights
def
__init__
(
self
,
model_path
,
config
,
device
):
super
().
__init__
(
model_path
,
config
,
device
)
def
_init_infer_class
(
self
):
self
.
pre_infer_class
=
WanAudioPreInfer
self
.
post_infer_class
=
WanAudioPostInfer
if
self
.
config
[
"feature_caching"
]
==
"NoCaching"
:
self
.
transformer_infer_class
=
WanTransformerInfer
elif
self
.
config
[
"feature_caching"
]
==
"Tea"
:
self
.
transformer_infer_class
=
WanTransformerInferTeaCaching
else
:
raise
NotImplementedError
(
f
"Unsupported feature_caching type:
{
self
.
config
[
'feature_caching'
]
}
"
)
@
torch
.
no_grad
()
def
infer
(
self
,
inputs
):
if
self
.
config
[
"cpu_offload"
]:
self
.
pre_weight
.
to_cuda
()
self
.
post_weight
.
to_cuda
()
embed
,
grid_sizes
,
pre_infer_out
,
valid_patch_length
=
self
.
pre_infer
.
infer
(
self
.
pre_weight
,
inputs
,
positive
=
True
)
x
=
self
.
transformer_infer
.
infer
(
self
.
transformer_weights
,
grid_sizes
,
embed
,
*
pre_infer_out
)
noise_pred_cond
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
x
,
embed
,
grid_sizes
,
valid_patch_length
)[
0
]
if
self
.
config
[
"feature_caching"
]
==
"Tea"
:
self
.
scheduler
.
cnt
+=
1
if
self
.
scheduler
.
cnt
>=
self
.
scheduler
.
num_steps
:
self
.
scheduler
.
cnt
=
0
self
.
scheduler
.
noise_pred
=
noise_pred_cond
if
self
.
config
[
"enable_cfg"
]:
embed
,
grid_sizes
,
pre_infer_out
,
valid_patch_length
=
self
.
pre_infer
.
infer
(
self
.
pre_weight
,
inputs
,
positive
=
False
)
x
=
self
.
transformer_infer
.
infer
(
self
.
transformer_weights
,
grid_sizes
,
embed
,
*
pre_infer_out
)
noise_pred_uncond
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
x
,
embed
,
grid_sizes
,
valid_patch_length
)[
0
]
if
self
.
config
[
"feature_caching"
]
==
"Tea"
:
self
.
scheduler
.
cnt
+=
1
if
self
.
scheduler
.
cnt
>=
self
.
scheduler
.
num_steps
:
self
.
scheduler
.
cnt
=
0
self
.
scheduler
.
noise_pred
=
noise_pred_uncond
+
self
.
config
.
sample_guide_scale
*
(
noise_pred_cond
-
noise_pred_uncond
)
if
self
.
config
[
"cpu_offload"
]:
self
.
pre_weight
.
to_cpu
()
self
.
post_weight
.
to_cpu
()
lightx2v/models/networks/wan/infer/post_wan_audio_infer.py
0 → 100755
View file @
e58dd9fe
import
math
import
torch
import
torch.cuda.amp
as
amp
from
loguru
import
logger
from
lightx2v.models.networks.wan.infer.post_infer
import
WanPostInfer
class
WanAudioPostInfer
(
WanPostInfer
):
def
__init__
(
self
,
config
):
self
.
out_dim
=
config
[
"out_dim"
]
self
.
patch_size
=
(
1
,
2
,
2
)
def
set_scheduler
(
self
,
scheduler
):
self
.
scheduler
=
scheduler
def
infer
(
self
,
weights
,
x
,
e
,
grid_sizes
,
valid_patch_length
):
if
e
.
dim
()
==
2
:
modulation
=
weights
.
head_modulation
.
tensor
# 1, 2, dim
e
=
(
modulation
+
e
.
unsqueeze
(
1
)).
chunk
(
2
,
dim
=
1
)
elif
e
.
dim
()
==
3
:
# For Diffustion forcing
modulation
=
weights
.
head_modulation
.
tensor
.
unsqueeze
(
2
)
# 1, 2, seq, dim
e
=
(
modulation
+
e
.
unsqueeze
(
1
)).
chunk
(
2
,
dim
=
1
)
e
=
[
ei
.
squeeze
(
1
)
for
ei
in
e
]
norm_out
=
torch
.
nn
.
functional
.
layer_norm
(
x
,
(
x
.
shape
[
1
],),
None
,
None
,
1e-6
).
type_as
(
x
)
out
=
norm_out
*
(
1
+
e
[
1
].
squeeze
(
0
))
+
e
[
0
].
squeeze
(
0
)
x
=
weights
.
head
.
apply
(
out
)
x
=
x
[:,
:
valid_patch_length
]
x
=
self
.
unpatchify
(
x
,
grid_sizes
)
return
[
u
.
float
()
for
u
in
x
]
def
unpatchify
(
self
,
x
,
grid_sizes
):
x
=
x
.
unsqueeze
(
0
)
c
=
self
.
out_dim
out
=
[]
for
u
,
v
in
zip
(
x
,
grid_sizes
.
tolist
()):
u
=
u
[:
math
.
prod
(
v
)].
view
(
*
v
,
*
self
.
patch_size
,
c
)
u
=
torch
.
einsum
(
"fhwpqrc->cfphqwr"
,
u
)
u
=
u
.
reshape
(
c
,
*
[
i
*
j
for
i
,
j
in
zip
(
v
,
self
.
patch_size
)])
out
.
append
(
u
)
return
out
lightx2v/models/networks/wan/infer/pre_wan_audio_infer.py
0 → 100755
View file @
e58dd9fe
import
torch
import
math
from
.utils
import
rope_params
,
sinusoidal_embedding_1d
from
lightx2v.models.networks.wan.infer.pre_infer
import
WanPreInfer
from
loguru
import
logger
class
WanAudioPreInfer
(
WanPreInfer
):
def
__init__
(
self
,
config
):
assert
(
config
[
"dim"
]
%
config
[
"num_heads"
])
==
0
and
(
config
[
"dim"
]
//
config
[
"num_heads"
])
%
2
==
0
d
=
config
[
"dim"
]
//
config
[
"num_heads"
]
self
.
task
=
config
[
"task"
]
self
.
freqs
=
torch
.
cat
(
[
rope_params
(
1024
,
d
-
4
*
(
d
//
6
)),
rope_params
(
1024
,
2
*
(
d
//
6
)),
rope_params
(
1024
,
2
*
(
d
//
6
)),
],
dim
=
1
,
).
cuda
()
self
.
freq_dim
=
config
[
"freq_dim"
]
self
.
dim
=
config
[
"dim"
]
self
.
text_len
=
config
[
"text_len"
]
def
infer
(
self
,
weights
,
inputs
,
positive
):
ltnt_channel
=
self
.
scheduler
.
latents
.
size
(
0
)
prev_latents
=
inputs
[
"previmg_encoder_output"
][
"prev_latents"
].
unsqueeze
(
0
)
prev_mask
=
inputs
[
"previmg_encoder_output"
][
"prev_mask"
]
hidden_states
=
self
.
scheduler
.
latents
.
unsqueeze
(
0
)
hidden_states
=
torch
.
cat
([
hidden_states
[:,
:
ltnt_channel
],
prev_latents
,
prev_mask
],
dim
=
1
)
hidden_states
=
hidden_states
.
squeeze
(
0
)
x
=
[
hidden_states
]
t
=
torch
.
stack
([
self
.
scheduler
.
timesteps
[
self
.
scheduler
.
step_index
]])
audio_dit_blocks
=
[]
audio_encoder_output
=
inputs
[
"audio_encoder_output"
]
audio_model_input
=
{
"audio_input_feat"
:
audio_encoder_output
.
to
(
hidden_states
.
device
),
"latent_shape"
:
hidden_states
.
shape
,
"timestep"
:
t
,
}
audio_dit_blocks
.
append
(
inputs
[
"audio_adapter_pipe"
](
**
audio_model_input
))
if
positive
:
context
=
inputs
[
"text_encoder_output"
][
"context"
]
else
:
context
=
inputs
[
"text_encoder_output"
][
"context_null"
]
seq_len
=
self
.
scheduler
.
seq_len
clip_fea
=
inputs
[
"image_encoder_output"
][
"clip_encoder_out"
]
ref_image_encoder
=
inputs
[
"image_encoder_output"
][
"vae_encode_out"
]
batch_size
=
len
(
x
)
num_channels
,
num_frames
,
height
,
width
=
x
[
0
].
shape
_
,
ref_num_channels
,
ref_num_frames
,
_
,
_
=
ref_image_encoder
.
shape
if
ref_num_channels
!=
num_channels
:
zero_padding
=
torch
.
zeros
(
(
batch_size
,
num_channels
-
ref_num_channels
,
ref_num_frames
,
height
,
width
),
dtype
=
self
.
scheduler
.
latents
.
dtype
,
device
=
self
.
scheduler
.
latents
.
device
,
)
ref_image_encoder
=
torch
.
concat
([
ref_image_encoder
,
zero_padding
],
dim
=
1
)
y
=
list
(
torch
.
unbind
(
ref_image_encoder
,
dim
=
0
))
# 第一个batch维度变成list
# embeddings
x
=
[
weights
.
patch_embedding
.
apply
(
u
.
unsqueeze
(
0
))
for
u
in
x
]
x_grid_sizes
=
torch
.
stack
([
torch
.
tensor
(
u
.
shape
[
2
:],
dtype
=
torch
.
long
)
for
u
in
x
])
x
=
[
u
.
flatten
(
2
).
transpose
(
1
,
2
)
for
u
in
x
]
seq_lens
=
torch
.
tensor
([
u
.
size
(
1
)
for
u
in
x
],
dtype
=
torch
.
long
).
cuda
()
assert
seq_lens
.
max
()
<=
seq_len
x
=
torch
.
cat
([
torch
.
cat
([
u
,
u
.
new_zeros
(
1
,
seq_len
-
u
.
size
(
1
),
u
.
size
(
2
))],
dim
=
1
)
for
u
in
x
])
valid_patch_length
=
x
[
0
].
size
(
0
)
y
=
[
weights
.
patch_embedding
.
apply
(
u
.
unsqueeze
(
0
))
for
u
in
y
]
y_grid_sizes
=
torch
.
stack
([
torch
.
tensor
(
u
.
shape
[
2
:],
dtype
=
torch
.
long
)
for
u
in
y
])
y
=
[
u
.
flatten
(
2
).
transpose
(
1
,
2
).
squeeze
(
0
)
for
u
in
y
]
x
=
[
torch
.
cat
([
a
,
b
],
dim
=
0
)
for
a
,
b
in
zip
(
x
,
y
)]
x
=
torch
.
stack
(
x
,
dim
=
0
)
embed
=
sinusoidal_embedding_1d
(
self
.
freq_dim
,
t
.
flatten
())
embed
=
weights
.
time_embedding_0
.
apply
(
embed
)
embed
=
torch
.
nn
.
functional
.
silu
(
embed
)
embed
=
weights
.
time_embedding_2
.
apply
(
embed
)
embed0
=
torch
.
nn
.
functional
.
silu
(
embed
)
embed0
=
weights
.
time_projection_1
.
apply
(
embed0
).
unflatten
(
1
,
(
6
,
self
.
dim
))
# text embeddings
stacked
=
torch
.
stack
([
torch
.
cat
([
u
,
u
.
new_zeros
(
self
.
text_len
-
u
.
size
(
0
),
u
.
size
(
1
))])
for
u
in
context
])
out
=
weights
.
text_embedding_0
.
apply
(
stacked
.
squeeze
(
0
))
out
=
torch
.
nn
.
functional
.
gelu
(
out
,
approximate
=
"tanh"
)
context
=
weights
.
text_embedding_2
.
apply
(
out
)
if
self
.
task
==
"i2v"
:
context_clip
=
weights
.
proj_0
.
apply
(
clip_fea
)
context_clip
=
weights
.
proj_1
.
apply
(
context_clip
)
context_clip
=
torch
.
nn
.
functional
.
gelu
(
context_clip
,
approximate
=
"none"
)
context_clip
=
weights
.
proj_3
.
apply
(
context_clip
)
context_clip
=
weights
.
proj_4
.
apply
(
context_clip
)
context
=
torch
.
concat
([
context_clip
,
context
],
dim
=
0
)
return
(
embed
,
x_grid_sizes
,
(
x
.
squeeze
(
0
),
embed0
.
squeeze
(
0
),
seq_lens
,
self
.
freqs
,
context
,
audio_dit_blocks
),
valid_patch_length
)
lightx2v/models/networks/wan/infer/transformer_infer.py
View file @
e58dd9fe
import
torch
from
.utils
import
compute_freqs
,
compute_freqs_dist
,
apply_rotary_emb
,
apply_rotary_emb_chunk
from
.utils
import
compute_freqs
,
compute_freqs_dist
,
compute_freqs_audio
,
compute_freqs_audio_dist
,
apply_rotary_emb
,
apply_rotary_emb_chunk
from
lightx2v.common.offload.manager
import
(
WeightAsyncStreamManager
,
LazyWeightAsyncStreamManager
,
)
from
lightx2v.common.transformer_infer.transformer_infer
import
BaseTransformerInfer
from
lightx2v.utils.envs
import
*
from
loguru
import
logger
import
pdb
import
os
class
WanTransformerInfer
(
BaseTransformerInfer
):
...
...
@@ -64,10 +67,10 @@ class WanTransformerInfer(BaseTransformerInfer):
return
cu_seqlens_q
,
cu_seqlens_k
@
torch
.
compile
(
disable
=
not
CHECK_ENABLE_GRAPH_MODE
())
def
infer
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
):
return
self
.
infer_func
(
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
)
def
infer
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
,
audio_dit_blocks
=
None
):
return
self
.
infer_func
(
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
,
audio_dit_blocks
)
def
_infer_with_offload
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
):
def
_infer_with_offload
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
,
audio_dit_blocks
=
None
):
for
block_idx
in
range
(
self
.
blocks_num
):
if
block_idx
==
0
:
self
.
weights_stream_mgr
.
active_weights
[
0
]
=
weights
.
blocks
[
0
]
...
...
@@ -92,7 +95,7 @@ class WanTransformerInfer(BaseTransformerInfer):
return
x
def
_infer_with_phases_offload
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
):
def
_infer_with_phases_offload
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
,
audio_dit_blocks
=
None
):
for
block_idx
in
range
(
weights
.
blocks_num
):
for
phase_idx
in
range
(
self
.
phases_num
):
if
block_idx
==
0
and
phase_idx
==
0
:
...
...
@@ -133,7 +136,7 @@ class WanTransformerInfer(BaseTransformerInfer):
return
x
def
_infer_with_phases_lazy_offload
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
):
def
_infer_with_phases_lazy_offload
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
,
audio_dit_blocks
=
None
):
self
.
weights_stream_mgr
.
prefetch_weights_from_disk
(
weights
)
for
block_idx
in
range
(
weights
.
blocks_num
):
...
...
@@ -194,7 +197,22 @@ class WanTransformerInfer(BaseTransformerInfer):
return
x
def
_infer_without_offload
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
):
def
zero_temporal_component_in_3DRoPE
(
self
,
valid_token_length
,
rotary_emb
=
None
):
if
rotary_emb
is
None
:
return
None
self
.
use_real
=
False
rope_t_dim
=
44
if
self
.
use_real
:
freqs_cos
,
freqs_sin
=
rotary_emb
freqs_cos
[
valid_token_length
:,
:,
:
rope_t_dim
]
=
0
freqs_sin
[
valid_token_length
:,
:,
:
rope_t_dim
]
=
0
return
freqs_cos
,
freqs_sin
else
:
freqs_cis
=
rotary_emb
freqs_cis
[
valid_token_length
:,
:,
:
rope_t_dim
//
2
]
=
0
return
freqs_cis
def
_infer_without_offload
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
,
audio_dit_blocks
=
None
):
for
block_idx
in
range
(
self
.
blocks_num
):
x
=
self
.
infer_block
(
weights
.
blocks
[
block_idx
],
...
...
@@ -206,6 +224,12 @@ class WanTransformerInfer(BaseTransformerInfer):
freqs
,
context
,
)
if
audio_dit_blocks
is
not
None
and
len
(
audio_dit_blocks
)
>
0
:
for
ipa_out
in
audio_dit_blocks
:
if
block_idx
in
ipa_out
:
cur_modify
=
ipa_out
[
block_idx
]
x
=
cur_modify
[
"modify_func"
](
x
,
grid_sizes
,
**
cur_modify
[
"kwargs"
])
return
x
def
infer_block
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
):
...
...
@@ -265,14 +289,23 @@ class WanTransformerInfer(BaseTransformerInfer):
v
=
weights
.
self_attn_v
.
apply
(
norm1_out
).
view
(
s
,
n
,
d
)
if
not
self
.
parallel_attention
:
freqs_i
=
compute_freqs
(
q
.
size
(
2
)
//
2
,
grid_sizes
,
freqs
)
if
self
.
config
.
get
(
"audio_sr"
,
False
):
freqs_i
=
compute_freqs_audio
(
q
.
size
(
0
),
q
.
size
(
2
)
//
2
,
grid_sizes
,
freqs
)
else
:
freqs_i
=
compute_freqs
(
q
.
size
(
2
)
//
2
,
grid_sizes
,
freqs
)
else
:
freqs_i
=
compute_freqs_dist
(
q
.
size
(
0
),
q
.
size
(
2
)
//
2
,
grid_sizes
,
freqs
)
if
self
.
config
.
get
(
"audio_sr"
,
False
):
freqs_i
=
compute_freqs_audio_dist
(
q
.
size
(
0
),
q
.
size
(
2
)
//
2
,
grid_sizes
,
freqs
)
else
:
freqs_i
=
compute_freqs_dist
(
q
.
size
(
2
)
//
2
,
grid_sizes
,
freqs
)
freqs_i
=
self
.
zero_temporal_component_in_3DRoPE
(
seq_lens
,
freqs_i
)
q
=
self
.
apply_rotary_emb_func
(
q
,
freqs_i
)
k
=
self
.
apply_rotary_emb_func
(
k
,
freqs_i
)
cu_seqlens_q
,
cu_seqlens_k
=
self
.
_calculate_q_k_len
(
q
,
k_lens
=
seq_lens
)
k_lens
=
torch
.
empty_like
(
seq_lens
).
fill_
(
freqs_i
.
size
(
0
))
cu_seqlens_q
,
cu_seqlens_k
=
self
.
_calculate_q_k_len
(
q
,
k_lens
=
k_lens
)
if
self
.
clean_cuda_cache
:
del
freqs_i
,
norm1_out
,
norm1_weight
,
norm1_bias
...
...
@@ -353,7 +386,6 @@ class WanTransformerInfer(BaseTransformerInfer):
q
,
k_lens
=
torch
.
tensor
([
k_img
.
size
(
0
)],
dtype
=
torch
.
int32
,
device
=
k
.
device
),
)
img_attn_out
=
weights
.
cross_attn_2
.
apply
(
q
=
q
,
k
=
k_img
,
...
...
lightx2v/models/networks/wan/infer/utils.py
View file @
e58dd9fe
import
torch
import
torch.distributed
as
dist
from
loguru
import
logger
from
lightx2v.utils.envs
import
*
...
...
@@ -19,6 +20,45 @@ def compute_freqs(c, grid_sizes, freqs):
return
freqs_i
def
compute_freqs_audio
(
c
,
grid_sizes
,
freqs
):
freqs
=
freqs
.
split
([
c
-
2
*
(
c
//
3
),
c
//
3
,
c
//
3
],
dim
=
1
)
f
,
h
,
w
=
grid_sizes
[
0
].
tolist
()
f
=
f
+
1
seq_len
=
f
*
h
*
w
freqs_i
=
torch
.
cat
(
[
freqs
[
0
][:
f
].
view
(
f
,
1
,
1
,
-
1
).
expand
(
f
,
h
,
w
,
-
1
),
freqs
[
1
][:
h
].
view
(
1
,
h
,
1
,
-
1
).
expand
(
f
,
h
,
w
,
-
1
),
freqs
[
2
][:
w
].
view
(
1
,
1
,
w
,
-
1
).
expand
(
f
,
h
,
w
,
-
1
),
],
dim
=-
1
,
).
reshape
(
seq_len
,
1
,
-
1
)
return
freqs_i
def
compute_freqs_audio_dist
(
s
,
c
,
grid_sizes
,
freqs
):
world_size
=
dist
.
get_world_size
()
cur_rank
=
dist
.
get_rank
()
freqs
=
freqs
.
split
([
c
-
2
*
(
c
//
3
),
c
//
3
,
c
//
3
],
dim
=
1
)
f
,
h
,
w
=
grid_sizes
[
0
].
tolist
()
f
=
f
+
1
seq_len
=
f
*
h
*
w
freqs_i
=
torch
.
cat
(
[
freqs
[
0
][:
f
].
view
(
f
,
1
,
1
,
-
1
).
expand
(
f
,
h
,
w
,
-
1
),
freqs
[
1
][:
h
].
view
(
1
,
h
,
1
,
-
1
).
expand
(
f
,
h
,
w
,
-
1
),
freqs
[
2
][:
w
].
view
(
1
,
1
,
w
,
-
1
).
expand
(
f
,
h
,
w
,
-
1
),
],
dim
=-
1
,
).
reshape
(
seq_len
,
1
,
-
1
)
freqs_i
=
pad_freqs
(
freqs_i
,
s
*
world_size
)
s_per_rank
=
s
freqs_i_rank
=
freqs_i
[(
cur_rank
*
s_per_rank
)
:
((
cur_rank
+
1
)
*
s_per_rank
),
:,
:]
return
freqs_i_rank
def
compute_freqs_causvid
(
c
,
grid_sizes
,
freqs
,
start_frame
=
0
):
freqs
=
freqs
.
split
([
c
-
2
*
(
c
//
3
),
c
//
3
,
c
//
3
],
dim
=
1
)
f
,
h
,
w
=
grid_sizes
[
0
].
tolist
()
...
...
lightx2v/models/runners/wan/wan_audio_runner.py
0 → 100644
View file @
e58dd9fe
This diff is collapsed.
Click to expand it.
lightx2v/models/schedulers/wan/scheduler.py
View file @
e58dd9fe
...
...
@@ -115,6 +115,15 @@ class WanScheduler(BaseScheduler):
x0_pred
=
sample
-
sigma_t
*
model_output
return
x0_pred
def
reset
(
self
):
self
.
model_outputs
=
[
None
]
*
self
.
solver_order
self
.
timestep_list
=
[
None
]
*
self
.
solver_order
self
.
last_sample
=
None
self
.
noise_pred
=
None
self
.
this_order
=
None
self
.
lower_order_nums
=
0
self
.
prepare_latents
(
self
.
config
.
target_shape
,
dtype
=
torch
.
float32
)
def
multistep_uni_p_bh_update
(
self
,
model_output
:
torch
.
Tensor
,
...
...
scripts/run_wan_i2v_audio.sh
0 → 100755
View file @
e58dd9fe
#!/bin/bash
# set path and first
lightx2v_path
=
"/mnt/Text2Video/wangshankun/lightx2v"
model_path
=
"/mnt/Text2Video/wangshankun/HF_Cache/Wan2.1-I2V-Audio-14B-720P/"
# check section
if
[
-z
"
${
CUDA_VISIBLE_DEVICES
}
"
]
;
then
cuda_devices
=
0
echo
"Warn: CUDA_VISIBLE_DEVICES is not set, using default value:
${
cuda_devices
}
, change at shell script or set env variable."
export
CUDA_VISIBLE_DEVICES
=
${
cuda_devices
}
fi
if
[
-z
"
${
lightx2v_path
}
"
]
;
then
echo
"Error: lightx2v_path is not set. Please set this variable first."
exit
1
fi
if
[
-z
"
${
model_path
}
"
]
;
then
echo
"Error: model_path is not set. Please set this variable first."
exit
1
fi
export
TOKENIZERS_PARALLELISM
=
false
export
PYTHONPATH
=
${
lightx2v_path
}
:
$PYTHONPATH
export
ENABLE_PROFILING_DEBUG
=
true
export
ENABLE_GRAPH_MODE
=
false
python
-m
lightx2v.infer
\
--model_cls
wan2.1_audio
\
--task
i2v
\
--model_path
$model_path
\
--config_json
${
lightx2v_path
}
/configs/wan_i2v_audio.json
\
--prompt_path
${
lightx2v_path
}
/assets/inputs/audio/15.txt
\
--negative_prompt
色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走
\
--image_path
${
lightx2v_path
}
/assets/inputs/audio/15.png
\
--audio_path
${
lightx2v_path
}
/assets/inputs/audio/15.wav
\
--save_video_path
${
lightx2v_path
}
/save_results/output_lightx2v_wan_i2v_audio.mp4
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment