Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f825c6bd
Unverified
Commit
f825c6bd
authored
Aug 06, 2025
by
Maximilien de Bayser
Committed by
GitHub
Aug 06, 2025
Browse files
Support encoder_only attention for FlexAttention (#22273)
Signed-off-by:
Max de Bayser
<
mbayser@br.ibm.com
>
parent
41b67f42
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
138 additions
and
47 deletions
+138
-47
tests/kernels/test_flex_attention.py
tests/kernels/test_flex_attention.py
+68
-20
vllm/v1/attention/backends/flex_attention.py
vllm/v1/attention/backends/flex_attention.py
+70
-27
No files found.
tests/kernels/test_flex_attention.py
View file @
f825c6bd
...
...
@@ -9,7 +9,9 @@ import pytest
import
torch
from
packaging
import
version
from
vllm
import
LLM
,
SamplingParams
from
vllm
import
SamplingParams
from
..models.utils
import
check_embeddings_close
TORCH_VERSION
=
version
.
parse
(
torch
.
__version__
)
MINIMUM_TORCH_VERSION
=
version
.
parse
(
"2.7.0"
)
...
...
@@ -28,7 +30,7 @@ def set_seed(seed):
not
torch
.
cuda
.
is_available
()
or
TORCH_VERSION
<
MINIMUM_TORCH_VERSION
,
reason
=
"CUDA not available or PyTorch version < 2.7"
,
)
def
test_flex_attention_vs_default_backend
(
monkeypatch
):
def
test_flex_attention_vs_default_backend
(
vllm_runner
,
monkeypatch
):
"""Test that FlexAttention produces the same outputs as the default backend.
This test compares the outputs from the FlexAttention backend with
...
...
@@ -36,7 +38,7 @@ def test_flex_attention_vs_default_backend(monkeypatch):
"""
model_name
=
"Qwen/Qwen2.5-1.5B-Instruct"
seed
=
42
max_tokens
=
3
2
max_tokens
=
2
4
prompts
=
[
"Hello, my name is"
,
"The president of the United States is"
,
...
...
@@ -54,33 +56,30 @@ def test_flex_attention_vs_default_backend(monkeypatch):
m
.
setenv
(
"VLLM_ATTENTION_BACKEND"
,
"FLEX_ATTENTION"
)
set_seed
(
seed
)
llm_flex
=
LLM
(
model_name
,
with
vllm_runner
(
model_name
,
runner
=
"generate"
,
tensor_parallel_size
=
1
,
num_gpu_blocks_override
=
128
,
enforce_eager
=
True
,
)
enforce_eager
=
True
)
as
llm_flex
:
output_flex
=
llm_flex
.
generate
(
prompts
,
sampling_params
)
# Run with default backend
with
monkeypatch
.
context
()
as
m
:
m
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
set_seed
(
seed
)
llm_default
=
LLM
(
model_name
,
with
vllm_runner
(
model_name
,
runner
=
"generate"
,
tensor_parallel_size
=
1
,
num_gpu_blocks_override
=
128
,
enforce_eager
=
True
,
)
enforce_eager
=
True
)
as
llm_default
:
output_default
=
llm_default
.
generate
(
prompts
,
sampling_params
)
# Compare outputs from both backends
for
i
,
(
flex_result
,
default_result
)
in
enumerate
(
zip
(
output_flex
,
output_default
)):
prompt
=
prompts
[
i
]
flex_text
=
flex_result
.
outputs
[
0
].
text
default_text
=
default_result
.
outputs
[
0
].
text
flex_text
=
flex_result
[
1
][
0
]
default_text
=
default_result
[
1
][
0
]
assert
flex_text
==
default_text
,
(
f
"FlexAttention output doesn't match default for:
{
prompt
!
r
}
\n
"
...
...
@@ -88,5 +87,54 @@ def test_flex_attention_vs_default_backend(monkeypatch):
f
"Default:
{
default_text
!
r
}
"
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
()
or
TORCH_VERSION
<
MINIMUM_TORCH_VERSION
,
reason
=
"CUDA not available or PyTorch version < 2.7"
,
)
def
test_encoder_flex_attention_vs_default_backend
(
vllm_runner
,
monkeypatch
):
"""Test that FlexAttention produces the same outputs as the default backend.
This test compares the outputs from the FlexAttention backend with
the default backend for encoder models.
"""
model_name
=
"BAAI/bge-base-en-v1.5"
prompts
=
[
"Hello, my name is"
,
"The president of the United States is"
,
"The capital of France is"
,
]
# Run with flex attention
with
monkeypatch
.
context
()
as
m
:
m
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
m
.
setenv
(
"VLLM_ATTENTION_BACKEND"
,
"FLEX_ATTENTION"
)
with
vllm_runner
(
model_name
,
runner
=
"pooling"
,
dtype
=
torch
.
bfloat16
,
tensor_parallel_size
=
1
,
max_model_len
=
100
,
enforce_eager
=
True
)
as
llm_flex
:
flex_outputs
=
llm_flex
.
embed
(
prompts
)
# Run with default backend
with
monkeypatch
.
context
()
as
m
:
m
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
with
vllm_runner
(
model_name
,
runner
=
"pooling"
,
dtype
=
torch
.
bfloat16
,
tensor_parallel_size
=
1
,
max_model_len
=
100
,
enforce_eager
=
True
)
as
llm_default
:
default_outputs
=
llm_default
.
embed
(
prompts
)
check_embeddings_close
(
embeddings_0_lst
=
flex_outputs
,
embeddings_1_lst
=
default_outputs
,
name_0
=
"flex"
,
name_1
=
"default"
,
tol
=
1e-2
,
)
if
__name__
==
"__main__"
:
pytest
.
main
([
__file__
])
vllm/v1/attention/backends/flex_attention.py
View file @
f825c6bd
...
...
@@ -148,6 +148,7 @@ def causal_mask_mod(b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor,
@
dataclass
class
FlexAttentionMetadata
:
causal
:
bool
num_actual_tokens
:
int
# Number of tokens excluding padding.
max_query_len
:
int
query_start_loc
:
torch
.
Tensor
...
...
@@ -177,10 +178,9 @@ class FlexAttentionMetadata:
num_blocks
=
0
block_mask
:
Optional
[
BlockMask
]
=
None
score_mod
:
Optional
[
_score_mod_signature
]
=
None
mask_mod
:
Optional
[
_mask_mod_signature
]
=
None
logical_mask_mod
:
_mask_mod_signature
=
causal_mask_mod
def
get_mask_mod
(
self
)
->
_mask_mod_signature
:
def
get_
causal_
mask_mod
(
self
)
->
_mask_mod_signature
:
"""Creates the mask_mod function for FlexAttention.
This function creates the combined mask mod function that handles:
...
...
@@ -233,14 +233,39 @@ class FlexAttentionMetadata:
return
final_mask_mod
def
get_bidirectional_mask_mod
(
self
)
->
_mask_mod_signature
:
"""Creates the encoder mask_mod function for FlexAttention.
Since the encoder bidirectional attention doesn't run with
KV cache, this function creates a mask based on the
packed query sequences.
"""
# Create a lookup mapping from query indices -> request number
request_lookup
=
_offsets_to_doc_ids_tensor
(
self
.
query_start_loc
)
def
final_mask_mod
(
b
:
torch
.
Tensor
,
h
:
torch
.
Tensor
,
q_idx
:
torch
.
Tensor
,
kv_idx
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
return
request_lookup
[
q_idx
]
==
request_lookup
[
kv_idx
]
return
final_mask_mod
def
build_block_mask
(
self
)
->
BlockMask
:
assert
self
.
mask_mod
is
not
None
if
self
.
causal
:
mask_mod
=
self
.
get_causal_mask_mod
()
kv_len
=
self
.
total_cache_tokens
else
:
mask_mod
=
self
.
get_bidirectional_mask_mod
()
kv_len
=
self
.
num_actual_tokens
return
create_block_mask_compiled
(
self
.
mask_mod
,
mask_mod
,
None
,
None
,
self
.
num_actual_tokens
,
self
.
total_cache_tok
en
s
,
kv_l
en
,
device
=
self
.
block_table
.
device
,
)
...
...
@@ -251,7 +276,6 @@ class FlexAttentionMetadata:
assert
self
.
prefix_kv_lens
is
None
,
"Not implemented yet."
assert
self
.
suffix_kv_lens
is
None
,
"Not implemented yet."
self
.
num_blocks
=
self
.
total_cache_tokens
//
self
.
block_size
self
.
mask_mod
=
self
.
get_mask_mod
()
self
.
block_mask
=
self
.
build_block_mask
()
...
...
@@ -306,6 +330,7 @@ class FlexAttentionMetadataBuilder(
self
.
device
,
non_blocking
=
True
)
out
=
FlexAttentionMetadata
(
causal
=
common_attn_metadata
.
causal
,
num_actual_tokens
=
num_actual_tokens
,
max_query_len
=
max_query_len
,
query_start_loc
=
query_start_loc
,
...
...
@@ -350,6 +375,12 @@ class FlexAttentionImpl(AttentionImpl):
self
.
head_size
=
head_size
self
.
scale
=
float
(
scale
)
self
.
num_kv_heads
=
num_kv_heads
self
.
attn_type
=
attn_type
if
attn_type
not
in
(
AttentionType
.
ENCODER_ONLY
,
AttentionType
.
DECODER
):
raise
NotImplementedError
(
f
"FlexAttention does not support
{
attn_type
}
attention"
)
if
alibi_slopes
is
not
None
:
raise
NotImplementedError
(
...
...
@@ -425,6 +456,16 @@ class FlexAttentionImpl(AttentionImpl):
num_actual_tokens
=
attn_metadata
.
num_actual_tokens
if
not
attn_metadata
.
causal
:
assert
self
.
attn_type
==
AttentionType
.
ENCODER_ONLY
query
,
key_tensor
,
value_tensor
=
map
(
lambda
x
:
self
.
view_as_4d
(
x
).
permute
(
0
,
2
,
1
,
3
),
(
query
,
key
,
value
),
)
else
:
assert
self
.
attn_type
==
AttentionType
.
DECODER
key_cache
,
value_cache
=
kv_cache
.
unbind
(
0
)
torch
.
ops
.
_C_cache_ops
.
reshape_and_cache_flash
(
...
...
@@ -440,11 +481,13 @@ class FlexAttentionImpl(AttentionImpl):
# View out the block_size dim
key_cache
=
key_cache
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
value_cache
=
value_cache
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
query
,
key_cache
,
value_cache
=
map
(
value_cache
=
value_cache
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
query
,
key_tensor
,
value_tensor
=
map
(
lambda
x
:
self
.
view_as_4d
(
x
).
permute
(
0
,
2
,
1
,
3
),
(
query
,
key_cache
,
value_cache
),
)
query
=
query
[:,
:,
:
num_actual_tokens
,
:]
# Doesn't work for now -> constraint violation
# torch._dynamo.try_mark_dynamic(query, 2)
...
...
@@ -465,8 +508,8 @@ class FlexAttentionImpl(AttentionImpl):
out
=
flex_attention_compiled
(
query
,
key_
cache
,
value_
cache
,
key_
tensor
,
value_
tensor
,
attn_metadata
.
score_mod
,
attn_metadata
.
block_mask
,
self
.
scale
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment