Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8c47fdfd
Unverified
Commit
8c47fdfd
authored
Mar 24, 2026
by
liangel-02
Committed by
GitHub
Mar 24, 2026
Browse files
[FlexAttention] allow custom mask mod (#37692)
Signed-off-by:
Angel Li
<
liangel@meta.com
>
parent
54b0578a
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
121 additions
and
16 deletions
+121
-16
tests/kernels/test_flex_attention.py
tests/kernels/test_flex_attention.py
+51
-0
vllm/v1/attention/backends/flex_attention.py
vllm/v1/attention/backends/flex_attention.py
+70
-16
No files found.
tests/kernels/test_flex_attention.py
View file @
8c47fdfd
...
@@ -14,6 +14,7 @@ from tests.v1.attention.utils import (
...
@@ -14,6 +14,7 @@ from tests.v1.attention.utils import (
create_vllm_config
,
create_vllm_config
,
)
)
from
vllm.v1.attention.backends.flex_attention
import
(
from
vllm.v1.attention.backends.flex_attention
import
(
BlockSparsityHint
,
FlexAttentionMetadataBuilder
,
FlexAttentionMetadataBuilder
,
physical_to_logical_mapping
,
physical_to_logical_mapping
,
)
)
...
@@ -223,5 +224,55 @@ def test_physical_to_logical_mapping_handles_reused_blocks():
...
@@ -223,5 +224,55 @@ def test_physical_to_logical_mapping_handles_reused_blocks():
assert
out2
[
0
,
2
].
item
()
==
1
assert
out2
[
0
,
2
].
item
()
==
1
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
()
or
TORCH_VERSION
<
DIRECT_BUILD_VERSION
,
reason
=
"CUDA not available or PyTorch version < 2.9"
,
)
def
test_block_sparsity_hint_prunes_blocks
():
"""Test that BlockSparsityHint prunes KV blocks from the direct build path.
Uses a hint that only keeps the diagonal (q_block == kv_block) to verify
that off-diagonal blocks are excluded from the resulting BlockMask.
"""
device
=
torch
.
device
(
"cuda"
)
vllm_config
=
create_vllm_config
(
model_name
=
"facebook/opt-125m"
,
block_size
=
16
,
max_model_len
=
1024
,
)
kv_cache_spec
=
create_standard_kv_cache_spec
(
vllm_config
)
batch_spec
=
BatchSpec
(
seq_lens
=
[
256
],
query_lens
=
[
256
],
name
=
"test_sparsity_hint"
,
)
common_attn_metadata
=
create_common_attn_metadata
(
batch_spec
,
vllm_config
.
cache_config
.
block_size
,
device
)
builder
=
FlexAttentionMetadataBuilder
(
kv_cache_spec
,
[],
vllm_config
,
device
)
metadata_no_hint
=
builder
.
build
(
common_prefix_len
=
0
,
common_attn_metadata
=
common_attn_metadata
)
metadata_no_hint
.
block_mask
=
metadata_no_hint
.
_build_block_mask_direct
()
assert
metadata_no_hint
.
block_mask
.
kv_num_blocks
.
max
().
item
()
>
1
def
diagonal_hint
(
q_block_idx
,
kv_block_idx
,
block_size
):
return
q_block_idx
==
kv_block_idx
metadata_with_hint
=
builder
.
build
(
common_prefix_len
=
0
,
common_attn_metadata
=
common_attn_metadata
)
metadata_with_hint
.
block_sparsity_hint
=
BlockSparsityHint
(
hint_fn
=
diagonal_hint
,
)
metadata_with_hint
.
block_mask
=
metadata_with_hint
.
_build_block_mask_direct
()
assert
metadata_with_hint
.
block_mask
.
kv_num_blocks
.
max
().
item
()
<=
1
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
pytest
.
main
([
__file__
])
pytest
.
main
([
__file__
])
vllm/v1/attention/backends/flex_attention.py
View file @
8c47fdfd
...
@@ -3,9 +3,10 @@
...
@@ -3,9 +3,10 @@
"""Attention layer with FlexAttention."""
"""Attention layer with FlexAttention."""
import
math
import
math
from
collections.abc
import
Callable
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
functools
import
cached_property
from
functools
import
cached_property
from
typing
import
ClassVar
from
typing
import
ClassVar
,
NamedTuple
import
torch
import
torch
import
torch._dynamo.decorators
import
torch._dynamo.decorators
...
@@ -294,6 +295,27 @@ def causal_mask_mod(
...
@@ -294,6 +295,27 @@ def causal_mask_mod(
return
q_idx
>=
kv_idx
return
q_idx
>=
kv_idx
# Type alias for the block sparsity hint callable signature.
_block_sparsity_hint_signature
=
Callable
[
[
torch
.
Tensor
,
torch
.
Tensor
,
int
],
torch
.
Tensor
]
class
BlockSparsityHint
(
NamedTuple
):
"""This prunes KV blocks from the BlockMask before the flex_attention kernel
is invoked, so that blocks that are fully masked never get loaded.
Use this with custom mask_mods that are sparse to avoid
the kernel iterating over all KV blocks unnecessarily.
Attributes:
hint_fn: (q_block_idx [num_tokens, 1], kv_block_idx [1, num_kv_blocks],
block_size int) -> bool Tensor [num_tokens, num_kv_blocks].
Returns True for block pairs that may contain non-masked elements.
"""
hint_fn
:
_block_sparsity_hint_signature
@
dataclass
@
dataclass
class
FlexAttentionMetadata
:
class
FlexAttentionMetadata
:
causal
:
bool
causal
:
bool
...
@@ -335,6 +357,7 @@ class FlexAttentionMetadata:
...
@@ -335,6 +357,7 @@ class FlexAttentionMetadata:
transformed_score_mod
:
_score_mod_signature
|
None
=
None
transformed_score_mod
:
_score_mod_signature
|
None
=
None
sliding_window
:
int
|
None
=
None
sliding_window
:
int
|
None
=
None
mm_prefix_range
:
dict
[
int
,
list
[
tuple
[
int
,
int
]]]
|
None
=
None
mm_prefix_range
:
dict
[
int
,
list
[
tuple
[
int
,
int
]]]
|
None
=
None
block_sparsity_hint
:
BlockSparsityHint
|
None
=
None
@
cached_property
@
cached_property
def
logical_block_ids
(
self
):
def
logical_block_ids
(
self
):
...
@@ -378,7 +401,7 @@ class FlexAttentionMetadata:
...
@@ -378,7 +401,7 @@ class FlexAttentionMetadata:
return
is_valid
,
logical_q_idx
,
logical_kv_idx
return
is_valid
,
logical_q_idx
,
logical_kv_idx
def
get_
causal
_mask_mod
(
self
)
->
_mask_mod_signature
:
def
get_
paged
_mask_mod
(
self
)
->
_mask_mod_signature
:
"""Creates the mask_mod function for FlexAttention.
"""Creates the mask_mod function for FlexAttention.
This function creates the combined mask mod function that handles:
This function creates the combined mask mod function that handles:
...
@@ -504,8 +527,9 @@ class FlexAttentionMetadata:
...
@@ -504,8 +527,9 @@ class FlexAttentionMetadata:
def
get_mask_mod
(
self
):
def
get_mask_mod
(
self
):
# Stage-1: initialize the base mask_mod
# Stage-1: initialize the base mask_mod
# (causal mask for decoder or bidirectional mask for encoder)
# (causal mask for decoder or bidirectional mask for encoder)
if
self
.
causal
:
has_custom_mask
=
self
.
logical_mask_mod
is
not
causal_mask_mod
mask_mod
=
self
.
get_causal_mask_mod
()
if
self
.
causal
or
has_custom_mask
:
mask_mod
=
self
.
get_paged_mask_mod
()
else
:
else
:
mask_mod
=
self
.
get_bidirectional_mask_mod
()
mask_mod
=
self
.
get_bidirectional_mask_mod
()
# stage-2: add external mask_mod for special attention during
# stage-2: add external mask_mod for special attention during
...
@@ -591,7 +615,9 @@ class FlexAttentionMetadata:
...
@@ -591,7 +615,9 @@ class FlexAttentionMetadata:
self
.
doc_ids
,
:
cdiv
(
self
.
max_seq_len
,
self
.
block_size
)
self
.
doc_ids
,
:
cdiv
(
self
.
max_seq_len
,
self
.
block_size
)
]
]
if
self
.
sliding_window
and
self
.
causal
:
custom_hint
=
self
.
block_sparsity_hint
is
not
None
if
self
.
sliding_window
or
custom_hint
:
device
=
used_pages
.
device
device
=
used_pages
.
device
assert
self
.
doc_ids
is
not
None
assert
self
.
doc_ids
is
not
None
token_indices
=
torch
.
arange
(
token_indices
=
torch
.
arange
(
...
@@ -602,10 +628,24 @@ class FlexAttentionMetadata:
...
@@ -602,10 +628,24 @@ class FlexAttentionMetadata:
-
self
.
query_start_loc
[
self
.
doc_ids
]
-
self
.
query_start_loc
[
self
.
doc_ids
]
+
self
.
decode_offset
[
self
.
doc_ids
]
+
self
.
decode_offset
[
self
.
doc_ids
]
)
)
min_kv_idx
=
torch
.
clamp
(
logical_q_idx
-
(
self
.
sliding_window
-
1
),
min
=
0
)
min_block_idx
=
min_kv_idx
//
self
.
block_size
if
self
.
sliding_window
:
sliding_mask
=
self
.
logical_block_ids
>=
min_block_idx
[:,
None
]
assert
self
.
sliding_window
is
not
None
used_pages
.
masked_fill_
(
~
sliding_mask
,
0
)
min_kv_idx
=
torch
.
clamp
(
logical_q_idx
-
(
self
.
sliding_window
-
1
),
min
=
0
)
min_block_idx
=
min_kv_idx
//
self
.
block_size
sliding_mask
=
self
.
logical_block_ids
>=
min_block_idx
[:,
None
]
used_pages
.
masked_fill_
(
~
sliding_mask
,
0
)
if
custom_hint
:
assert
self
.
block_sparsity_hint
is
not
None
q_block_idx
=
logical_q_idx
//
self
.
block_size
hint_mask
=
self
.
block_sparsity_hint
.
hint_fn
(
q_block_idx
[:,
None
],
self
.
logical_block_ids
[
None
,
:],
self
.
block_size
,
)
used_pages
.
masked_fill_
(
~
hint_mask
,
0
)
used_pages_padded
=
pad_to_multiple
(
used_pages_padded
=
pad_to_multiple
(
used_pages
,
multiple
=
self
.
q_block_size
,
dim
=
0
used_pages
,
multiple
=
self
.
q_block_size
,
dim
=
0
...
@@ -660,11 +700,6 @@ class FlexAttentionMetadata:
...
@@ -660,11 +700,6 @@ class FlexAttentionMetadata:
self
.
mask_mod
=
self
.
get_mask_mod
()
self
.
mask_mod
=
self
.
get_mask_mod
()
self
.
transformed_score_mod
=
self
.
get_transformed_score_mod
()
self
.
transformed_score_mod
=
self
.
get_transformed_score_mod
()
if
self
.
direct_build
and
self
.
causal
:
self
.
block_mask
=
self
.
_build_block_mask_direct
()
else
:
self
.
block_mask
=
self
.
build_block_mask
()
class
FlexAttentionMetadataBuilder
(
AttentionMetadataBuilder
[
FlexAttentionMetadata
]):
class
FlexAttentionMetadataBuilder
(
AttentionMetadataBuilder
[
FlexAttentionMetadata
]):
def
__init__
(
def
__init__
(
...
@@ -770,6 +805,8 @@ class FlexAttentionImpl(AttentionImpl):
...
@@ -770,6 +805,8 @@ class FlexAttentionImpl(AttentionImpl):
alibi_slopes
:
torch
.
Tensor
|
None
alibi_slopes
:
torch
.
Tensor
|
None
logits_soft_cap
:
float
|
None
logits_soft_cap
:
float
|
None
mm_prefix_range
:
dict
[
int
,
list
[
tuple
[
int
,
int
]]]
|
None
=
None
mm_prefix_range
:
dict
[
int
,
list
[
tuple
[
int
,
int
]]]
|
None
=
None
logical_mask_mod
:
_mask_mod_signature
|
None
=
None
block_sparsity_hint
:
BlockSparsityHint
|
None
=
None
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -907,8 +944,25 @@ class FlexAttentionImpl(AttentionImpl):
...
@@ -907,8 +944,25 @@ class FlexAttentionImpl(AttentionImpl):
attn_metadata
.
mask_mod
=
attn_metadata
.
get_mask_mod
()
attn_metadata
.
mask_mod
=
attn_metadata
.
get_mask_mod
()
needs_rebuild_block_mask
=
True
needs_rebuild_block_mask
=
True
if
needs_rebuild_block_mask
:
layer_mask_mod
=
getattr
(
layer
,
"logical_mask_mod"
,
None
)
if
attn_metadata
.
direct_build
and
attn_metadata
.
causal
:
if
(
layer_mask_mod
is
not
None
and
attn_metadata
.
logical_mask_mod
is
not
layer_mask_mod
):
attn_metadata
.
logical_mask_mod
=
layer_mask_mod
attn_metadata
.
mask_mod
=
attn_metadata
.
get_mask_mod
()
needs_rebuild_block_mask
=
True
layer_hint
=
getattr
(
layer
,
"block_sparsity_hint"
,
None
)
if
(
layer_hint
is
not
None
and
attn_metadata
.
block_sparsity_hint
is
not
layer_hint
):
attn_metadata
.
block_sparsity_hint
=
layer_hint
needs_rebuild_block_mask
=
True
if
needs_rebuild_block_mask
or
attn_metadata
.
block_mask
is
None
:
if
attn_metadata
.
direct_build
:
attn_metadata
.
block_mask
=
attn_metadata
.
_build_block_mask_direct
()
attn_metadata
.
block_mask
=
attn_metadata
.
_build_block_mask_direct
()
else
:
else
:
attn_metadata
.
block_mask
=
attn_metadata
.
build_block_mask
()
attn_metadata
.
block_mask
=
attn_metadata
.
build_block_mask
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment