Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
837feba7
Unverified
Commit
837feba7
authored
Nov 03, 2025
by
Yang Yong (雍洋)
Committed by
GitHub
Nov 03, 2025
Browse files
Update sparse attention (#432)
parent
fc231d3d
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
29 additions
and
26 deletions
+29
-26
lightx2v/common/ops/attn/nbhd_attn.py
lightx2v/common/ops/attn/nbhd_attn.py
+10
-12
lightx2v/common/ops/attn/svg_attn.py
lightx2v/common/ops/attn/svg_attn.py
+15
-14
lightx2v/models/networks/wan/weights/transformer_weights.py
lightx2v/models/networks/wan/weights/transformer_weights.py
+2
-0
lightx2v/utils/set_config.py
lightx2v/utils/set_config.py
+2
-0
No files found.
lightx2v/common/ops/attn/nbhd_attn.py
View file @
837feba7
...
...
@@ -11,11 +11,11 @@ from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER
from
.template
import
AttnWeightTemplate
def
generate_nbhd_mask
(
a
,
block_num
,
num
_frame
,
device
=
"cpu"
):
def
generate_nbhd_mask
(
a
,
block_num
,
attnmap
_frame
_num
,
device
=
"cpu"
):
"""
a : block num per frame
block_num : block num per col/row
num
_frame : total frame num
attnmap
_frame
_num
: total frame num
"""
i_indices
=
torch
.
arange
(
block_num
,
device
=
device
).
unsqueeze
(
1
)
# [block_num, 1]
j_indices
=
torch
.
arange
(
block_num
,
device
=
device
).
unsqueeze
(
0
)
# [1, block_num]
...
...
@@ -29,7 +29,7 @@ def generate_nbhd_mask(a, block_num, num_frame, device="cpu"):
# 3. cross-frame attention
mask_cross
=
torch
.
zeros
((
block_num
,
block_num
),
dtype
=
torch
.
bool
,
device
=
device
)
for
n
in
range
(
1
,
num
_frame
):
for
n
in
range
(
1
,
attnmap
_frame
_num
):
if
n
==
1
:
width
=
1
/
2
*
a
elif
n
>=
2
:
...
...
@@ -67,7 +67,7 @@ def generate_qk_ranges(mask, block_size, seqlen):
class
NbhdAttnWeight
(
AttnWeightTemplate
):
block_size
=
128
seqlen
=
None
num
_frame
=
None
attnmap
_frame
_num
=
None
q_ranges
=
None
k_ranges
=
None
attn_type_map
=
None
...
...
@@ -76,22 +76,21 @@ class NbhdAttnWeight(AttnWeightTemplate):
self
.
config
=
{}
@
classmethod
def
prepare_mask
(
cls
,
seqlen
,
num_frame
):
if
seqlen
==
cls
.
seqlen
and
num_frame
==
cls
.
num_frame
:
def
prepare_mask
(
cls
,
seqlen
):
if
seqlen
==
cls
.
seqlen
:
return
block_num
=
(
seqlen
+
cls
.
block_size
-
1
)
//
cls
.
block_size
block_num_per_frame
=
(
seqlen
//
num
_frame
+
cls
.
block_size
-
1
)
//
cls
.
block_size
mask
=
generate_nbhd_mask
(
block_num_per_frame
,
block_num
,
num
_frame
,
device
=
"cpu"
)
block_num_per_frame
=
(
seqlen
//
cls
.
attnmap
_frame
_num
+
cls
.
block_size
-
1
)
//
cls
.
block_size
mask
=
generate_nbhd_mask
(
block_num_per_frame
,
block_num
,
cls
.
attnmap
_frame
_num
,
device
=
"cpu"
)
q_ranges
,
k_ranges
=
generate_qk_ranges
(
mask
,
cls
.
block_size
,
seqlen
)
attn_type_map
=
torch
.
zeros
(
len
(
q_ranges
),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
q_ranges
=
q_ranges
.
to
(
torch
.
int32
).
to
(
"cuda"
)
k_ranges
=
k_ranges
.
to
(
torch
.
int32
).
to
(
"cuda"
)
cls
.
seqlen
=
seqlen
cls
.
num_frame
=
num_frame
cls
.
q_ranges
=
q_ranges
cls
.
k_ranges
=
k_ranges
cls
.
attn_type_map
=
attn_type_map
logger
.
info
(
f
"NbhdAttnWeight Update: seqlen=
{
seqlen
}
, num_frame=
{
num_frame
}
"
)
logger
.
info
(
f
"NbhdAttnWeight Update: seqlen=
{
seqlen
}
"
)
sparsity
=
1
-
mask
.
sum
().
item
()
/
mask
.
numel
()
logger
.
info
(
f
"Attention sparsity:
{
sparsity
}
"
)
...
...
@@ -111,8 +110,7 @@ class NbhdAttnWeight(AttnWeightTemplate):
k: [seqlen, head_num, head_dim]
v: [seqlen, head_num, head_dim]
"""
num_frame
=
21
self
.
prepare_mask
(
seqlen
=
q
.
shape
[
0
],
num_frame
=
num_frame
)
self
.
prepare_mask
(
seqlen
=
q
.
shape
[
0
])
out
=
magi_ffa_func
(
q
,
k
,
...
...
lightx2v/common/ops/attn/svg_attn.py
View file @
837feba7
...
...
@@ -299,8 +299,8 @@ class SvgAttnWeight(AttnWeightTemplate):
sample_mse_max_row
=
None
num_sampled_rows
=
None
context_length
=
None
num
_frame
=
None
frame_size
=
None
attnmap
_frame
_num
=
None
seqlen
=
None
sparsity
=
None
mask_name_list
=
[
"spatial"
,
"temporal"
]
attention_masks
=
None
...
...
@@ -325,18 +325,18 @@ class SvgAttnWeight(AttnWeightTemplate):
self
.
sparse_attention
=
torch
.
compile
(
flex_attention
,
dynamic
=
False
,
mode
=
"max-autotune-no-cudagraphs"
)
@
classmethod
def
prepare_mask
(
cls
,
num_frame
,
frame_size
):
def
prepare_mask
(
cls
,
seqlen
):
# Use class attributes so updates affect all instances of this class
if
num_frame
==
cls
.
num_frame
and
frame_size
==
cls
.
frame_size
:
if
seqlen
==
cls
.
seqlen
:
return
cls
.
num_frame
=
num_frame
cls
.
frame_size
=
frame_size
cls
.
attention_masks
=
[
get_attention_mask
(
mask_name
,
cls
.
sample_mse_max_row
,
cls
.
context_length
,
num_frame
,
frame_size
)
for
mask_name
in
cls
.
mask_name_list
]
multiplier
=
diag_width
=
sparsity_to_width
(
cls
.
sparsity
,
cls
.
context_length
,
num_frame
,
frame_size
)
frame_size
=
seqlen
//
cls
.
attnmap_frame_num
cls
.
attention_masks
=
[
get_attention_mask
(
mask_name
,
cls
.
sample_mse_max_row
,
cls
.
context_length
,
cls
.
attnmap_frame_num
,
frame_size
)
for
mask_name
in
cls
.
mask_name_list
]
multiplier
=
diag_width
=
sparsity_to_width
(
cls
.
sparsity
,
cls
.
context_length
,
cls
.
attnmap_frame_num
,
frame_size
)
cls
.
block_mask
=
prepare_flexattention
(
1
,
cls
.
head_num
,
cls
.
head_dim
,
torch
.
bfloat16
,
"cuda"
,
cls
.
context_length
,
cls
.
context_length
,
num
_frame
,
frame_size
,
diag_width
=
diag_width
,
multiplier
=
multiplier
1
,
cls
.
head_num
,
cls
.
head_dim
,
torch
.
bfloat16
,
"cuda"
,
cls
.
context_length
,
cls
.
context_length
,
cls
.
attnmap
_frame
_num
,
frame_size
,
diag_width
=
diag_width
,
multiplier
=
multiplier
)
logger
.
info
(
f
"SvgAttnWeight Update: num_frame=
{
num_frame
}
, frame_size=
{
frame_size
}
"
)
cls
.
seqlen
=
seqlen
logger
.
info
(
f
"SvgAttnWeight Update: seqlen=
{
seqlen
}
"
)
def
apply
(
self
,
...
...
@@ -354,18 +354,19 @@ class SvgAttnWeight(AttnWeightTemplate):
v
=
v
.
unsqueeze
(
0
).
transpose
(
1
,
2
)
bs
,
num_heads
,
seq_len
,
dim
=
q
.
size
()
num_frame
=
21
self
.
prepare_mask
(
num_frame
=
num_frame
,
frame_size
=
seq_len
//
num_frame
)
self
.
prepare_mask
(
seq_len
)
sampled_mses
=
self
.
sample_mse
(
q
,
k
,
v
)
best_mask_idx
=
torch
.
argmin
(
sampled_mses
,
dim
=
0
)
output_hidden_states
=
torch
.
zeros_like
(
q
)
query_out
,
key_out
,
value_out
=
torch
.
zeros_like
(
q
),
torch
.
zeros_like
(
k
),
torch
.
zeros_like
(
v
)
query_out
,
key_out
,
value_out
=
self
.
fast_sparse_head_placement
(
q
,
k
,
v
,
query_out
,
key_out
,
value_out
,
best_mask_idx
,
self
.
context_length
,
self
.
num_frame
,
self
.
frame_size
)
query_out
,
key_out
,
value_out
=
self
.
fast_sparse_head_placement
(
q
,
k
,
v
,
query_out
,
key_out
,
value_out
,
best_mask_idx
,
self
.
context_length
,
self
.
attnmap_frame_num
,
seq_len
//
self
.
attnmap_frame_num
)
hidden_states
=
self
.
sparse_attention
(
query_out
,
key_out
,
value_out
)
wan_hidden_states_placement
(
hidden_states
,
output_hidden_states
,
best_mask_idx
,
self
.
context_length
,
self
.
num_frame
,
self
.
frame_
size
)
wan_hidden_states_placement
(
hidden_states
,
output_hidden_states
,
best_mask_idx
,
self
.
context_length
,
self
.
attnmap_frame_num
,
seq_len
//
self
.
attnmap_
frame_
num
)
return
output_hidden_states
.
reshape
(
bs
,
num_heads
,
seq_len
,
dim
).
transpose
(
1
,
2
).
reshape
(
bs
*
seq_len
,
-
1
)
...
...
lightx2v/models/networks/wan/weights/transformer_weights.py
View file @
837feba7
...
...
@@ -192,6 +192,8 @@ class WanSelfAttention(WeightModule):
context_length
=
self
.
config
.
get
(
"svg_context_length"
,
0
),
sparsity
=
self
.
config
.
get
(
"svg_sparsity"
,
0.25
),
)
if
self
.
config
[
"self_attn_1_type"
]
in
[
"svg_attn"
,
"nbhd_attn"
]:
attention_weights_cls
.
attnmap_frame_num
=
self
.
config
[
"attnmap_frame_num"
]
self
.
add_module
(
"self_attn_1"
,
attention_weights_cls
())
if
self
.
config
[
"seq_parallel"
]:
...
...
lightx2v/utils/set_config.py
View file @
837feba7
...
...
@@ -71,6 +71,8 @@ def set_config(args):
logger
.
warning
(
f
"`num_frames - 1` has to be divisible by
{
config
[
'vae_stride'
][
0
]
}
. Rounding to the nearest number."
)
config
[
"target_video_length"
]
=
config
[
"target_video_length"
]
//
config
[
"vae_stride"
][
0
]
*
config
[
"vae_stride"
][
0
]
+
1
config
[
"attnmap_frame_num"
]
=
((
config
[
"target_video_length"
]
-
1
)
//
config
[
"vae_stride"
][
0
]
+
1
)
//
config
[
"patch_size"
][
0
]
return
config
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment