Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
db6296f2
Commit
db6296f2
authored
Nov 10, 2025
by
helloyongyang
Browse files
Support flashinfer for nbhd attn
parent
adc66e8d
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
73 additions
and
3 deletions
+73
-3
lightx2v/common/ops/attn/__init__.py
lightx2v/common/ops/attn/__init__.py
+1
-1
lightx2v/common/ops/attn/nbhd_attn.py
lightx2v/common/ops/attn/nbhd_attn.py
+70
-0
lightx2v/models/networks/wan/weights/transformer_weights.py
lightx2v/models/networks/wan/weights/transformer_weights.py
+2
-2
No files found.
lightx2v/common/ops/attn/__init__.py
View file @
db6296f2
from
.flash_attn
import
FlashAttn2Weight
,
FlashAttn3Weight
from
.flash_attn
import
FlashAttn2Weight
,
FlashAttn3Weight
from
.nbhd_attn
import
NbhdAttnWeight
from
.nbhd_attn
import
NbhdAttnWeight
,
NbhdAttnWeightFlashInfer
from
.radial_attn
import
RadialAttnWeight
from
.radial_attn
import
RadialAttnWeight
from
.ring_attn
import
RingAttnWeight
from
.ring_attn
import
RingAttnWeight
from
.sage_attn
import
SageAttn2Weight
,
SageAttn3Weight
from
.sage_attn
import
SageAttn2Weight
,
SageAttn3Weight
...
...
lightx2v/common/ops/attn/nbhd_attn.py
View file @
db6296f2
...
@@ -6,6 +6,11 @@ try:
...
@@ -6,6 +6,11 @@ try:
except
ImportError
:
except
ImportError
:
magi_ffa_func
=
None
magi_ffa_func
=
None
try
:
import
flashinfer
except
ImportError
:
flashinfer
=
None
from
lightx2v.utils.registry_factory
import
ATTN_WEIGHT_REGISTER
from
lightx2v.utils.registry_factory
import
ATTN_WEIGHT_REGISTER
from
.template
import
AttnWeightTemplate
from
.template
import
AttnWeightTemplate
...
@@ -124,3 +129,68 @@ class NbhdAttnWeight(AttnWeightTemplate):
...
@@ -124,3 +129,68 @@ class NbhdAttnWeight(AttnWeightTemplate):
auto_range_merge
=
True
,
auto_range_merge
=
True
,
)[
0
]
)[
0
]
return
out
.
reshape
(
out
.
shape
[
0
],
-
1
)
return
out
.
reshape
(
out
.
shape
[
0
],
-
1
)
@
ATTN_WEIGHT_REGISTER
(
"nbhd_attn_flashinfer"
)
class
NbhdAttnWeightFlashInfer
(
AttnWeightTemplate
):
block_size
=
128
seqlen
=
None
attnmap_frame_num
=
None
coefficient
=
[
1.0
,
0.5
,
0.056
]
min_width
=
1.0
sparse_wrapper
=
None
def
__init__
(
self
):
self
.
config
=
{}
@
classmethod
@
torch
.
compiler
.
disable
def
prepare_mask
(
cls
,
seqlen
,
head_num
,
head_dim
):
if
seqlen
==
cls
.
seqlen
:
return
block_num
=
(
seqlen
+
cls
.
block_size
-
1
)
//
cls
.
block_size
block_num_per_frame
=
seqlen
/
cls
.
attnmap_frame_num
/
cls
.
block_size
mask
=
generate_nbhd_mask
(
block_num_per_frame
,
block_num
,
cls
.
attnmap_frame_num
,
coefficient
=
cls
.
coefficient
,
min_width
=
cls
.
min_width
,
device
=
"cpu"
)
mask
=
mask
.
unsqueeze
(
0
).
repeat
(
head_num
,
1
,
1
)
block_rowcol_size
=
torch
.
ones
(
block_num
,
dtype
=
torch
.
int32
)
*
cls
.
block_size
block_rowcol_size
[
-
1
]
=
seqlen
-
cls
.
block_size
*
(
block_num
-
1
)
block_rowcol_size
=
block_rowcol_size
.
unsqueeze
(
0
).
repeat
(
head_num
,
1
)
float_workspace_buffer
=
torch
.
empty
(
1024
*
1024
*
1024
,
dtype
=
torch
.
uint8
,
device
=
"cuda:0"
)
cls
.
sparse_wrapper
=
flashinfer
.
sparse
.
VariableBlockSparseAttentionWrapper
(
float_workspace_buffer
,
backend
=
"fa2"
)
cls
.
sparse_wrapper
.
plan
(
block_mask_map
=
mask
,
block_row_sz
=
block_rowcol_size
,
block_col_sz
=
block_rowcol_size
,
num_qo_heads
=
head_num
,
num_kv_heads
=
head_num
,
head_dim
=
head_dim
,
q_data_type
=
torch
.
bfloat16
,
)
cls
.
seqlen
=
seqlen
logger
.
info
(
f
"NbhdAttnWeight Update: seqlen=
{
seqlen
}
"
)
sparsity
=
1
-
mask
.
sum
().
item
()
/
mask
.
numel
()
logger
.
info
(
f
"Attention sparsity:
{
sparsity
}
"
)
def
apply
(
self
,
q
,
k
,
v
,
cu_seqlens_q
=
None
,
cu_seqlens_kv
=
None
,
max_seqlen_q
=
None
,
max_seqlen_kv
=
None
,
model_cls
=
None
,
):
"""
q: [seqlen, head_num, head_dim]
k: [seqlen, head_num, head_dim]
v: [seqlen, head_num, head_dim]
"""
self
.
prepare_mask
(
seqlen
=
q
.
shape
[
0
],
head_num
=
q
.
shape
[
1
],
head_dim
=
q
.
shape
[
2
])
q
=
q
.
transpose
(
0
,
1
)
k
=
k
.
transpose
(
0
,
1
)
v
=
v
.
transpose
(
0
,
1
)
out
=
self
.
sparse_wrapper
.
run
(
q
,
k
,
v
)
out
=
out
.
transpose
(
0
,
1
)
return
out
.
reshape
(
out
.
shape
[
0
],
-
1
)
lightx2v/models/networks/wan/weights/transformer_weights.py
View file @
db6296f2
...
@@ -192,10 +192,10 @@ class WanSelfAttention(WeightModule):
...
@@ -192,10 +192,10 @@ class WanSelfAttention(WeightModule):
context_length
=
self
.
config
.
get
(
"svg_context_length"
,
0
),
context_length
=
self
.
config
.
get
(
"svg_context_length"
,
0
),
sparsity
=
self
.
config
.
get
(
"svg_sparsity"
,
0.25
),
sparsity
=
self
.
config
.
get
(
"svg_sparsity"
,
0.25
),
)
)
if
self
.
config
[
"self_attn_1_type"
]
in
[
"svg_attn"
,
"radial_attn"
,
"nbhd_attn"
]:
if
self
.
config
[
"self_attn_1_type"
]
in
[
"svg_attn"
,
"radial_attn"
,
"nbhd_attn"
,
"nbhd_attn_flashinfer"
]:
attention_weights_cls
.
attnmap_frame_num
=
self
.
config
[
"attnmap_frame_num"
]
attention_weights_cls
.
attnmap_frame_num
=
self
.
config
[
"attnmap_frame_num"
]
# nbhd_attn setting
# nbhd_attn setting
if
self
.
config
[
"self_attn_1_type"
]
==
"nbhd_attn"
:
if
self
.
config
[
"self_attn_1_type"
]
in
[
"nbhd_attn"
,
"nbhd_attn_flashinfer"
]
:
if
"nbhd_attn_setting"
in
self
.
config
:
if
"nbhd_attn_setting"
in
self
.
config
:
if
"coefficient"
in
self
.
config
[
"nbhd_attn_setting"
]:
if
"coefficient"
in
self
.
config
[
"nbhd_attn_setting"
]:
attention_weights_cls
.
coefficient
=
self
.
config
[
"nbhd_attn_setting"
][
"coefficient"
]
attention_weights_cls
.
coefficient
=
self
.
config
[
"nbhd_attn_setting"
][
"coefficient"
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment