Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
f37ce0b2
Unverified
Commit
f37ce0b2
authored
Nov 01, 2025
by
Yang Yong (雍洋)
Committed by
GitHub
Nov 01, 2025
Browse files
Support nbhd attention (#427)
parent
6062ef24
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
139 additions
and
0 deletions
+139
-0
configs/attentions/wan_i2v_nbhd.json
configs/attentions/wan_i2v_nbhd.json
+13
-0
lightx2v/common/ops/attn/__init__.py
lightx2v/common/ops/attn/__init__.py
+1
-0
lightx2v/common/ops/attn/nbhd_attn.py
lightx2v/common/ops/attn/nbhd_attn.py
+125
-0
No files found.
configs/attentions/wan_i2v_nbhd.json
0 → 100755
View file @
f37ce0b2
{
"infer_steps"
:
40
,
"target_video_length"
:
81
,
"target_height"
:
480
,
"target_width"
:
832
,
"self_attn_1_type"
:
"nbhd_attn"
,
"cross_attn_1_type"
:
"flash_attn3"
,
"cross_attn_2_type"
:
"flash_attn3"
,
"sample_guide_scale"
:
5
,
"sample_shift"
:
3
,
"enable_cfg"
:
true
,
"cpu_offload"
:
false
}
lightx2v/common/ops/attn/__init__.py
View file @
f37ce0b2
from
.flash_attn
import
FlashAttn2Weight
,
FlashAttn3Weight
from
.nbhd_attn
import
NbhdAttnWeight
from
.radial_attn
import
RadialAttnWeight
from
.ring_attn
import
RingAttnWeight
from
.sage_attn
import
SageAttn2Weight
,
SageAttn3Weight
...
...
lightx2v/common/ops/attn/nbhd_attn.py
0 → 100644
View file @
f37ce0b2
import
torch
from
loguru
import
logger
try
:
from
magi_attention.functional
import
flex_flash_attn_func
as
magi_ffa_func
except
ImportError
:
magi_ffa_func
=
None
from
lightx2v.utils.registry_factory
import
ATTN_WEIGHT_REGISTER
from
.template
import
AttnWeightTemplate
def
generate_nbhd_mask
(
a
,
block_num
,
num_frame
,
device
=
"cpu"
):
"""
a : block num per frame
block_num : block num per col/row
num_frame : total frame num
"""
i_indices
=
torch
.
arange
(
block_num
,
device
=
device
).
unsqueeze
(
1
)
# [block_num, 1]
j_indices
=
torch
.
arange
(
block_num
,
device
=
device
).
unsqueeze
(
0
)
# [1, block_num]
# 1. attention sink frame: j <= a
mask_sink
=
j_indices
<=
a
# 2. self-attention within the frame
n
=
i_indices
//
a
mask_self
=
(
j_indices
>=
n
*
a
)
&
(
j_indices
<
(
n
+
1
)
*
a
)
# 3. cross-frame attention
mask_cross
=
torch
.
zeros
((
block_num
,
block_num
),
dtype
=
torch
.
bool
,
device
=
device
)
for
n
in
range
(
1
,
num_frame
):
if
n
==
1
:
width
=
1
/
2
*
a
elif
n
>=
2
:
width
=
1
/
8
*
a
mask_1
=
(
i_indices
-
j_indices
+
(
n
*
a
+
width
)
>=
0
)
&
(
i_indices
-
j_indices
+
(
n
*
a
-
width
)
<
0
)
mask_2
=
(
i_indices
-
j_indices
-
(
n
*
a
-
width
)
>
0
)
&
(
i_indices
-
j_indices
-
(
n
*
a
+
width
)
<=
0
)
mask_cross
=
mask_cross
|
mask_1
|
mask_2
# 合并所有mask
mask
=
mask_sink
|
mask_self
|
mask_cross
return
mask
def
generate_qk_ranges
(
mask
,
block_size
,
seqlen
):
indices
=
torch
.
nonzero
(
mask
,
as_tuple
=
False
)
# shape: [N, 2]
i_indices
=
indices
[:,
0
]
# [N]
j_indices
=
indices
[:,
1
]
# [N]
q_start
=
i_indices
*
block_size
# [N]
q_end
=
torch
.
clamp
((
i_indices
+
1
)
*
block_size
,
max
=
seqlen
)
# [N]
k_start
=
j_indices
*
block_size
# [N]
k_end
=
torch
.
clamp
((
j_indices
+
1
)
*
block_size
,
max
=
seqlen
)
# [N]
q_ranges
=
torch
.
stack
([
q_start
,
q_end
],
dim
=
1
)
# [N, 2]
k_ranges
=
torch
.
stack
([
k_start
,
k_end
],
dim
=
1
)
# [N, 2]
return
q_ranges
,
k_ranges
@
ATTN_WEIGHT_REGISTER
(
"nbhd_attn"
)
class
NbhdAttnWeight
(
AttnWeightTemplate
):
block_size
=
128
seqlen
=
None
num_frame
=
None
q_ranges
=
None
k_ranges
=
None
attn_type_map
=
None
def
__init__
(
self
):
self
.
config
=
{}
@
classmethod
def
prepare_mask
(
cls
,
seqlen
,
num_frame
):
if
seqlen
==
cls
.
seqlen
and
num_frame
==
cls
.
num_frame
:
return
block_num
=
(
seqlen
+
cls
.
block_size
-
1
)
//
cls
.
block_size
block_num_per_frame
=
(
seqlen
//
num_frame
+
cls
.
block_size
-
1
)
//
cls
.
block_size
mask
=
generate_nbhd_mask
(
block_num_per_frame
,
block_num
,
num_frame
,
device
=
"cpu"
)
q_ranges
,
k_ranges
=
generate_qk_ranges
(
mask
,
cls
.
block_size
,
seqlen
)
attn_type_map
=
torch
.
zeros
(
len
(
q_ranges
),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
q_ranges
=
q_ranges
.
to
(
torch
.
int32
).
to
(
"cuda"
)
k_ranges
=
k_ranges
.
to
(
torch
.
int32
).
to
(
"cuda"
)
cls
.
seqlen
=
seqlen
cls
.
num_frame
=
num_frame
cls
.
q_ranges
=
q_ranges
cls
.
k_ranges
=
k_ranges
cls
.
attn_type_map
=
attn_type_map
logger
.
info
(
f
"NbhdAttnWeight Update: seqlen=
{
seqlen
}
, num_frame=
{
num_frame
}
"
)
sparsity
=
1
-
mask
.
sum
().
item
()
/
mask
.
numel
()
logger
.
info
(
f
"Attention sparsity:
{
sparsity
}
"
)
def
apply
(
self
,
q
,
k
,
v
,
cu_seqlens_q
=
None
,
cu_seqlens_kv
=
None
,
max_seqlen_q
=
None
,
max_seqlen_kv
=
None
,
model_cls
=
None
,
):
"""
q: [seqlen, head_num, head_dim]
k: [seqlen, head_num, head_dim]
v: [seqlen, head_num, head_dim]
"""
num_frame
=
21
self
.
prepare_mask
(
seqlen
=
q
.
shape
[
0
],
num_frame
=
num_frame
)
out
=
magi_ffa_func
(
q
,
k
,
v
,
q_ranges
=
self
.
q_ranges
,
k_ranges
=
self
.
k_ranges
,
attn_type_map
=
self
.
attn_type_map
,
auto_range_merge
=
True
,
)[
0
]
return
out
.
reshape
(
out
.
shape
[
0
],
-
1
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment