Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
dd958c79
Commit
dd958c79
authored
Aug 06, 2025
by
wangshankun
Browse files
Support: audio r2v dist infer
parent
820b4450
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
95 additions
and
12 deletions
+95
-12
lightx2v/models/networks/wan/audio_adapter.py
lightx2v/models/networks/wan/audio_adapter.py
+34
-8
lightx2v/models/networks/wan/audio_model.py
lightx2v/models/networks/wan/audio_model.py
+45
-0
lightx2v/models/networks/wan/infer/dist_infer/transformer_infer.py
...models/networks/wan/infer/dist_infer/transformer_infer.py
+3
-0
lightx2v/models/networks/wan/infer/transformer_infer.py
lightx2v/models/networks/wan/infer/transformer_infer.py
+1
-3
lightx2v/models/networks/wan/infer/utils.py
lightx2v/models/networks/wan/infer/utils.py
+3
-0
lightx2v/models/runners/wan/wan_audio_runner.py
lightx2v/models/runners/wan/wan_audio_runner.py
+9
-1
No files found.
lightx2v/models/networks/wan/audio_adapter.py
View file @
dd958c79
...
@@ -12,6 +12,9 @@ import torch.nn.functional as F
...
@@ -12,6 +12,9 @@ import torch.nn.functional as F
from
diffusers.models.embeddings
import
TimestepEmbedding
,
Timesteps
from
diffusers.models.embeddings
import
TimestepEmbedding
,
Timesteps
from
einops
import
rearrange
from
einops
import
rearrange
from
transformers
import
AutoModel
from
transformers
import
AutoModel
import
torch.distributed
as
dist
from
loguru
import
logger
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
...
@@ -261,8 +264,8 @@ class AudioAdapter(nn.Module):
...
@@ -261,8 +264,8 @@ class AudioAdapter(nn.Module):
audio_feature
=
rearrange
(
audio_feature
,
"B (T S) N C -> B T (S N) C"
,
S
=
4
)
audio_feature
=
rearrange
(
audio_feature
,
"B (T S) N C -> B T (S N) C"
,
S
=
4
)
return
audio_feature
return
audio_feature
def
forward
(
self
,
audio_feat
:
torch
.
Tensor
,
timestep
:
torch
.
Tensor
,
latent_frame
:
int
,
weight
:
float
=
1.0
):
def
forward
(
self
,
audio_feat
:
torch
.
Tensor
,
timestep
:
torch
.
Tensor
,
latent_frame
:
int
,
weight
:
float
=
1.0
,
seq_p_group
=
None
):
def
modify_hidden_states
(
hidden_states
,
grid_sizes
,
ca_block
:
PerceiverAttentionCA
,
x
,
t_emb
,
dtype
,
weight
):
def
modify_hidden_states
(
hidden_states
,
grid_sizes
,
ca_block
:
PerceiverAttentionCA
,
x
,
t_emb
,
dtype
,
weight
,
seq_p_group
):
"""thw specify the latent_frame, latent_height, latenf_width after
"""thw specify the latent_frame, latent_height, latenf_width after
hidden_states is patchified.
hidden_states is patchified.
...
@@ -271,15 +274,27 @@ class AudioAdapter(nn.Module):
...
@@ -271,15 +274,27 @@ class AudioAdapter(nn.Module):
"""
"""
if
len
(
hidden_states
.
shape
)
==
2
:
# 扩展batchsize dim
if
len
(
hidden_states
.
shape
)
==
2
:
# 扩展batchsize dim
hidden_states
=
hidden_states
.
unsqueeze
(
0
)
# bs = 1
hidden_states
=
hidden_states
.
unsqueeze
(
0
)
# bs = 1
# print(weight)
t
,
h
,
w
=
grid_sizes
[
0
].
tolist
()
t
,
h
,
w
=
grid_sizes
[
0
].
tolist
()
n_tokens
=
t
*
h
*
w
n_tokens
=
t
*
h
*
w
ori_dtype
=
hidden_states
.
dtype
ori_dtype
=
hidden_states
.
dtype
device
=
hidden_states
.
device
device
=
hidden_states
.
device
bs
,
n_tokens_per_rank
=
hidden_states
.
shape
[:
2
]
bs
,
n_tokens_per_rank
=
hidden_states
.
shape
[:
2
]
tail_length
=
n_tokens_per_rank
-
n_tokens
if
seq_p_group
is
not
None
:
sp_size
=
dist
.
get_world_size
(
seq_p_group
)
sp_rank
=
dist
.
get_rank
(
seq_p_group
)
else
:
sp_size
=
1
sp_rank
=
0
tail_length
=
n_tokens_per_rank
*
sp_size
-
n_tokens
n_unused_ranks
=
tail_length
//
n_tokens_per_rank
if
sp_rank
>
sp_size
-
n_unused_ranks
-
1
:
n_query_tokens
=
0
elif
sp_rank
==
sp_size
-
n_unused_ranks
-
1
:
n_query_tokens
=
n_tokens_per_rank
-
tail_length
%
n_tokens_per_rank
n_query_tokens
=
n_tokens_per_rank
-
tail_length
%
n_tokens_per_rank
else
:
n_query_tokens
=
n_tokens_per_rank
if
n_query_tokens
>
0
:
if
n_query_tokens
>
0
:
hidden_states_aligned
=
hidden_states
[:,
:
n_query_tokens
]
hidden_states_aligned
=
hidden_states
[:,
:
n_query_tokens
]
...
@@ -289,7 +304,7 @@ class AudioAdapter(nn.Module):
...
@@ -289,7 +304,7 @@ class AudioAdapter(nn.Module):
hidden_states_aligned
=
hidden_states
[:,
:
1
]
hidden_states_aligned
=
hidden_states
[:,
:
1
]
hidden_states_tail
=
hidden_states
[:,
1
:]
hidden_states_tail
=
hidden_states
[:,
1
:]
q_lens
,
t0
,
t1
=
get_q_lens_audio_range
(
batchsize
=
bs
,
n_tokens_per_rank
=
n_tokens_per_rank
,
n_query_tokens
=
n_query_tokens
,
n_tokens_per_frame
=
h
*
w
,
sp_rank
=
0
)
q_lens
,
t0
,
t1
=
get_q_lens_audio_range
(
batchsize
=
bs
,
n_tokens_per_rank
=
n_tokens_per_rank
,
n_query_tokens
=
n_query_tokens
,
n_tokens_per_frame
=
h
*
w
,
sp_rank
=
sp_rank
)
q_lens
=
torch
.
tensor
(
q_lens
,
device
=
device
,
dtype
=
torch
.
int32
)
q_lens
=
torch
.
tensor
(
q_lens
,
device
=
device
,
dtype
=
torch
.
int32
)
"""
"""
processing audio features in sp_state can be moved outside.
processing audio features in sp_state can be moved outside.
...
@@ -300,6 +315,7 @@ class AudioAdapter(nn.Module):
...
@@ -300,6 +315,7 @@ class AudioAdapter(nn.Module):
assert
q_lens
.
shape
==
k_lens
.
shape
assert
q_lens
.
shape
==
k_lens
.
shape
# ca_block:CrossAttention函数
# ca_block:CrossAttention函数
residual
=
ca_block
(
x
,
hidden_states_aligned
,
t_emb
,
q_lens
,
k_lens
)
*
weight
residual
=
ca_block
(
x
,
hidden_states_aligned
,
t_emb
,
q_lens
,
k_lens
)
*
weight
residual
=
residual
.
to
(
ori_dtype
)
# audio做了CrossAttention之后以Residual的方式注入
residual
=
residual
.
to
(
ori_dtype
)
# audio做了CrossAttention之后以Residual的方式注入
if
n_query_tokens
==
0
:
if
n_query_tokens
==
0
:
residual
=
residual
*
0.0
residual
=
residual
*
0.0
...
@@ -325,6 +341,7 @@ class AudioAdapter(nn.Module):
...
@@ -325,6 +341,7 @@ class AudioAdapter(nn.Module):
"weight"
:
weight
,
"weight"
:
weight
,
"t_emb"
:
t_emb
,
"t_emb"
:
t_emb
,
"dtype"
:
x
.
dtype
,
"dtype"
:
x
.
dtype
,
"seq_p_group"
:
seq_p_group
,
},
},
"modify_func"
:
modify_hidden_states
,
"modify_func"
:
modify_hidden_states
,
}
}
...
@@ -370,8 +387,17 @@ class AudioAdapter(nn.Module):
...
@@ -370,8 +387,17 @@ class AudioAdapter(nn.Module):
class
AudioAdapterPipe
:
class
AudioAdapterPipe
:
def
__init__
(
def
__init__
(
self
,
audio_adapter
:
AudioAdapter
,
audio_encoder_repo
:
str
=
"microsoft/wavlm-base-plus"
,
dtype
=
torch
.
float32
,
device
=
"cuda"
,
tgt_fps
:
int
=
15
,
weight
:
float
=
1.0
,
cpu_offload
:
bool
=
False
self
,
audio_adapter
:
AudioAdapter
,
audio_encoder_repo
:
str
=
"microsoft/wavlm-base-plus"
,
dtype
=
torch
.
float32
,
device
=
"cuda"
,
tgt_fps
:
int
=
15
,
weight
:
float
=
1.0
,
cpu_offload
:
bool
=
False
,
seq_p_group
=
None
,
)
->
None
:
)
->
None
:
self
.
seq_p_group
=
seq_p_group
self
.
audio_adapter
=
audio_adapter
self
.
audio_adapter
=
audio_adapter
self
.
dtype
=
dtype
self
.
dtype
=
dtype
self
.
audio_encoder_dtype
=
torch
.
float16
self
.
audio_encoder_dtype
=
torch
.
float16
...
@@ -415,4 +441,4 @@ class AudioAdapterPipe:
...
@@ -415,4 +441,4 @@ class AudioAdapterPipe:
if
dropout_cond
is
not
None
:
if
dropout_cond
is
not
None
:
audio_feat
=
dropout_cond
(
audio_feat
)
audio_feat
=
dropout_cond
(
audio_feat
)
return
self
.
audio_adapter
(
audio_feat
=
audio_feat
,
timestep
=
timestep
,
latent_frame
=
latent_frame
,
weight
=
self
.
weight
)
return
self
.
audio_adapter
(
audio_feat
=
audio_feat
,
timestep
=
timestep
,
latent_frame
=
latent_frame
,
weight
=
self
.
weight
,
seq_p_group
=
self
.
seq_p_group
)
lightx2v/models/networks/wan/audio_model.py
View file @
dd958c79
...
@@ -13,6 +13,8 @@ from lightx2v.models.networks.wan.weights.transformer_weights import (
...
@@ -13,6 +13,8 @@ from lightx2v.models.networks.wan.weights.transformer_weights import (
WanTransformerWeights
,
WanTransformerWeights
,
)
)
from
loguru
import
logger
class
WanAudioModel
(
WanModel
):
class
WanAudioModel
(
WanModel
):
pre_weight_class
=
WanPreWeights
pre_weight_class
=
WanPreWeights
...
@@ -65,6 +67,49 @@ class WanAudioModel(WanModel):
...
@@ -65,6 +67,49 @@ class WanAudioModel(WanModel):
self
.
pre_weight
.
to_cpu
()
self
.
pre_weight
.
to_cpu
()
self
.
post_weight
.
to_cpu
()
self
.
post_weight
.
to_cpu
()
@
torch
.
no_grad
()
def
infer_wo_cfg_parallel
(
self
,
inputs
):
if
self
.
cpu_offload
:
if
self
.
offload_granularity
==
"model"
and
self
.
scheduler
.
step_index
==
0
:
self
.
to_cuda
()
elif
self
.
offload_granularity
!=
"model"
:
self
.
pre_weight
.
to_cuda
()
self
.
post_weight
.
to_cuda
()
if
self
.
transformer_infer
.
mask_map
is
None
:
_
,
c
,
h
,
w
=
self
.
scheduler
.
latents
.
shape
num_frame
=
c
+
1
# for r2v
video_token_num
=
num_frame
*
(
h
//
2
)
*
(
w
//
2
)
self
.
transformer_infer
.
mask_map
=
MaskMap
(
video_token_num
,
num_frame
)
embed
,
grid_sizes
,
pre_infer_out
,
valid_patch_length
=
self
.
pre_infer
.
infer
(
self
.
pre_weight
,
inputs
,
positive
=
True
)
x
=
self
.
transformer_infer
.
infer
(
self
.
transformer_weights
,
grid_sizes
,
embed
,
*
pre_infer_out
)
noise_pred_cond
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
x
,
embed
,
grid_sizes
,
valid_patch_length
)[
0
]
self
.
scheduler
.
noise_pred
=
noise_pred_cond
if
self
.
clean_cuda_cache
:
del
x
,
embed
,
pre_infer_out
,
noise_pred_cond
,
grid_sizes
torch
.
cuda
.
empty_cache
()
if
self
.
config
[
"enable_cfg"
]:
embed
,
grid_sizes
,
pre_infer_out
=
self
.
pre_infer
.
infer
(
self
.
pre_weight
,
inputs
,
positive
=
False
)
x
=
self
.
transformer_infer
.
infer
(
self
.
transformer_weights
,
grid_sizes
,
embed
,
*
pre_infer_out
)
noise_pred_uncond
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
x
,
embed
,
grid_sizes
)[
0
]
self
.
scheduler
.
noise_pred
=
noise_pred_uncond
+
self
.
scheduler
.
sample_guide_scale
*
(
self
.
scheduler
.
noise_pred
-
noise_pred_uncond
)
if
self
.
clean_cuda_cache
:
del
x
,
embed
,
pre_infer_out
,
noise_pred_uncond
,
grid_sizes
torch
.
cuda
.
empty_cache
()
if
self
.
cpu_offload
:
if
self
.
offload_granularity
==
"model"
and
self
.
scheduler
.
step_index
==
self
.
scheduler
.
infer_steps
-
1
:
self
.
to_cpu
()
elif
self
.
offload_granularity
!=
"model"
:
self
.
pre_weight
.
to_cpu
()
self
.
post_weight
.
to_cpu
()
class
Wan22MoeAudioModel
(
WanAudioModel
):
class
Wan22MoeAudioModel
(
WanAudioModel
):
def
_load_ckpt
(
self
,
unified_dtype
,
sensitive_layer
):
def
_load_ckpt
(
self
,
unified_dtype
,
sensitive_layer
):
...
...
lightx2v/models/networks/wan/infer/dist_infer/transformer_infer.py
View file @
dd958c79
...
@@ -76,6 +76,7 @@ class WanTransformerDistInfer(WanTransformerInfer):
...
@@ -76,6 +76,7 @@ class WanTransformerDistInfer(WanTransformerInfer):
cur_rank
=
dist
.
get_rank
(
self
.
seq_p_group
)
cur_rank
=
dist
.
get_rank
(
self
.
seq_p_group
)
freqs
=
freqs
.
split
([
c
-
2
*
(
c
//
3
),
c
//
3
,
c
//
3
],
dim
=
1
)
freqs
=
freqs
.
split
([
c
-
2
*
(
c
//
3
),
c
//
3
,
c
//
3
],
dim
=
1
)
f
,
h
,
w
=
grid_sizes
[
0
].
tolist
()
f
,
h
,
w
=
grid_sizes
[
0
].
tolist
()
valid_token_length
=
f
*
h
*
w
f
=
f
+
1
f
=
f
+
1
seq_len
=
f
*
h
*
w
seq_len
=
f
*
h
*
w
freqs_i
=
torch
.
cat
(
freqs_i
=
torch
.
cat
(
...
@@ -87,6 +88,8 @@ class WanTransformerDistInfer(WanTransformerInfer):
...
@@ -87,6 +88,8 @@ class WanTransformerDistInfer(WanTransformerInfer):
dim
=-
1
,
dim
=-
1
,
).
reshape
(
seq_len
,
1
,
-
1
)
).
reshape
(
seq_len
,
1
,
-
1
)
freqs_i
[
valid_token_length
:,
:,
:
f
]
=
0
freqs_i
=
pad_freqs
(
freqs_i
,
s
*
world_size
)
freqs_i
=
pad_freqs
(
freqs_i
,
s
*
world_size
)
s_per_rank
=
s
s_per_rank
=
s
freqs_i_rank
=
freqs_i
[(
cur_rank
*
s_per_rank
)
:
((
cur_rank
+
1
)
*
s_per_rank
),
:,
:]
freqs_i_rank
=
freqs_i
[(
cur_rank
*
s_per_rank
)
:
((
cur_rank
+
1
)
*
s_per_rank
),
:,
:]
...
...
lightx2v/models/networks/wan/infer/transformer_infer.py
View file @
dd958c79
...
@@ -11,7 +11,6 @@ from lightx2v.utils.envs import *
...
@@ -11,7 +11,6 @@ from lightx2v.utils.envs import *
from
.utils
import
apply_rotary_emb
,
apply_rotary_emb_chunk
,
compute_freqs
,
compute_freqs_audio
from
.utils
import
apply_rotary_emb
,
apply_rotary_emb_chunk
,
compute_freqs
,
compute_freqs_audio
class
WanTransformerInfer
(
BaseTransformerInfer
):
class
WanTransformerInfer
(
BaseTransformerInfer
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
self
.
config
=
config
self
.
config
=
config
...
@@ -33,6 +32,7 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -33,6 +32,7 @@ class WanTransformerInfer(BaseTransformerInfer):
self
.
infer_dtype
=
GET_DTYPE
()
self
.
infer_dtype
=
GET_DTYPE
()
self
.
sensitive_layer_dtype
=
GET_SENSITIVE_DTYPE
()
self
.
sensitive_layer_dtype
=
GET_SENSITIVE_DTYPE
()
self
.
seq_p_group
=
None
if
self
.
config
.
get
(
"cpu_offload"
,
False
):
if
self
.
config
.
get
(
"cpu_offload"
,
False
):
if
torch
.
cuda
.
get_device_capability
(
0
)
==
(
9
,
0
):
if
torch
.
cuda
.
get_device_capability
(
0
)
==
(
9
,
0
):
assert
self
.
config
[
"self_attn_1_type"
]
!=
"sage_attn2"
assert
self
.
config
[
"self_attn_1_type"
]
!=
"sage_attn2"
...
@@ -360,8 +360,6 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -360,8 +360,6 @@ class WanTransformerInfer(BaseTransformerInfer):
freqs_i
=
self
.
compute_freqs
(
q
,
grid_sizes
,
freqs
)
freqs_i
=
self
.
compute_freqs
(
q
,
grid_sizes
,
freqs
)
freqs_i
=
self
.
zero_temporal_component_in_3DRoPE
(
seq_lens
,
freqs_i
)
q
=
self
.
apply_rotary_emb_func
(
q
,
freqs_i
)
q
=
self
.
apply_rotary_emb_func
(
q
,
freqs_i
)
k
=
self
.
apply_rotary_emb_func
(
k
,
freqs_i
)
k
=
self
.
apply_rotary_emb_func
(
k
,
freqs_i
)
...
...
lightx2v/models/networks/wan/infer/utils.py
View file @
dd958c79
...
@@ -22,6 +22,7 @@ def compute_freqs(c, grid_sizes, freqs):
...
@@ -22,6 +22,7 @@ def compute_freqs(c, grid_sizes, freqs):
def
compute_freqs_audio
(
c
,
grid_sizes
,
freqs
):
def
compute_freqs_audio
(
c
,
grid_sizes
,
freqs
):
freqs
=
freqs
.
split
([
c
-
2
*
(
c
//
3
),
c
//
3
,
c
//
3
],
dim
=
1
)
freqs
=
freqs
.
split
([
c
-
2
*
(
c
//
3
),
c
//
3
,
c
//
3
],
dim
=
1
)
f
,
h
,
w
=
grid_sizes
[
0
]
f
,
h
,
w
=
grid_sizes
[
0
]
valid_token_length
=
f
*
h
*
w
f
=
f
+
1
##for r2v add 1 channel
f
=
f
+
1
##for r2v add 1 channel
seq_len
=
f
*
h
*
w
seq_len
=
f
*
h
*
w
freqs_i
=
torch
.
cat
(
freqs_i
=
torch
.
cat
(
...
@@ -33,6 +34,8 @@ def compute_freqs_audio(c, grid_sizes, freqs):
...
@@ -33,6 +34,8 @@ def compute_freqs_audio(c, grid_sizes, freqs):
dim
=-
1
,
dim
=-
1
,
).
reshape
(
seq_len
,
1
,
-
1
)
).
reshape
(
seq_len
,
1
,
-
1
)
freqs_i
[
valid_token_length
:,
:,
:
f
]
=
0
###for r2v # zero temporl component corresponding to ref embeddings
return
freqs_i
return
freqs_i
...
...
lightx2v/models/runners/wan/wan_audio_runner.py
View file @
dd958c79
...
@@ -426,7 +426,15 @@ class WanAudioRunner(WanRunner): # type:ignore
...
@@ -426,7 +426,15 @@ class WanAudioRunner(WanRunner): # type:ignore
else
:
else
:
device
=
torch
.
device
(
"cuda"
)
device
=
torch
.
device
(
"cuda"
)
audio_encoder_repo
=
self
.
config
[
"model_path"
]
+
"/audio_encoder"
audio_encoder_repo
=
self
.
config
[
"model_path"
]
+
"/audio_encoder"
self
.
_audio_adapter_pipe
=
AudioAdapterPipe
(
audio_adapter
,
audio_encoder_repo
=
audio_encoder_repo
,
dtype
=
GET_DTYPE
(),
device
=
device
,
weight
=
1.0
,
cpu_offload
=
cpu_offload
)
if
self
.
model
.
transformer_infer
.
seq_p_group
is
not
None
:
seq_p_group
=
self
.
model
.
transformer_infer
.
seq_p_group
else
:
seq_p_group
=
None
self
.
_audio_adapter_pipe
=
AudioAdapterPipe
(
audio_adapter
,
audio_encoder_repo
=
audio_encoder_repo
,
dtype
=
GET_DTYPE
(),
device
=
device
,
weight
=
1.0
,
cpu_offload
=
cpu_offload
,
seq_p_group
=
seq_p_group
)
return
self
.
_audio_adapter_pipe
return
self
.
_audio_adapter_pipe
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment