Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
f37ce0b2
"research/domain_adaptation/domain_separation" did not exist on "a8ba923c873f9848d0f6453f3e2e3fa2dd1187dc"
Unverified
Commit
f37ce0b2
authored
Nov 01, 2025
by
Yang Yong (雍洋)
Committed by
GitHub
Nov 01, 2025
Browse files
Support nbhd attention (#427)
parent
6062ef24
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
139 additions
and
0 deletions
+139
-0
configs/attentions/wan_i2v_nbhd.json
configs/attentions/wan_i2v_nbhd.json
+13
-0
lightx2v/common/ops/attn/__init__.py
lightx2v/common/ops/attn/__init__.py
+1
-0
lightx2v/common/ops/attn/nbhd_attn.py
lightx2v/common/ops/attn/nbhd_attn.py
+125
-0
No files found.
configs/attentions/wan_i2v_nbhd.json
0 → 100755
View file @
f37ce0b2
{
"infer_steps"
:
40
,
"target_video_length"
:
81
,
"target_height"
:
480
,
"target_width"
:
832
,
"self_attn_1_type"
:
"nbhd_attn"
,
"cross_attn_1_type"
:
"flash_attn3"
,
"cross_attn_2_type"
:
"flash_attn3"
,
"sample_guide_scale"
:
5
,
"sample_shift"
:
3
,
"enable_cfg"
:
true
,
"cpu_offload"
:
false
}
lightx2v/common/ops/attn/__init__.py
View file @
f37ce0b2
from
.flash_attn
import
FlashAttn2Weight
,
FlashAttn3Weight
from
.nbhd_attn
import
NbhdAttnWeight
from
.radial_attn
import
RadialAttnWeight
from
.ring_attn
import
RingAttnWeight
from
.sage_attn
import
SageAttn2Weight
,
SageAttn3Weight
...
...
lightx2v/common/ops/attn/nbhd_attn.py
0 → 100644
View file @
f37ce0b2
import
torch
from
loguru
import
logger
try
:
from
magi_attention.functional
import
flex_flash_attn_func
as
magi_ffa_func
except
ImportError
:
magi_ffa_func
=
None
from
lightx2v.utils.registry_factory
import
ATTN_WEIGHT_REGISTER
from
.template
import
AttnWeightTemplate
def
generate_nbhd_mask
(
a
,
block_num
,
num_frame
,
device
=
"cpu"
):
"""
a : block num per frame
block_num : block num per col/row
num_frame : total frame num
"""
i_indices
=
torch
.
arange
(
block_num
,
device
=
device
).
unsqueeze
(
1
)
# [block_num, 1]
j_indices
=
torch
.
arange
(
block_num
,
device
=
device
).
unsqueeze
(
0
)
# [1, block_num]
# 1. attention sink frame: j <= a
mask_sink
=
j_indices
<=
a
# 2. self-attention within the frame
n
=
i_indices
//
a
mask_self
=
(
j_indices
>=
n
*
a
)
&
(
j_indices
<
(
n
+
1
)
*
a
)
# 3. cross-frame attention
mask_cross
=
torch
.
zeros
((
block_num
,
block_num
),
dtype
=
torch
.
bool
,
device
=
device
)
for
n
in
range
(
1
,
num_frame
):
if
n
==
1
:
width
=
1
/
2
*
a
elif
n
>=
2
:
width
=
1
/
8
*
a
mask_1
=
(
i_indices
-
j_indices
+
(
n
*
a
+
width
)
>=
0
)
&
(
i_indices
-
j_indices
+
(
n
*
a
-
width
)
<
0
)
mask_2
=
(
i_indices
-
j_indices
-
(
n
*
a
-
width
)
>
0
)
&
(
i_indices
-
j_indices
-
(
n
*
a
+
width
)
<=
0
)
mask_cross
=
mask_cross
|
mask_1
|
mask_2
# 合并所有mask
mask
=
mask_sink
|
mask_self
|
mask_cross
return
mask
def
generate_qk_ranges
(
mask
,
block_size
,
seqlen
):
indices
=
torch
.
nonzero
(
mask
,
as_tuple
=
False
)
# shape: [N, 2]
i_indices
=
indices
[:,
0
]
# [N]
j_indices
=
indices
[:,
1
]
# [N]
q_start
=
i_indices
*
block_size
# [N]
q_end
=
torch
.
clamp
((
i_indices
+
1
)
*
block_size
,
max
=
seqlen
)
# [N]
k_start
=
j_indices
*
block_size
# [N]
k_end
=
torch
.
clamp
((
j_indices
+
1
)
*
block_size
,
max
=
seqlen
)
# [N]
q_ranges
=
torch
.
stack
([
q_start
,
q_end
],
dim
=
1
)
# [N, 2]
k_ranges
=
torch
.
stack
([
k_start
,
k_end
],
dim
=
1
)
# [N, 2]
return
q_ranges
,
k_ranges
@
ATTN_WEIGHT_REGISTER
(
"nbhd_attn"
)
class
NbhdAttnWeight
(
AttnWeightTemplate
):
block_size
=
128
seqlen
=
None
num_frame
=
None
q_ranges
=
None
k_ranges
=
None
attn_type_map
=
None
def
__init__
(
self
):
self
.
config
=
{}
@
classmethod
def
prepare_mask
(
cls
,
seqlen
,
num_frame
):
if
seqlen
==
cls
.
seqlen
and
num_frame
==
cls
.
num_frame
:
return
block_num
=
(
seqlen
+
cls
.
block_size
-
1
)
//
cls
.
block_size
block_num_per_frame
=
(
seqlen
//
num_frame
+
cls
.
block_size
-
1
)
//
cls
.
block_size
mask
=
generate_nbhd_mask
(
block_num_per_frame
,
block_num
,
num_frame
,
device
=
"cpu"
)
q_ranges
,
k_ranges
=
generate_qk_ranges
(
mask
,
cls
.
block_size
,
seqlen
)
attn_type_map
=
torch
.
zeros
(
len
(
q_ranges
),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
q_ranges
=
q_ranges
.
to
(
torch
.
int32
).
to
(
"cuda"
)
k_ranges
=
k_ranges
.
to
(
torch
.
int32
).
to
(
"cuda"
)
cls
.
seqlen
=
seqlen
cls
.
num_frame
=
num_frame
cls
.
q_ranges
=
q_ranges
cls
.
k_ranges
=
k_ranges
cls
.
attn_type_map
=
attn_type_map
logger
.
info
(
f
"NbhdAttnWeight Update: seqlen=
{
seqlen
}
, num_frame=
{
num_frame
}
"
)
sparsity
=
1
-
mask
.
sum
().
item
()
/
mask
.
numel
()
logger
.
info
(
f
"Attention sparsity:
{
sparsity
}
"
)
def
apply
(
self
,
q
,
k
,
v
,
cu_seqlens_q
=
None
,
cu_seqlens_kv
=
None
,
max_seqlen_q
=
None
,
max_seqlen_kv
=
None
,
model_cls
=
None
,
):
"""
q: [seqlen, head_num, head_dim]
k: [seqlen, head_num, head_dim]
v: [seqlen, head_num, head_dim]
"""
num_frame
=
21
self
.
prepare_mask
(
seqlen
=
q
.
shape
[
0
],
num_frame
=
num_frame
)
out
=
magi_ffa_func
(
q
,
k
,
v
,
q_ranges
=
self
.
q_ranges
,
k_ranges
=
self
.
k_ranges
,
attn_type_map
=
self
.
attn_type_map
,
auto_range_merge
=
True
,
)[
0
]
return
out
.
reshape
(
out
.
shape
[
0
],
-
1
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment