Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
837feba7
"tests/git@developer.sourcefind.cn:wangsen/mineru.git" did not exist on "f5939a51b6389d4b045f3b843481ffb624b54719"
Unverified
Commit
837feba7
authored
Nov 03, 2025
by
Yang Yong (雍洋)
Committed by
GitHub
Nov 03, 2025
Browse files
Update sparse attention (#432)
parent
fc231d3d
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
29 additions
and
26 deletions
+29
-26
lightx2v/common/ops/attn/nbhd_attn.py
lightx2v/common/ops/attn/nbhd_attn.py
+10
-12
lightx2v/common/ops/attn/svg_attn.py
lightx2v/common/ops/attn/svg_attn.py
+15
-14
lightx2v/models/networks/wan/weights/transformer_weights.py
lightx2v/models/networks/wan/weights/transformer_weights.py
+2
-0
lightx2v/utils/set_config.py
lightx2v/utils/set_config.py
+2
-0
No files found.
lightx2v/common/ops/attn/nbhd_attn.py
View file @
837feba7
...
...
@@ -11,11 +11,11 @@ from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER
from
.template
import
AttnWeightTemplate
def
generate_nbhd_mask
(
a
,
block_num
,
num
_frame
,
device
=
"cpu"
):
def
generate_nbhd_mask
(
a
,
block_num
,
attnmap
_frame
_num
,
device
=
"cpu"
):
"""
a : block num per frame
block_num : block num per col/row
num
_frame : total frame num
attnmap
_frame
_num
: total frame num
"""
i_indices
=
torch
.
arange
(
block_num
,
device
=
device
).
unsqueeze
(
1
)
# [block_num, 1]
j_indices
=
torch
.
arange
(
block_num
,
device
=
device
).
unsqueeze
(
0
)
# [1, block_num]
...
...
@@ -29,7 +29,7 @@ def generate_nbhd_mask(a, block_num, num_frame, device="cpu"):
# 3. cross-frame attention
mask_cross
=
torch
.
zeros
((
block_num
,
block_num
),
dtype
=
torch
.
bool
,
device
=
device
)
for
n
in
range
(
1
,
num
_frame
):
for
n
in
range
(
1
,
attnmap
_frame
_num
):
if
n
==
1
:
width
=
1
/
2
*
a
elif
n
>=
2
:
...
...
@@ -67,7 +67,7 @@ def generate_qk_ranges(mask, block_size, seqlen):
class
NbhdAttnWeight
(
AttnWeightTemplate
):
block_size
=
128
seqlen
=
None
num
_frame
=
None
attnmap
_frame
_num
=
None
q_ranges
=
None
k_ranges
=
None
attn_type_map
=
None
...
...
@@ -76,22 +76,21 @@ class NbhdAttnWeight(AttnWeightTemplate):
self
.
config
=
{}
@
classmethod
def
prepare_mask
(
cls
,
seqlen
,
num_frame
):
if
seqlen
==
cls
.
seqlen
and
num_frame
==
cls
.
num_frame
:
def
prepare_mask
(
cls
,
seqlen
):
if
seqlen
==
cls
.
seqlen
:
return
block_num
=
(
seqlen
+
cls
.
block_size
-
1
)
//
cls
.
block_size
block_num_per_frame
=
(
seqlen
//
num
_frame
+
cls
.
block_size
-
1
)
//
cls
.
block_size
mask
=
generate_nbhd_mask
(
block_num_per_frame
,
block_num
,
num
_frame
,
device
=
"cpu"
)
block_num_per_frame
=
(
seqlen
//
cls
.
attnmap
_frame
_num
+
cls
.
block_size
-
1
)
//
cls
.
block_size
mask
=
generate_nbhd_mask
(
block_num_per_frame
,
block_num
,
cls
.
attnmap
_frame
_num
,
device
=
"cpu"
)
q_ranges
,
k_ranges
=
generate_qk_ranges
(
mask
,
cls
.
block_size
,
seqlen
)
attn_type_map
=
torch
.
zeros
(
len
(
q_ranges
),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
q_ranges
=
q_ranges
.
to
(
torch
.
int32
).
to
(
"cuda"
)
k_ranges
=
k_ranges
.
to
(
torch
.
int32
).
to
(
"cuda"
)
cls
.
seqlen
=
seqlen
cls
.
num_frame
=
num_frame
cls
.
q_ranges
=
q_ranges
cls
.
k_ranges
=
k_ranges
cls
.
attn_type_map
=
attn_type_map
logger
.
info
(
f
"NbhdAttnWeight Update: seqlen=
{
seqlen
}
, num_frame=
{
num_frame
}
"
)
logger
.
info
(
f
"NbhdAttnWeight Update: seqlen=
{
seqlen
}
"
)
sparsity
=
1
-
mask
.
sum
().
item
()
/
mask
.
numel
()
logger
.
info
(
f
"Attention sparsity:
{
sparsity
}
"
)
...
...
@@ -111,8 +110,7 @@ class NbhdAttnWeight(AttnWeightTemplate):
k: [seqlen, head_num, head_dim]
v: [seqlen, head_num, head_dim]
"""
num_frame
=
21
self
.
prepare_mask
(
seqlen
=
q
.
shape
[
0
],
num_frame
=
num_frame
)
self
.
prepare_mask
(
seqlen
=
q
.
shape
[
0
])
out
=
magi_ffa_func
(
q
,
k
,
...
...
lightx2v/common/ops/attn/svg_attn.py
View file @
837feba7
...
...
@@ -299,8 +299,8 @@ class SvgAttnWeight(AttnWeightTemplate):
sample_mse_max_row
=
None
num_sampled_rows
=
None
context_length
=
None
num
_frame
=
None
frame_size
=
None
attnmap
_frame
_num
=
None
seqlen
=
None
sparsity
=
None
mask_name_list
=
[
"spatial"
,
"temporal"
]
attention_masks
=
None
...
...
@@ -325,18 +325,18 @@ class SvgAttnWeight(AttnWeightTemplate):
self
.
sparse_attention
=
torch
.
compile
(
flex_attention
,
dynamic
=
False
,
mode
=
"max-autotune-no-cudagraphs"
)
@
classmethod
def
prepare_mask
(
cls
,
num_frame
,
frame_size
):
def
prepare_mask
(
cls
,
seqlen
):
# Use class attributes so updates affect all instances of this class
if
num_frame
==
cls
.
num_frame
and
frame_size
==
cls
.
frame_size
:
if
seqlen
==
cls
.
seqlen
:
return
cls
.
num_frame
=
num_frame
cls
.
frame_size
=
frame_size
cls
.
attention_masks
=
[
get_attention_mask
(
mask_name
,
cls
.
sample_mse_max_row
,
cls
.
context_length
,
num_frame
,
frame_size
)
for
mask_name
in
cls
.
mask_name_list
]
multiplier
=
diag_width
=
sparsity_to_width
(
cls
.
sparsity
,
cls
.
context_length
,
num_frame
,
frame_size
)
frame_size
=
seqlen
//
cls
.
attnmap_frame_num
cls
.
attention_masks
=
[
get_attention_mask
(
mask_name
,
cls
.
sample_mse_max_row
,
cls
.
context_length
,
cls
.
attnmap_frame_num
,
frame_size
)
for
mask_name
in
cls
.
mask_name_list
]
multiplier
=
diag_width
=
sparsity_to_width
(
cls
.
sparsity
,
cls
.
context_length
,
cls
.
attnmap_frame_num
,
frame_size
)
cls
.
block_mask
=
prepare_flexattention
(
1
,
cls
.
head_num
,
cls
.
head_dim
,
torch
.
bfloat16
,
"cuda"
,
cls
.
context_length
,
cls
.
context_length
,
num
_frame
,
frame_size
,
diag_width
=
diag_width
,
multiplier
=
multiplier
1
,
cls
.
head_num
,
cls
.
head_dim
,
torch
.
bfloat16
,
"cuda"
,
cls
.
context_length
,
cls
.
context_length
,
cls
.
attnmap
_frame
_num
,
frame_size
,
diag_width
=
diag_width
,
multiplier
=
multiplier
)
logger
.
info
(
f
"SvgAttnWeight Update: num_frame=
{
num_frame
}
, frame_size=
{
frame_size
}
"
)
cls
.
seqlen
=
seqlen
logger
.
info
(
f
"SvgAttnWeight Update: seqlen=
{
seqlen
}
"
)
def
apply
(
self
,
...
...
@@ -354,18 +354,19 @@ class SvgAttnWeight(AttnWeightTemplate):
v
=
v
.
unsqueeze
(
0
).
transpose
(
1
,
2
)
bs
,
num_heads
,
seq_len
,
dim
=
q
.
size
()
num_frame
=
21
self
.
prepare_mask
(
num_frame
=
num_frame
,
frame_size
=
seq_len
//
num_frame
)
self
.
prepare_mask
(
seq_len
)
sampled_mses
=
self
.
sample_mse
(
q
,
k
,
v
)
best_mask_idx
=
torch
.
argmin
(
sampled_mses
,
dim
=
0
)
output_hidden_states
=
torch
.
zeros_like
(
q
)
query_out
,
key_out
,
value_out
=
torch
.
zeros_like
(
q
),
torch
.
zeros_like
(
k
),
torch
.
zeros_like
(
v
)
query_out
,
key_out
,
value_out
=
self
.
fast_sparse_head_placement
(
q
,
k
,
v
,
query_out
,
key_out
,
value_out
,
best_mask_idx
,
self
.
context_length
,
self
.
num_frame
,
self
.
frame_size
)
query_out
,
key_out
,
value_out
=
self
.
fast_sparse_head_placement
(
q
,
k
,
v
,
query_out
,
key_out
,
value_out
,
best_mask_idx
,
self
.
context_length
,
self
.
attnmap_frame_num
,
seq_len
//
self
.
attnmap_frame_num
)
hidden_states
=
self
.
sparse_attention
(
query_out
,
key_out
,
value_out
)
wan_hidden_states_placement
(
hidden_states
,
output_hidden_states
,
best_mask_idx
,
self
.
context_length
,
self
.
num_frame
,
self
.
frame_
size
)
wan_hidden_states_placement
(
hidden_states
,
output_hidden_states
,
best_mask_idx
,
self
.
context_length
,
self
.
attnmap_frame_num
,
seq_len
//
self
.
attnmap_
frame_
num
)
return
output_hidden_states
.
reshape
(
bs
,
num_heads
,
seq_len
,
dim
).
transpose
(
1
,
2
).
reshape
(
bs
*
seq_len
,
-
1
)
...
...
lightx2v/models/networks/wan/weights/transformer_weights.py
View file @
837feba7
...
...
@@ -192,6 +192,8 @@ class WanSelfAttention(WeightModule):
context_length
=
self
.
config
.
get
(
"svg_context_length"
,
0
),
sparsity
=
self
.
config
.
get
(
"svg_sparsity"
,
0.25
),
)
if
self
.
config
[
"self_attn_1_type"
]
in
[
"svg_attn"
,
"nbhd_attn"
]:
attention_weights_cls
.
attnmap_frame_num
=
self
.
config
[
"attnmap_frame_num"
]
self
.
add_module
(
"self_attn_1"
,
attention_weights_cls
())
if
self
.
config
[
"seq_parallel"
]:
...
...
lightx2v/utils/set_config.py
View file @
837feba7
...
...
@@ -71,6 +71,8 @@ def set_config(args):
logger
.
warning
(
f
"`num_frames - 1` has to be divisible by
{
config
[
'vae_stride'
][
0
]
}
. Rounding to the nearest number."
)
config
[
"target_video_length"
]
=
config
[
"target_video_length"
]
//
config
[
"vae_stride"
][
0
]
*
config
[
"vae_stride"
][
0
]
+
1
config
[
"attnmap_frame_num"
]
=
((
config
[
"target_video_length"
]
-
1
)
//
config
[
"vae_stride"
][
0
]
+
1
)
//
config
[
"patch_size"
][
0
]
return
config
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment