Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f825c6bd
Unverified
Commit
f825c6bd
authored
Aug 06, 2025
by
Maximilien de Bayser
Committed by
GitHub
Aug 06, 2025
Browse files
Support encoder_only attention for FlexAttention (#22273)
Signed-off-by:
Max de Bayser
<
mbayser@br.ibm.com
>
parent
41b67f42
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
138 additions
and
47 deletions
+138
-47
tests/kernels/test_flex_attention.py
tests/kernels/test_flex_attention.py
+68
-20
vllm/v1/attention/backends/flex_attention.py
vllm/v1/attention/backends/flex_attention.py
+70
-27
No files found.
tests/kernels/test_flex_attention.py
View file @
f825c6bd
...
@@ -9,7 +9,9 @@ import pytest
...
@@ -9,7 +9,9 @@ import pytest
import
torch
import
torch
from
packaging
import
version
from
packaging
import
version
from
vllm
import
LLM
,
SamplingParams
from
vllm
import
SamplingParams
from
..models.utils
import
check_embeddings_close
TORCH_VERSION
=
version
.
parse
(
torch
.
__version__
)
TORCH_VERSION
=
version
.
parse
(
torch
.
__version__
)
MINIMUM_TORCH_VERSION
=
version
.
parse
(
"2.7.0"
)
MINIMUM_TORCH_VERSION
=
version
.
parse
(
"2.7.0"
)
...
@@ -28,7 +30,7 @@ def set_seed(seed):
...
@@ -28,7 +30,7 @@ def set_seed(seed):
not
torch
.
cuda
.
is_available
()
or
TORCH_VERSION
<
MINIMUM_TORCH_VERSION
,
not
torch
.
cuda
.
is_available
()
or
TORCH_VERSION
<
MINIMUM_TORCH_VERSION
,
reason
=
"CUDA not available or PyTorch version < 2.7"
,
reason
=
"CUDA not available or PyTorch version < 2.7"
,
)
)
def
test_flex_attention_vs_default_backend
(
monkeypatch
):
def
test_flex_attention_vs_default_backend
(
vllm_runner
,
monkeypatch
):
"""Test that FlexAttention produces the same outputs as the default backend.
"""Test that FlexAttention produces the same outputs as the default backend.
This test compares the outputs from the FlexAttention backend with
This test compares the outputs from the FlexAttention backend with
...
@@ -36,7 +38,7 @@ def test_flex_attention_vs_default_backend(monkeypatch):
...
@@ -36,7 +38,7 @@ def test_flex_attention_vs_default_backend(monkeypatch):
"""
"""
model_name
=
"Qwen/Qwen2.5-1.5B-Instruct"
model_name
=
"Qwen/Qwen2.5-1.5B-Instruct"
seed
=
42
seed
=
42
max_tokens
=
3
2
max_tokens
=
2
4
prompts
=
[
prompts
=
[
"Hello, my name is"
,
"Hello, my name is"
,
"The president of the United States is"
,
"The president of the United States is"
,
...
@@ -54,33 +56,30 @@ def test_flex_attention_vs_default_backend(monkeypatch):
...
@@ -54,33 +56,30 @@ def test_flex_attention_vs_default_backend(monkeypatch):
m
.
setenv
(
"VLLM_ATTENTION_BACKEND"
,
"FLEX_ATTENTION"
)
m
.
setenv
(
"VLLM_ATTENTION_BACKEND"
,
"FLEX_ATTENTION"
)
set_seed
(
seed
)
set_seed
(
seed
)
with
vllm_runner
(
model_name
,
llm_flex
=
LLM
(
runner
=
"generate"
,
model_name
,
tensor_parallel_size
=
1
,
tensor_parallel_size
=
1
,
num_gpu_blocks_override
=
128
,
num_gpu_blocks_override
=
128
,
enforce_eager
=
True
)
as
llm_flex
:
enforce_eager
=
True
,
output_flex
=
llm_flex
.
generate
(
prompts
,
sampling_params
)
)
output_flex
=
llm_flex
.
generate
(
prompts
,
sampling_params
)
# Run with default backend
# Run with default backend
with
monkeypatch
.
context
()
as
m
:
with
monkeypatch
.
context
()
as
m
:
m
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
m
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
set_seed
(
seed
)
set_seed
(
seed
)
llm_default
=
LLM
(
with
vllm_runner
(
model_name
,
model_name
,
runner
=
"generate"
,
tensor_parallel_size
=
1
,
tensor_parallel_size
=
1
,
num_gpu_blocks_override
=
128
,
num_gpu_blocks_override
=
128
,
enforce_eager
=
True
,
enforce_eager
=
True
)
as
llm_default
:
)
output_default
=
llm_default
.
generate
(
prompts
,
sampling_params
)
output_default
=
llm_default
.
generate
(
prompts
,
sampling_params
)
# Compare outputs from both backends
# Compare outputs from both backends
for
i
,
(
flex_result
,
for
i
,
(
flex_result
,
default_result
)
in
enumerate
(
zip
(
output_flex
,
output_default
)):
default_result
)
in
enumerate
(
zip
(
output_flex
,
output_default
)):
prompt
=
prompts
[
i
]
prompt
=
prompts
[
i
]
flex_text
=
flex_result
.
outputs
[
0
].
text
flex_text
=
flex_result
[
1
][
0
]
default_text
=
default_result
.
outputs
[
0
].
text
default_text
=
default_result
[
1
][
0
]
assert
flex_text
==
default_text
,
(
assert
flex_text
==
default_text
,
(
f
"FlexAttention output doesn't match default for:
{
prompt
!
r
}
\n
"
f
"FlexAttention output doesn't match default for:
{
prompt
!
r
}
\n
"
...
@@ -88,5 +87,54 @@ def test_flex_attention_vs_default_backend(monkeypatch):
...
@@ -88,5 +87,54 @@ def test_flex_attention_vs_default_backend(monkeypatch):
f
"Default:
{
default_text
!
r
}
"
)
f
"Default:
{
default_text
!
r
}
"
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
()
or
TORCH_VERSION
<
MINIMUM_TORCH_VERSION
,
reason
=
"CUDA not available or PyTorch version < 2.7"
,
)
def
test_encoder_flex_attention_vs_default_backend
(
vllm_runner
,
monkeypatch
):
"""Test that FlexAttention produces the same outputs as the default backend.
This test compares the outputs from the FlexAttention backend with
the default backend for encoder models.
"""
model_name
=
"BAAI/bge-base-en-v1.5"
prompts
=
[
"Hello, my name is"
,
"The president of the United States is"
,
"The capital of France is"
,
]
# Run with flex attention
with
monkeypatch
.
context
()
as
m
:
m
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
m
.
setenv
(
"VLLM_ATTENTION_BACKEND"
,
"FLEX_ATTENTION"
)
with
vllm_runner
(
model_name
,
runner
=
"pooling"
,
dtype
=
torch
.
bfloat16
,
tensor_parallel_size
=
1
,
max_model_len
=
100
,
enforce_eager
=
True
)
as
llm_flex
:
flex_outputs
=
llm_flex
.
embed
(
prompts
)
# Run with default backend
with
monkeypatch
.
context
()
as
m
:
m
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
with
vllm_runner
(
model_name
,
runner
=
"pooling"
,
dtype
=
torch
.
bfloat16
,
tensor_parallel_size
=
1
,
max_model_len
=
100
,
enforce_eager
=
True
)
as
llm_default
:
default_outputs
=
llm_default
.
embed
(
prompts
)
check_embeddings_close
(
embeddings_0_lst
=
flex_outputs
,
embeddings_1_lst
=
default_outputs
,
name_0
=
"flex"
,
name_1
=
"default"
,
tol
=
1e-2
,
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
pytest
.
main
([
__file__
])
pytest
.
main
([
__file__
])
vllm/v1/attention/backends/flex_attention.py
View file @
f825c6bd
...
@@ -148,6 +148,7 @@ def causal_mask_mod(b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor,
...
@@ -148,6 +148,7 @@ def causal_mask_mod(b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor,
@
dataclass
@
dataclass
class
FlexAttentionMetadata
:
class
FlexAttentionMetadata
:
causal
:
bool
num_actual_tokens
:
int
# Number of tokens excluding padding.
num_actual_tokens
:
int
# Number of tokens excluding padding.
max_query_len
:
int
max_query_len
:
int
query_start_loc
:
torch
.
Tensor
query_start_loc
:
torch
.
Tensor
...
@@ -177,10 +178,9 @@ class FlexAttentionMetadata:
...
@@ -177,10 +178,9 @@ class FlexAttentionMetadata:
num_blocks
=
0
num_blocks
=
0
block_mask
:
Optional
[
BlockMask
]
=
None
block_mask
:
Optional
[
BlockMask
]
=
None
score_mod
:
Optional
[
_score_mod_signature
]
=
None
score_mod
:
Optional
[
_score_mod_signature
]
=
None
mask_mod
:
Optional
[
_mask_mod_signature
]
=
None
logical_mask_mod
:
_mask_mod_signature
=
causal_mask_mod
logical_mask_mod
:
_mask_mod_signature
=
causal_mask_mod
def
get_mask_mod
(
self
)
->
_mask_mod_signature
:
def
get_
causal_
mask_mod
(
self
)
->
_mask_mod_signature
:
"""Creates the mask_mod function for FlexAttention.
"""Creates the mask_mod function for FlexAttention.
This function creates the combined mask mod function that handles:
This function creates the combined mask mod function that handles:
...
@@ -233,14 +233,39 @@ class FlexAttentionMetadata:
...
@@ -233,14 +233,39 @@ class FlexAttentionMetadata:
return
final_mask_mod
return
final_mask_mod
def
get_bidirectional_mask_mod
(
self
)
->
_mask_mod_signature
:
"""Creates the encoder mask_mod function for FlexAttention.
Since the encoder bidirectional attention doesn't run with
KV cache, this function creates a mask based on the
packed query sequences.
"""
# Create a lookup mapping from query indices -> request number
request_lookup
=
_offsets_to_doc_ids_tensor
(
self
.
query_start_loc
)
def
final_mask_mod
(
b
:
torch
.
Tensor
,
h
:
torch
.
Tensor
,
q_idx
:
torch
.
Tensor
,
kv_idx
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
return
request_lookup
[
q_idx
]
==
request_lookup
[
kv_idx
]
return
final_mask_mod
def
build_block_mask
(
self
)
->
BlockMask
:
def
build_block_mask
(
self
)
->
BlockMask
:
assert
self
.
mask_mod
is
not
None
if
self
.
causal
:
mask_mod
=
self
.
get_causal_mask_mod
()
kv_len
=
self
.
total_cache_tokens
else
:
mask_mod
=
self
.
get_bidirectional_mask_mod
()
kv_len
=
self
.
num_actual_tokens
return
create_block_mask_compiled
(
return
create_block_mask_compiled
(
self
.
mask_mod
,
mask_mod
,
None
,
None
,
None
,
None
,
self
.
num_actual_tokens
,
self
.
num_actual_tokens
,
self
.
total_cache_tok
en
s
,
kv_l
en
,
device
=
self
.
block_table
.
device
,
device
=
self
.
block_table
.
device
,
)
)
...
@@ -251,7 +276,6 @@ class FlexAttentionMetadata:
...
@@ -251,7 +276,6 @@ class FlexAttentionMetadata:
assert
self
.
prefix_kv_lens
is
None
,
"Not implemented yet."
assert
self
.
prefix_kv_lens
is
None
,
"Not implemented yet."
assert
self
.
suffix_kv_lens
is
None
,
"Not implemented yet."
assert
self
.
suffix_kv_lens
is
None
,
"Not implemented yet."
self
.
num_blocks
=
self
.
total_cache_tokens
//
self
.
block_size
self
.
num_blocks
=
self
.
total_cache_tokens
//
self
.
block_size
self
.
mask_mod
=
self
.
get_mask_mod
()
self
.
block_mask
=
self
.
build_block_mask
()
self
.
block_mask
=
self
.
build_block_mask
()
...
@@ -306,6 +330,7 @@ class FlexAttentionMetadataBuilder(
...
@@ -306,6 +330,7 @@ class FlexAttentionMetadataBuilder(
self
.
device
,
non_blocking
=
True
)
self
.
device
,
non_blocking
=
True
)
out
=
FlexAttentionMetadata
(
out
=
FlexAttentionMetadata
(
causal
=
common_attn_metadata
.
causal
,
num_actual_tokens
=
num_actual_tokens
,
num_actual_tokens
=
num_actual_tokens
,
max_query_len
=
max_query_len
,
max_query_len
=
max_query_len
,
query_start_loc
=
query_start_loc
,
query_start_loc
=
query_start_loc
,
...
@@ -350,6 +375,12 @@ class FlexAttentionImpl(AttentionImpl):
...
@@ -350,6 +375,12 @@ class FlexAttentionImpl(AttentionImpl):
self
.
head_size
=
head_size
self
.
head_size
=
head_size
self
.
scale
=
float
(
scale
)
self
.
scale
=
float
(
scale
)
self
.
num_kv_heads
=
num_kv_heads
self
.
num_kv_heads
=
num_kv_heads
self
.
attn_type
=
attn_type
if
attn_type
not
in
(
AttentionType
.
ENCODER_ONLY
,
AttentionType
.
DECODER
):
raise
NotImplementedError
(
f
"FlexAttention does not support
{
attn_type
}
attention"
)
if
alibi_slopes
is
not
None
:
if
alibi_slopes
is
not
None
:
raise
NotImplementedError
(
raise
NotImplementedError
(
...
@@ -425,26 +456,38 @@ class FlexAttentionImpl(AttentionImpl):
...
@@ -425,26 +456,38 @@ class FlexAttentionImpl(AttentionImpl):
num_actual_tokens
=
attn_metadata
.
num_actual_tokens
num_actual_tokens
=
attn_metadata
.
num_actual_tokens
key_cache
,
value_cache
=
kv_cache
.
unbind
(
0
)
if
not
attn_metadata
.
causal
:
assert
self
.
attn_type
==
AttentionType
.
ENCODER_ONLY
torch
.
ops
.
_C_cache_ops
.
reshape_and_cache_flash
(
key
,
query
,
key_tensor
,
value_tensor
=
map
(
value
,
lambda
x
:
self
.
view_as_4d
(
x
).
permute
(
0
,
2
,
1
,
3
),
key_cache
,
(
query
,
key
,
value
),
value_cache
,
)
attn_metadata
.
slot_mapping
,
self
.
kv_cache_dtype
,
else
:
layer
.
_k_scale
,
assert
self
.
attn_type
==
AttentionType
.
DECODER
layer
.
_v_scale
,
key_cache
,
value_cache
=
kv_cache
.
unbind
(
0
)
)
torch
.
ops
.
_C_cache_ops
.
reshape_and_cache_flash
(
key
,
value
,
key_cache
,
value_cache
,
attn_metadata
.
slot_mapping
,
self
.
kv_cache_dtype
,
layer
.
_k_scale
,
layer
.
_v_scale
,
)
# View out the block_size dim
key_cache
=
key_cache
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
value_cache
=
value_cache
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
query
,
key_tensor
,
value_tensor
=
map
(
lambda
x
:
self
.
view_as_4d
(
x
).
permute
(
0
,
2
,
1
,
3
),
(
query
,
key_cache
,
value_cache
),
)
# View out the block_size dim
key_cache
=
key_cache
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
value_cache
=
value_cache
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
query
,
key_cache
,
value_cache
=
map
(
lambda
x
:
self
.
view_as_4d
(
x
).
permute
(
0
,
2
,
1
,
3
),
(
query
,
key_cache
,
value_cache
),
)
query
=
query
[:,
:,
:
num_actual_tokens
,
:]
query
=
query
[:,
:,
:
num_actual_tokens
,
:]
# Doesn't work for now -> constraint violation
# Doesn't work for now -> constraint violation
# torch._dynamo.try_mark_dynamic(query, 2)
# torch._dynamo.try_mark_dynamic(query, 2)
...
@@ -465,8 +508,8 @@ class FlexAttentionImpl(AttentionImpl):
...
@@ -465,8 +508,8 @@ class FlexAttentionImpl(AttentionImpl):
out
=
flex_attention_compiled
(
out
=
flex_attention_compiled
(
query
,
query
,
key_
cache
,
key_
tensor
,
value_
cache
,
value_
tensor
,
attn_metadata
.
score_mod
,
attn_metadata
.
score_mod
,
attn_metadata
.
block_mask
,
attn_metadata
.
block_mask
,
self
.
scale
,
self
.
scale
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment