Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
6060ff4f
Commit
6060ff4f
authored
Jul 03, 2025
by
wangshankun
Browse files
Support:radial attention
parent
b2147c40
Changes
15
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
297 additions
and
48 deletions
+297
-48
configs/wan_i2v_audio.json
configs/wan_i2v_audio.json
+8
-10
lightx2v/attentions/__init__.py
lightx2v/attentions/__init__.py
+3
-0
lightx2v/attentions/common/radial_attn.py
lightx2v/attentions/common/radial_attn.py
+199
-0
lightx2v/common/ops/attn/attn_weight.py
lightx2v/common/ops/attn/attn_weight.py
+27
-2
lightx2v/infer.py
lightx2v/infer.py
+3
-1
lightx2v/models/input_encoders/hf/xlm_roberta/model.py
lightx2v/models/input_encoders/hf/xlm_roberta/model.py
+1
-1
lightx2v/models/networks/wan/audio_adapter.py
lightx2v/models/networks/wan/audio_adapter.py
+0
-2
lightx2v/models/networks/wan/audio_model.py
lightx2v/models/networks/wan/audio_model.py
+11
-0
lightx2v/models/networks/wan/infer/transformer_infer.py
lightx2v/models/networks/wan/infer/transformer_infer.py
+4
-2
lightx2v/models/networks/wan/infer/utils.py
lightx2v/models/networks/wan/infer/utils.py
+1
-1
lightx2v/models/networks/wan/lora_adapter.py
lightx2v/models/networks/wan/lora_adapter.py
+1
-0
lightx2v/models/networks/wan/weights/transformer_weights.py
lightx2v/models/networks/wan/weights/transformer_weights.py
+3
-3
lightx2v/models/runners/wan/wan_audio_runner.py
lightx2v/models/runners/wan/wan_audio_runner.py
+32
-26
lightx2v/models/schedulers/wan/scheduler.py
lightx2v/models/schedulers/wan/scheduler.py
+3
-0
scripts/run_wan_i2v_audio.sh
scripts/run_wan_i2v_audio.sh
+1
-0
No files found.
configs/wan_i2v_audio.json
View file @
6060ff4f
{
"infer_steps"
:
8
,
"infer_steps"
:
5
,
"target_fps"
:
16
,
"video_duration"
:
12
,
"audio_sr"
:
16000
,
"target_video_length"
:
81
,
"target_height"
:
480
,
"target_width"
:
832
,
"attention_type"
:
"flash_attn3"
,
"self_attn_1_type"
:
"radial_attn"
,
"cross_attn_1_type"
:
"flash_attn3"
,
"cross_attn_2_type"
:
"flash_attn3"
,
"seed"
:
42
,
"sample_guide_scale"
:
1
,
"sample_shift"
:
5
,
"enable_cfg"
:
false
,
"cpu_offload"
:
false
,
"feature_caching"
:
"Tea"
,
"coefficients"
:
[
[
8.10705460e03
,
2.13393892e03
,
-3.72934672e02
,
1.66203073e01
,
-4.17769401e-02
],
[
-114.36346466
,
65.26524496
,
-18.82220707
,
4.91518089
,
-0.23412683
]
],
"use_ret_steps"
:
true
,
"teacache_thresh"
:
0.12
"cpu_offload"
:
false
}
lightx2v/attentions/__init__.py
View file @
6060ff4f
...
...
@@ -2,6 +2,7 @@ from lightx2v.attentions.common.torch_sdpa import torch_sdpa
from
lightx2v.attentions.common.flash_attn2
import
flash_attn2
from
lightx2v.attentions.common.flash_attn3
import
flash_attn3
from
lightx2v.attentions.common.sage_attn2
import
sage_attn2
from
lightx2v.attentions.common.radial_attn
import
radial_attn
def
attention
(
attention_type
=
"flash_attn2"
,
*
args
,
**
kwargs
):
...
...
@@ -13,5 +14,7 @@ def attention(attention_type="flash_attn2", *args, **kwargs):
return
flash_attn3
(
*
args
,
**
kwargs
)
elif
attention_type
==
"sage_attn2"
:
return
sage_attn2
(
*
args
,
**
kwargs
)
elif
attention_type
==
"radial_attn"
:
return
radial_attn
(
*
args
,
**
kwargs
)
else
:
raise
NotImplementedError
(
f
"Unsupported attention mode:
{
attention_type
}
"
)
lightx2v/attentions/common/radial_attn.py
0 → 100644
View file @
6060ff4f
import
torch
import
flashinfer
###
### Code from radial-attention
### https://github.com/mit-han-lab/ç/blob/main/radial_attn/attn_mask.py#L150
###
def
radial_attn
(
query
,
key
,
value
,
cu_seqlens_q
=
None
,
cu_seqlens_kv
=
None
,
max_seqlen_q
=
None
,
max_seqlen_kv
=
None
,
mask_map
=
None
,
sparsity_type
=
"radial"
,
block_size
=
128
,
decay_factor
=
1
,
model_cls
=
"wan"
):
orig_seqlen
,
num_head
,
hidden_dim
=
query
.
shape
query
=
pad_qkv
(
query
,
block_size
=
block_size
)
key
=
pad_qkv
(
key
,
block_size
=
block_size
)
value
=
pad_qkv
(
value
,
block_size
=
block_size
)
mask
=
mask_map
.
queryLogMask
(
query
,
sparsity_type
,
block_size
=
block_size
,
decay_factor
=
decay_factor
,
model_type
=
model_cls
)
if
mask_map
else
None
seqlen
=
query
.
shape
[
0
]
workspace_buffer
=
torch
.
empty
(
128
*
1024
*
1024
,
device
=
query
.
device
,
dtype
=
torch
.
uint8
)
bsr_wrapper
=
flashinfer
.
BlockSparseAttentionWrapper
(
workspace_buffer
,
backend
=
"fa2"
,
)
indptr
=
get_indptr_from_mask
(
mask
,
query
)
indices
=
get_indices_from_mask
(
mask
,
query
)
bsr_wrapper
.
plan
(
indptr
=
indptr
,
indices
=
indices
,
M
=
seqlen
,
N
=
seqlen
,
R
=
block_size
,
C
=
block_size
,
num_qo_heads
=
num_head
,
num_kv_heads
=
num_head
,
head_dim
=
hidden_dim
,
q_data_type
=
query
.
dtype
,
kv_data_type
=
key
.
dtype
,
o_data_type
=
query
.
dtype
,
use_fp16_qk_reduction
=
True
,
)
o
=
bsr_wrapper
.
run
(
query
,
key
,
value
)
return
o
[:
orig_seqlen
,
:,
:]
def
get_indptr_from_mask
(
mask
,
query
):
# query shows the device of the indptr
# indptr (torch.Tensor) - the block index pointer of the block-sparse matrix on row dimension,
# shape `(MB + 1,)`, where `MB` is the number of blocks in the row dimension.
# The first element is always 0, and the last element is the number of blocks in the row dimension.
# The rest of the elements are the number of blocks in each row.
# the mask is already a block sparse mask
indptr
=
torch
.
zeros
(
mask
.
shape
[
0
]
+
1
,
device
=
query
.
device
,
dtype
=
torch
.
int32
)
indptr
[
0
]
=
0
row_counts
=
mask
.
sum
(
dim
=
1
).
flatten
()
# Ensure 1D output [num_blocks_row]
indptr
[
1
:]
=
torch
.
cumsum
(
row_counts
,
dim
=
0
)
return
indptr
def
get_indices_from_mask
(
mask
,
query
):
# indices (torch.Tensor) - the block indices of the block-sparse matrix on column dimension,
# shape `(nnz,),` where `nnz` is the number of non-zero blocks.
# The elements in `indices` array should be less than `NB`: the number of blocks in the column dimension.
nonzero_indices
=
torch
.
nonzero
(
mask
)
indices
=
nonzero_indices
[:,
1
].
to
(
dtype
=
torch
.
int32
,
device
=
query
.
device
)
return
indices
def
shrinkMaskStrict
(
mask
,
block_size
=
128
):
seqlen
=
mask
.
shape
[
0
]
block_num
=
seqlen
//
block_size
mask
=
mask
[:
block_num
*
block_size
,
:
block_num
*
block_size
].
view
(
block_num
,
block_size
,
block_num
,
block_size
)
col_densities
=
mask
.
sum
(
dim
=
1
)
/
block_size
# we want the minimum non-zero column density in the block
non_zero_densities
=
col_densities
>
0
high_density_cols
=
col_densities
>
1
/
3
frac_high_density_cols
=
high_density_cols
.
sum
(
dim
=-
1
)
/
(
non_zero_densities
.
sum
(
dim
=-
1
)
+
1e-9
)
block_mask
=
frac_high_density_cols
>
0.6
block_mask
[
0
:
0
]
=
True
block_mask
[
-
1
:
-
1
]
=
True
return
block_mask
def
pad_qkv
(
input_tensor
,
block_size
=
128
):
"""
Pad the input tensor to be a multiple of the block size.
input shape: (seqlen, num_heads, hidden_dim)
"""
seqlen
,
num_heads
,
hidden_dim
=
input_tensor
.
shape
# Calculate the necessary padding
padding_length
=
(
block_size
-
(
seqlen
%
block_size
))
%
block_size
# Create a padded tensor with zeros
padded_tensor
=
torch
.
zeros
((
seqlen
+
padding_length
,
num_heads
,
hidden_dim
),
device
=
input_tensor
.
device
,
dtype
=
input_tensor
.
dtype
)
# Copy the original tensor into the padded tensor
padded_tensor
[:
seqlen
,
:,
:]
=
input_tensor
return
padded_tensor
def
get_diagonal_split_mask
(
i
,
j
,
token_per_frame
,
sparse_type
,
query
):
assert
sparse_type
in
[
"radial"
]
dist
=
abs
(
i
-
j
)
group
=
dist
.
bit_length
()
threshold
=
128
# hardcoded threshold for now, which is equal to block-size
decay_length
=
2
**
token_per_frame
.
bit_length
()
/
2
**
group
if
decay_length
>=
threshold
:
return
torch
.
ones
((
token_per_frame
,
token_per_frame
),
device
=
query
.
device
,
dtype
=
torch
.
bool
)
split_factor
=
int
(
threshold
/
decay_length
)
modular
=
dist
%
split_factor
if
modular
==
0
:
return
torch
.
ones
((
token_per_frame
,
token_per_frame
),
device
=
query
.
device
,
dtype
=
torch
.
bool
)
else
:
return
torch
.
zeros
((
token_per_frame
,
token_per_frame
),
device
=
query
.
device
,
dtype
=
torch
.
bool
)
def
get_window_width
(
i
,
j
,
token_per_frame
,
sparse_type
,
num_frame
,
decay_factor
=
1
,
block_size
=
128
,
model_type
=
None
):
assert
sparse_type
in
[
"radial"
]
dist
=
abs
(
i
-
j
)
if
model_type
==
"wan"
:
if
dist
<
1
:
return
token_per_frame
if
dist
==
1
:
return
token_per_frame
//
2
elif
model_type
==
"hunyuan"
:
if
dist
<=
1
:
return
token_per_frame
else
:
raise
ValueError
(
f
"Unknown model type:
{
model_type
}
"
)
group
=
dist
.
bit_length
()
decay_length
=
2
**
token_per_frame
.
bit_length
()
/
2
**
group
*
decay_factor
threshold
=
block_size
if
decay_length
>=
threshold
:
return
decay_length
else
:
return
threshold
def
gen_log_mask_shrinked
(
query
,
s
,
video_token_num
,
num_frame
,
block_size
=
128
,
sparse_type
=
"log"
,
decay_factor
=
0.5
,
model_type
=
None
):
"""
A more memory friendly version, we generate the attention mask of each frame pair at a time,
shrinks it, and stores it into the final result
"""
final_log_mask
=
torch
.
zeros
((
s
//
block_size
,
s
//
block_size
),
device
=
query
.
device
,
dtype
=
torch
.
bool
)
token_per_frame
=
video_token_num
//
num_frame
video_text_border
=
video_token_num
//
block_size
col_indices
=
torch
.
arange
(
0
,
token_per_frame
,
device
=
query
.
device
).
view
(
1
,
-
1
)
row_indices
=
torch
.
arange
(
0
,
token_per_frame
,
device
=
query
.
device
).
view
(
-
1
,
1
)
final_log_mask
[
video_text_border
:]
=
True
final_log_mask
[:,
video_text_border
:]
=
True
for
i
in
range
(
num_frame
):
for
j
in
range
(
num_frame
):
local_mask
=
torch
.
zeros
((
token_per_frame
,
token_per_frame
),
device
=
query
.
device
,
dtype
=
torch
.
bool
)
if
j
==
0
:
# this is attention sink
local_mask
=
torch
.
ones
((
token_per_frame
,
token_per_frame
),
device
=
query
.
device
,
dtype
=
torch
.
bool
)
else
:
window_width
=
get_window_width
(
i
,
j
,
token_per_frame
,
sparse_type
,
num_frame
,
decay_factor
=
decay_factor
,
block_size
=
block_size
,
model_type
=
model_type
)
local_mask
=
torch
.
abs
(
col_indices
-
row_indices
)
<=
window_width
split_mask
=
get_diagonal_split_mask
(
i
,
j
,
token_per_frame
,
sparse_type
,
query
)
local_mask
=
torch
.
logical_and
(
local_mask
,
split_mask
)
remainder_row
=
(
i
*
token_per_frame
)
%
block_size
remainder_col
=
(
j
*
token_per_frame
)
%
block_size
# get the padded size
all_length_row
=
remainder_row
+
((
token_per_frame
-
1
)
//
block_size
+
1
)
*
block_size
all_length_col
=
remainder_col
+
((
token_per_frame
-
1
)
//
block_size
+
1
)
*
block_size
padded_local_mask
=
torch
.
zeros
((
all_length_row
,
all_length_col
),
device
=
query
.
device
,
dtype
=
torch
.
bool
)
padded_local_mask
[
remainder_row
:
remainder_row
+
token_per_frame
,
remainder_col
:
remainder_col
+
token_per_frame
]
=
local_mask
# shrink the mask
block_mask
=
shrinkMaskStrict
(
padded_local_mask
,
block_size
=
block_size
)
# set the block mask to the final log mask
block_row_start
=
(
i
*
token_per_frame
)
//
block_size
block_col_start
=
(
j
*
token_per_frame
)
//
block_size
block_row_end
=
block_row_start
+
block_mask
.
shape
[
0
]
block_col_end
=
block_col_start
+
block_mask
.
shape
[
1
]
final_log_mask
[
block_row_start
:
block_row_end
,
block_col_start
:
block_col_end
]
=
torch
.
logical_or
(
final_log_mask
[
block_row_start
:
block_row_end
,
block_col_start
:
block_col_end
],
block_mask
)
print
(
f
"mask sparsity:
{
1
-
final_log_mask
.
sum
()
/
final_log_mask
.
numel
()
}
"
)
return
final_log_mask
class
MaskMap
:
def
__init__
(
self
,
video_token_num
=
79200
,
num_frame
=
22
):
self
.
video_token_num
=
video_token_num
self
.
num_frame
=
num_frame
self
.
log_mask
=
None
def
queryLogMask
(
self
,
query
,
sparse_type
,
block_size
=
128
,
decay_factor
=
0.5
,
model_type
=
None
):
log_mask
=
torch
.
ones
((
query
.
shape
[
0
]
//
block_size
,
query
.
shape
[
0
]
//
block_size
),
device
=
query
.
device
,
dtype
=
torch
.
bool
)
if
self
.
log_mask
is
None
:
self
.
log_mask
=
gen_log_mask_shrinked
(
query
,
query
.
shape
[
0
],
self
.
video_token_num
,
self
.
num_frame
,
sparse_type
=
sparse_type
,
decay_factor
=
decay_factor
,
model_type
=
model_type
,
block_size
=
block_size
)
block_bound
=
self
.
video_token_num
//
block_size
log_mask
[:
block_bound
,
:
block_bound
]
=
self
.
log_mask
[:
block_bound
,
:
block_bound
]
return
log_mask
lightx2v/common/ops/attn/attn_weight.py
View file @
6060ff4f
...
...
@@ -37,6 +37,9 @@ else:
sageattn
=
None
from
lightx2v.attentions.common.radial_attn
import
radial_attn
class
AttnWeightTemplate
(
metaclass
=
ABCMeta
):
def
__init__
(
self
,
weight_name
):
self
.
weight_name
=
weight_name
...
...
@@ -70,7 +73,7 @@ class FlashAttn2Weight(AttnWeightTemplate):
def
__init__
(
self
):
self
.
config
=
{}
def
apply
(
self
,
q
,
k
,
v
,
cu_seqlens_q
=
None
,
cu_seqlens_kv
=
None
,
max_seqlen_q
=
None
,
max_seqlen_kv
=
None
,
model_cls
=
None
):
def
apply
(
self
,
q
,
k
,
v
,
cu_seqlens_q
=
None
,
cu_seqlens_kv
=
None
,
max_seqlen_q
=
None
,
max_seqlen_kv
=
None
,
model_cls
=
None
,
mask_map
=
None
):
x
=
flash_attn_varlen_func
(
q
,
k
,
...
...
@@ -88,7 +91,7 @@ class FlashAttn3Weight(AttnWeightTemplate):
def
__init__
(
self
):
self
.
config
=
{}
def
apply
(
self
,
q
,
k
,
v
,
cu_seqlens_q
=
None
,
cu_seqlens_kv
=
None
,
max_seqlen_q
=
None
,
max_seqlen_kv
=
None
,
model_cls
=
None
):
def
apply
(
self
,
q
,
k
,
v
,
cu_seqlens_q
=
None
,
cu_seqlens_kv
=
None
,
max_seqlen_q
=
None
,
max_seqlen_kv
=
None
,
model_cls
=
None
,
mask_map
=
None
):
x
=
flash_attn_varlen_func_v3
(
q
,
k
,
...
...
@@ -101,6 +104,28 @@ class FlashAttn3Weight(AttnWeightTemplate):
return
x
@
ATTN_WEIGHT_REGISTER
(
"radial_attn"
)
class
RadialAttnWeight
(
AttnWeightTemplate
):
def
__init__
(
self
):
self
.
config
=
{}
def
apply
(
self
,
q
,
k
,
v
,
cu_seqlens_q
=
None
,
cu_seqlens_kv
=
None
,
max_seqlen_q
=
None
,
max_seqlen_kv
=
None
,
mask_map
=
None
,
sparsity_type
=
"radial"
,
block_size
=
128
,
decay_factor
=
1
,
model_cls
=
"wan"
):
assert
len
(
q
.
shape
)
==
3
x
=
radial_attn
(
q
,
k
,
v
,
mask_map
=
mask_map
,
sparsity_type
=
sparsity_type
,
block_size
=
block_size
,
model_cls
=
model_cls
[:
3
],
# Use first 3 characters to match "wan", "wan2", etc.
decay_factor
=
decay_factor
,
)
x
=
x
.
view
(
max_seqlen_q
,
-
1
)
return
x
@
ATTN_WEIGHT_REGISTER
(
"sage_attn2"
)
class
SageAttn2Weight
(
AttnWeightTemplate
):
def
__init__
(
self
):
...
...
lightx2v/infer.py
View file @
6060ff4f
...
...
@@ -42,7 +42,9 @@ def init_runner(config):
async
def
main
():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--model_cls"
,
type
=
str
,
required
=
True
,
choices
=
[
"wan2.1"
,
"hunyuan"
,
"wan2.1_distill"
,
"wan2.1_causvid"
,
"wan2.1_skyreels_v2_df"
,
"cogvideox"
,
"wan2.1_audio"
],
default
=
"hunyuan"
)
parser
.
add_argument
(
"--model_cls"
,
type
=
str
,
required
=
True
,
choices
=
[
"wan2.1"
,
"hunyuan"
,
"wan2.1_distill"
,
"wan2.1_causvid"
,
"wan2.1_skyreels_v2_df"
,
"cogvideox"
,
"wan2.1_audio"
],
default
=
"hunyuan"
)
parser
.
add_argument
(
"--task"
,
type
=
str
,
choices
=
[
"t2v"
,
"i2v"
],
default
=
"t2v"
)
parser
.
add_argument
(
"--model_path"
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
"--config_json"
,
type
=
str
,
required
=
True
)
...
...
lightx2v/models/input_encoders/hf/xlm_roberta/model.py
View file @
6060ff4f
...
...
@@ -449,7 +449,7 @@ class WanVideoIPHandler:
but Wan2.1 official use no_crop resize by default
so I don't use CLIPImageProcessor
"""
image_encoder
=
CLIPVisionModel
.
from_pretrained
(
repo_or_path
,
subfolder
=
"image_encoder"
,
torch_dtype
=
dtype
)
image_encoder
=
CLIPVisionModel
.
from_pretrained
(
repo_or_path
,
torch_dtype
=
dtype
)
logger
.
info
(
f
"Using image encoder
{
model_name
}
from
{
repo_or_path
}
"
)
image_encoder
.
requires_grad_
(
require_grad
)
if
mode
==
"eval"
:
...
...
lightx2v/models/networks/wan/audio_adapter.py
View file @
6060ff4f
...
...
@@ -7,11 +7,9 @@ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
from
einops
import
rearrange
from
transformers
import
AutoModel
from
loguru
import
logger
import
pdb
import
os
import
safetensors
from
typing
import
List
,
Optional
,
Tuple
,
Union
...
...
lightx2v/models/networks/wan/audio_model.py
View file @
6060ff4f
...
...
@@ -19,6 +19,8 @@ from safetensors import safe_open
import
lightx2v.attentions.distributed.ulysses.wrap
as
ulysses_dist_wrap
import
lightx2v.attentions.distributed.ring.wrap
as
ring_dist_wrap
from
lightx2v.attentions.common.radial_attn
import
MaskMap
from
lightx2v.models.networks.wan.infer.transformer_infer
import
(
WanTransformerInfer
,
)
...
...
@@ -51,6 +53,15 @@ class WanAudioModel(WanModel):
self
.
pre_weight
.
to_cuda
()
self
.
post_weight
.
to_cuda
()
if
self
.
transformer_infer
.
mask_map
is
None
:
_
,
c
,
h
,
w
=
self
.
scheduler
.
latents
.
shape
num_frame
=
c
+
1
# for r2v
video_token_num
=
num_frame
*
(
h
//
2
)
*
(
w
//
2
)
from
loguru
import
logger
logger
.
info
(
f
"video_token_num:
{
video_token_num
}
, num_frame:
{
num_frame
}
"
)
self
.
transformer_infer
.
mask_map
=
MaskMap
(
video_token_num
,
num_frame
)
embed
,
grid_sizes
,
pre_infer_out
,
valid_patch_length
=
self
.
pre_infer
.
infer
(
self
.
pre_weight
,
inputs
,
positive
=
True
)
x
=
self
.
transformer_infer
.
infer
(
self
.
transformer_weights
,
grid_sizes
,
embed
,
*
pre_infer_out
)
noise_pred_cond
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
x
,
embed
,
grid_sizes
,
valid_patch_length
)[
0
]
...
...
lightx2v/models/networks/wan/infer/transformer_infer.py
View file @
6060ff4f
...
...
@@ -7,7 +7,6 @@ from lightx2v.common.offload.manager import (
from
lightx2v.common.transformer_infer.transformer_infer
import
BaseTransformerInfer
from
lightx2v.utils.envs
import
*
from
loguru
import
logger
import
pdb
import
os
...
...
@@ -24,6 +23,8 @@ class WanTransformerInfer(BaseTransformerInfer):
self
.
parallel_attention
=
None
self
.
apply_rotary_emb_func
=
apply_rotary_emb_chunk
if
config
.
get
(
"rotary_chunk"
,
False
)
else
apply_rotary_emb
self
.
clean_cuda_cache
=
self
.
config
.
get
(
"clean_cuda_cache"
,
False
)
self
.
mask_map
=
None
if
self
.
config
[
"cpu_offload"
]:
if
"offload_ratio"
in
self
.
config
:
offload_ratio
=
self
.
config
[
"offload_ratio"
]
...
...
@@ -290,7 +291,7 @@ class WanTransformerInfer(BaseTransformerInfer):
if
not
self
.
parallel_attention
:
if
self
.
config
.
get
(
"audio_sr"
,
False
):
freqs_i
=
compute_freqs_audio
(
q
.
size
(
0
),
q
.
size
(
2
)
//
2
,
grid_sizes
,
freqs
)
freqs_i
=
compute_freqs_audio
(
q
.
size
(
2
)
//
2
,
grid_sizes
,
freqs
)
else
:
freqs_i
=
compute_freqs
(
q
.
size
(
2
)
//
2
,
grid_sizes
,
freqs
)
else
:
...
...
@@ -321,6 +322,7 @@ class WanTransformerInfer(BaseTransformerInfer):
max_seqlen_q
=
q
.
size
(
0
),
max_seqlen_kv
=
k
.
size
(
0
),
model_cls
=
self
.
config
[
"model_cls"
],
mask_map
=
self
.
mask_map
,
)
else
:
attn_out
=
self
.
parallel_attention
(
...
...
lightx2v/models/networks/wan/infer/utils.py
View file @
6060ff4f
...
...
@@ -23,7 +23,7 @@ def compute_freqs(c, grid_sizes, freqs):
def
compute_freqs_audio
(
c
,
grid_sizes
,
freqs
):
freqs
=
freqs
.
split
([
c
-
2
*
(
c
//
3
),
c
//
3
,
c
//
3
],
dim
=
1
)
f
,
h
,
w
=
grid_sizes
[
0
].
tolist
()
f
=
f
+
1
f
=
f
+
1
##for r2v add 1 channel
seq_len
=
f
*
h
*
w
freqs_i
=
torch
.
cat
(
[
...
...
lightx2v/models/networks/wan/lora_adapter.py
View file @
6060ff4f
...
...
@@ -50,6 +50,7 @@ class WanLoraWrapper:
self
.
model
.
_init_weights
(
weight_dict
)
logger
.
info
(
f
"Applied LoRA:
{
lora_name
}
with alpha=
{
alpha
}
"
)
del
lora_weights
# 删除节约显存
return
True
@
torch
.
no_grad
()
...
...
lightx2v/models/networks/wan/weights/transformer_weights.py
View file @
6060ff4f
...
...
@@ -184,7 +184,7 @@ class WanSelfAttention(WeightModule):
sparge_ckpt
=
torch
.
load
(
self
.
config
[
"sparge_ckpt"
])
self
.
self_attn_1
.
load
(
sparge_ckpt
)
else
:
self
.
add_module
(
"self_attn_1"
,
ATTN_WEIGHT_REGISTER
[
self
.
config
[
"
attention
_type"
]]())
self
.
add_module
(
"self_attn_1"
,
ATTN_WEIGHT_REGISTER
[
self
.
config
[
"
self_attn_1
_type"
]]())
if
self
.
quant_method
in
[
"smoothquant"
,
"awq"
]:
self
.
add_module
(
"smooth_norm1_weight"
,
...
...
@@ -275,7 +275,7 @@ class WanCrossAttention(WeightModule):
self
.
lazy_load_file
,
),
)
self
.
add_module
(
"cross_attn_1"
,
ATTN_WEIGHT_REGISTER
[
self
.
config
[
"
attention
_type"
]]())
self
.
add_module
(
"cross_attn_1"
,
ATTN_WEIGHT_REGISTER
[
self
.
config
[
"
cross_attn_1
_type"
]]())
if
self
.
config
.
task
==
"i2v"
:
self
.
add_module
(
...
...
@@ -304,7 +304,7 @@ class WanCrossAttention(WeightModule):
self
.
lazy_load_file
,
),
)
self
.
add_module
(
"cross_attn_2"
,
ATTN_WEIGHT_REGISTER
[
self
.
config
[
"
attention
_type"
]]())
self
.
add_module
(
"cross_attn_2"
,
ATTN_WEIGHT_REGISTER
[
self
.
config
[
"
cross_attn_2
_type"
]]())
class
WanFFN
(
WeightModule
):
...
...
lightx2v/models/runners/wan/wan_audio_runner.py
View file @
6060ff4f
...
...
@@ -31,7 +31,6 @@ from torchvision.transforms.functional import resize
import
subprocess
import
warnings
from
typing
import
Optional
,
Tuple
,
Union
import
pdb
def
get_crop_bbox
(
ori_h
,
ori_w
,
tgt_h
,
tgt_w
):
...
...
@@ -210,7 +209,6 @@ def generate_unique_path(path):
def
save_to_video
(
gen_lvideo
,
out_path
,
target_fps
):
print
(
gen_lvideo
.
shape
)
gen_lvideo
=
rearrange
(
gen_lvideo
,
"B C T H W -> B T H W C"
)
gen_lvideo
=
(
gen_lvideo
[
0
].
cpu
().
numpy
()
*
127.5
+
127.5
).
astype
(
np
.
uint8
)
gen_lvideo
=
gen_lvideo
[...,
::
-
1
].
copy
()
...
...
@@ -219,21 +217,29 @@ def save_to_video(gen_lvideo, out_path, target_fps):
def
save_audio
(
audio_array
:
str
,
audio_array
,
audio_name
:
str
,
video_name
:
str
=
None
,
sr
:
int
=
16000
,
):
logger
.
info
(
f
"Saving audio to
{
audio_name
}
type:
{
type
(
audio_array
)
}
"
)
if
not
os
.
path
.
exists
(
audio_name
):
ta
.
save
(
audio_name
,
torch
.
tensor
(
audio_array
[
None
]),
sample_rate
=
sr
,
)
ta
.
save
(
audio_name
,
torch
.
tensor
(
audio_array
[
None
]),
sample_rate
=
sr
,
)
out_video
=
f
"
{
video_name
[:
-
4
]
}
_with_audio.mp4"
# generate_unique_path(out_path)
# 确保父目录存在
parent_dir
=
os
.
path
.
dirname
(
out_video
)
if
parent_dir
and
not
os
.
path
.
exists
(
parent_dir
):
os
.
makedirs
(
parent_dir
,
exist_ok
=
True
)
# 如果输出视频已存在,先删除
if
os
.
path
.
exists
(
out_video
):
os
.
remove
(
out_video
)
cmd
=
f
"/usr/bin/ffmpeg -i
{
video_name
}
-i
{
audio_name
}
{
out_video
}
"
subprocess
.
call
(
cmd
,
shell
=
True
)
...
...
@@ -246,6 +252,9 @@ class WanAudioRunner(WanRunner):
def
load_audio_models
(
self
):
##音频特征提取器
self
.
audio_preprocess
=
AutoFeatureExtractor
.
from_pretrained
(
self
.
config
[
"model_path"
],
subfolder
=
"audio_encoder"
)
##音频驱动视频生成adapter
audio_adapter_path
=
self
.
config
[
"model_path"
]
+
"/audio_adapter.safetensors"
audio_adaper
=
AudioAdapter
.
from_transformer
(
self
.
model
,
audio_feature_dim
=
1024
,
...
...
@@ -253,11 +262,11 @@ class WanAudioRunner(WanRunner):
time_freq_dim
=
256
,
projection_transformer_layers
=
4
,
)
load_path
=
"/mnt/aigc/zoemodels/Zoetrained/vigendit/audio_driven/audio_adapter/audio_adapter_V1_0507_bf16.safetensors"
audio_adapter
=
rank0_load_state_dict_from_path
(
audio_adaper
,
load_path
,
strict
=
False
)
audio_adapter
=
rank0_load_state_dict_from_path
(
audio_adaper
,
audio_adapter_path
,
strict
=
False
)
##音频特征编码器
device
=
self
.
model
.
device
audio_encoder_repo
=
"/mnt/aigc/zoemodels/models--TencentGameMate--chinese-hubert-large/snapshots/90cb660492214f687e60f5ca509b20edae6e75bd
"
audio_encoder_repo
=
self
.
config
[
"model_path"
]
+
"/audio_encoder
"
audio_adapter_pipe
=
AudioAdapterPipe
(
audio_adapter
,
audio_encoder_repo
=
audio_encoder_repo
,
dtype
=
torch
.
bfloat16
,
device
=
device
,
generator
=
torch
.
Generator
(
device
),
weight
=
1.0
)
return
audio_adapter_pipe
...
...
@@ -275,9 +284,8 @@ class WanAudioRunner(WanRunner):
return
base_model
def
load_image_encoder
(
self
):
image_encoder
=
WanVideoIPHandler
(
"CLIPModel"
,
repo_or_path
=
"/mnt/aigc/zoemodels/Wan21/Wan2.1-I2V-14B-720P-Diffusers"
,
require_grad
=
False
,
mode
=
"eval"
,
device
=
self
.
init_device
,
dtype
=
torch
.
float16
)
clip_model_dir
=
self
.
config
[
"model_path"
]
+
"/image_encoder"
image_encoder
=
WanVideoIPHandler
(
"CLIPModel"
,
repo_or_path
=
clip_model_dir
,
require_grad
=
False
,
mode
=
"eval"
,
device
=
self
.
init_device
,
dtype
=
torch
.
float16
)
return
image_encoder
...
...
@@ -325,6 +333,7 @@ class WanAudioRunner(WanRunner):
self
.
set_target_shape
()
self
.
inputs
=
{
"text_encoder_output"
:
text_encoder_output
,
"image_encoder_output"
:
image_encoder_output
}
del
self
.
image_encoder
# 删除ref的clip模型,只使用一次
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
...
...
@@ -360,15 +369,15 @@ class WanAudioRunner(WanRunner):
self
.
inputs
[
"audio_adapter_pipe"
]
=
self
.
load_audio_models
()
# process audio
audio_sr
=
16000
max_num_frames
=
81
# wan2.1一段最多81帧,5秒,16fps
audio_sr
=
self
.
config
.
get
(
"audio_sr"
,
16000
)
max_num_frames
=
self
.
config
.
get
(
"target_video_length"
,
81
)
# wan2.1一段最多81帧,5秒,16fps
target_fps
=
self
.
config
.
get
(
"target_fps"
,
16
)
# 音视频同步帧率
video_duration
=
self
.
config
.
get
(
"video_duration"
,
8
)
# 期望视频输出时长
video_duration
=
self
.
config
.
get
(
"video_duration"
,
5
)
# 期望视频输出时长
audio_array
=
load_audio
(
self
.
config
[
"audio_path"
],
sr
=
audio_sr
)
audio_len
=
int
(
audio_array
.
shape
[
0
]
/
audio_sr
*
target_fps
)
prev_frame_length
=
5
prev_token_length
=
(
prev_frame_length
-
1
)
//
4
+
1
max_num_audio_length
=
int
((
max_num_frames
+
1
)
/
target_fps
*
16000
)
max_num_audio_length
=
int
((
max_num_frames
+
1
)
/
target_fps
*
audio_sr
)
interval_num
=
1
# expected_frames
...
...
@@ -463,13 +472,10 @@ class WanAudioRunner(WanRunner):
latents
=
self
.
model
.
scheduler
.
latents
generator
=
self
.
model
.
scheduler
.
generator
gen_video
=
self
.
vae_decoder
.
decode
(
latents
,
generator
=
generator
,
config
=
self
.
config
)
# gen_img = vae_handler.decode(xt.to(vae_dtype))
# B, C, T, H, W
gen_video
=
torch
.
clamp
(
gen_video
,
-
1
,
1
)
start_frame
=
0
if
idx
==
0
else
prev_frame_length
start_audio_frame
=
0
if
idx
==
0
else
int
((
prev_frame_length
+
1
)
*
audio_sr
/
target_fps
)
print
(
f
"----
{
idx
}
,
{
gen_video
[:,
:,
start_frame
:].
shape
}
"
)
if
res_frame_num
>
5
and
idx
==
interval_num
-
1
:
gen_video_list
.
append
(
gen_video
[:,
:,
start_frame
:
res_frame_num
])
cut_audio_list
.
append
(
audio_array
[
start_audio_frame
:
useful_length
])
...
...
@@ -482,7 +488,7 @@ class WanAudioRunner(WanRunner):
gen_lvideo
=
torch
.
cat
(
gen_video_list
,
dim
=
2
).
float
()
merge_audio
=
np
.
concatenate
(
cut_audio_list
,
axis
=
0
).
astype
(
np
.
float32
)
out_path
=
os
.
path
.
join
(
"./"
,
"video_merge.mp4"
)
out_path
=
self
.
config
.
save_video_path
audio_file
=
os
.
path
.
join
(
"./"
,
"audio_merge.wav"
)
save_to_video
(
gen_lvideo
,
out_path
,
target_fps
)
save_audio
(
merge_audio
,
audio_file
,
out_path
)
...
...
@@ -501,5 +507,5 @@ class WanAudioRunner(WanRunner):
self
.
run
()
self
.
end_run
()
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
lightx2v/models/schedulers/wan/scheduler.py
View file @
6060ff4f
import
math
import
numpy
as
np
import
torch
import
gc
from
typing
import
List
,
Optional
,
Tuple
,
Union
from
lightx2v.models.schedulers.scheduler
import
BaseScheduler
...
...
@@ -123,6 +124,8 @@ class WanScheduler(BaseScheduler):
self
.
this_order
=
None
self
.
lower_order_nums
=
0
self
.
prepare_latents
(
self
.
config
.
target_shape
,
dtype
=
torch
.
float32
)
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
def
multistep_uni_p_bh_update
(
self
,
...
...
scripts/run_wan_i2v_audio.sh
View file @
6060ff4f
...
...
@@ -27,6 +27,7 @@ export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export
ENABLE_PROFILING_DEBUG
=
true
export
ENABLE_GRAPH_MODE
=
false
export
DTYPE
=
BF16
python
-m
lightx2v.infer
\
--model_cls
wan2.1_audio
\
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment