Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e0329ed4
Unverified
Commit
e0329ed4
authored
Aug 25, 2025
by
Driss Guessous
Committed by
GitHub
Aug 25, 2025
Browse files
Updates to Flex + VLLm integration (#21416)
Signed-off-by:
drisspg
<
drisspguessous@gmail.com
>
parent
6879cd80
Changes
3
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
439 additions
and
103 deletions
+439
-103
tests/kernels/test_flex_attention.py
tests/kernels/test_flex_attention.py
+87
-23
tests/v1/attention/test_attention_backends.py
tests/v1/attention/test_attention_backends.py
+18
-12
vllm/v1/attention/backends/flex_attention.py
vllm/v1/attention/backends/flex_attention.py
+334
-68
No files found.
tests/kernels/test_flex_attention.py
View file @
e0329ed4
...
...
@@ -9,12 +9,17 @@ import pytest
import
torch
from
packaging
import
version
from
vllm
import
SamplingParams
from
tests.v1.attention.utils
import
(
BatchSpec
,
create_common_attn_metadata
,
create_standard_kv_cache_spec
,
create_vllm_config
)
from
vllm.v1.attention.backends.flex_attention
import
(
FlexAttentionMetadataBuilder
)
from
..models.utils
import
check_embeddings_close
from
..models.utils
import
check_embeddings_close
,
check_logprobs_close
TORCH_VERSION
=
version
.
parse
(
torch
.
__version__
)
MINIMUM_TORCH_VERSION
=
version
.
parse
(
"2.7.0"
)
DIRECT_BUILD_VERSION
=
version
.
parse
(
"2.9.dev0"
)
def
set_seed
(
seed
):
...
...
@@ -34,22 +39,18 @@ def test_flex_attention_vs_default_backend(vllm_runner, monkeypatch):
"""Test that FlexAttention produces the same outputs as the default backend.
This test compares the outputs from the FlexAttention backend with
the default backend, ensuring they are
identical
when using the same seed.
the default backend, ensuring they are
similar
when using the same seed.
"""
model_name
=
"Qwen/Qwen2.5-1.5B-Instruct"
seed
=
42
max_tokens
=
24
num_logprobs
=
5
prompts
=
[
"Hello, my name is"
,
"The president of the United States is"
,
"The capital of France is"
,
]
sampling_params
=
SamplingParams
(
temperature
=
0.0
,
top_p
=
1.0
,
seed
=
seed
,
max_tokens
=
max_tokens
)
# Run with flex attention
with
monkeypatch
.
context
()
as
m
:
m
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
...
...
@@ -61,7 +62,8 @@ def test_flex_attention_vs_default_backend(vllm_runner, monkeypatch):
tensor_parallel_size
=
1
,
num_gpu_blocks_override
=
128
,
enforce_eager
=
True
)
as
llm_flex
:
output_flex
=
llm_flex
.
generate
(
prompts
,
sampling_params
)
output_flex
=
llm_flex
.
generate_greedy_logprobs
(
prompts
,
max_tokens
,
num_logprobs
)
# Run with default backend
with
monkeypatch
.
context
()
as
m
:
...
...
@@ -71,20 +73,17 @@ def test_flex_attention_vs_default_backend(vllm_runner, monkeypatch):
runner
=
"generate"
,
tensor_parallel_size
=
1
,
num_gpu_blocks_override
=
128
,
enforce_eager
=
True
)
as
llm_default
:
output_default
=
llm_default
.
generate
(
prompts
,
sampling_params
)
# Compare outputs from both backends
for
i
,
(
flex_result
,
default_result
)
in
enumerate
(
zip
(
output_flex
,
output_default
)):
prompt
=
prompts
[
i
]
flex_text
=
flex_result
[
1
][
0
]
default_text
=
default_result
[
1
][
0
]
assert
flex_text
==
default_text
,
(
f
"FlexAttention output doesn't match default for:
{
prompt
!
r
}
\n
"
f
"FlexAttention:
{
flex_text
!
r
}
\n
"
f
"Default:
{
default_text
!
r
}
"
)
enforce_eager
=
True
,
gpu_memory_utilization
=
0.85
)
as
llm_default
:
output_default
=
llm_default
.
generate_greedy_logprobs
(
prompts
,
max_tokens
,
num_logprobs
)
check_logprobs_close
(
outputs_0_lst
=
output_flex
,
outputs_1_lst
=
output_default
,
name_0
=
"flex"
,
name_1
=
"default"
,
)
@
pytest
.
mark
.
skipif
(
...
...
@@ -136,5 +135,70 @@ def test_encoder_flex_attention_vs_default_backend(vllm_runner, monkeypatch):
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
()
or
TORCH_VERSION
<
DIRECT_BUILD_VERSION
,
reason
=
"CUDA not available or PyTorch version < 2.7"
,
)
def
test_block_mask_direct_vs_slow_path
():
"""Test that direct path block mask is a superset of slow path.
The direct path may include extra blocks for performance (over-estimation),
but must include all blocks that the slow path determines are necessary.
"""
device
=
torch
.
device
(
"cuda"
)
vllm_config
=
create_vllm_config
(
model_name
=
"meta-llama/Meta-Llama-3-8B"
,
block_size
=
16
,
max_model_len
=
1024
)
kv_cache_spec
=
create_standard_kv_cache_spec
(
vllm_config
)
# Use a mixed batch that will create groups spanning multiple sequences
batch_spec
=
BatchSpec
(
seq_lens
=
[
35
,
64
,
128
,
256
],
query_lens
=
[
33
,
5
,
32
,
64
],
name
=
"test_mixed_batch"
)
common_attn_metadata
=
create_common_attn_metadata
(
batch_spec
,
vllm_config
.
cache_config
.
block_size
,
device
)
builder
=
FlexAttentionMetadataBuilder
(
kv_cache_spec
,
[],
vllm_config
,
device
)
metadata_direct
=
builder
.
build
(
common_prefix_len
=
0
,
common_attn_metadata
=
common_attn_metadata
)
builder
.
direct_build
=
False
metadata_slow
=
builder
.
build
(
common_prefix_len
=
0
,
common_attn_metadata
=
common_attn_metadata
)
assert
metadata_direct
.
block_mask
is
not
None
assert
metadata_slow
.
block_mask
is
not
None
# Extract block indices for comparison, B, H are the same
direct_indices
=
metadata_direct
.
block_mask
.
kv_indices
[
0
,
0
]
slow_indices
=
metadata_slow
.
block_mask
.
kv_indices
[
0
,
0
]
direct_num
=
metadata_direct
.
block_mask
.
kv_num_blocks
[
0
,
0
]
slow_num
=
metadata_slow
.
block_mask
.
kv_num_blocks
[
0
,
0
]
# main test: every block needed by slow path must be in direct path
num_groups
=
direct_num
.
shape
[
0
]
all_contained
=
True
missing_details
=
[]
for
group_idx
in
range
(
num_groups
):
direct_blocks
=
set
(
direct_indices
[
group_idx
,
:
direct_num
[
group_idx
]].
tolist
())
slow_blocks
=
set
(
slow_indices
[
group_idx
,
:
slow_num
[
group_idx
]].
tolist
())
missing_blocks
=
slow_blocks
-
direct_blocks
if
missing_blocks
:
all_contained
=
False
missing_details
.
append
(
f
"Group
{
group_idx
}
: missing
{
sorted
(
missing_blocks
)
}
"
)
assert
all_contained
,
(
"Direct path is missing blocks required by slow path:
\n
"
+
"
\n
"
.
join
(
missing_details
))
if
__name__
==
"__main__"
:
pytest
.
main
([
__file__
])
tests/v1/attention/test_attention_backends.py
View file @
e0329ed4
...
...
@@ -10,14 +10,15 @@ from tests.v1.attention.utils import (BatchSpec, _Backend,
create_standard_kv_cache_spec
,
create_vllm_config
,
get_attention_backend
)
from
vllm.utils
import
STR_DTYPE_TO_TORCH_DTYPE
,
cdiv
from
vllm.utils
import
STR_DTYPE_TO_TORCH_DTYPE
,
cdiv
,
is_torch_equal_or_newer
from
vllm.v1.attention.backends.utils
import
(
CommonAttentionMetadata
,
set_kv_cache_layout
)
from
vllm.v1.kv_cache_interface
import
FullAttentionSpec
BACKENDS_TO_TEST
=
[
_Backend
.
FLASH_ATTN_VLLM_V1
,
_Backend
.
FLASHINFER_VLLM_V1
,
_Backend
.
FLEX_ATTENTION
,
_Backend
.
TRITON_ATTN_VLLM_V1
,
_Backend
.
TREE_ATTN
_Backend
.
FLEX_ATTENTION
,
_Backend
.
TRITON_ATTN_VLLM_V1
,
_Backend
.
TREE_ATTN
,
"FLEX_ATTENTION_SLOW"
]
# Remove flashinfer from the list if it's not available
...
...
@@ -97,7 +98,7 @@ def create_and_prepopulate_kv_cache(
common_attn_metadata
:
CommonAttentionMetadata
,
randomize_blocks
:
bool
=
True
)
->
torch
.
Tensor
:
"""Create and prepopulate a KV cache with context data.
Args:
k_contexts: List of key context tensors for each sequence
v_contexts: List of value context tensors for each sequence
...
...
@@ -109,9 +110,9 @@ def create_and_prepopulate_kv_cache(
device: Device to create the cache on
num_blocks: Total number of blocks in the cache
block_table: Block table tensor to populate
randomize_blocks: Whether to randomly permute blocks
randomize_blocks: Whether to randomly permute blocks
or use sequential order
Returns:
Tuple of (kv_cache, updated_block_table)
"""
...
...
@@ -206,10 +207,18 @@ def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec,
kv_cache
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Run attention computation using the specified backend's AttentionImpl."""
builder_cls
,
impl_cls
=
get_attention_backend
(
backend
)
# Handle special case for FLEX_ATTENTION_SLOW
actual_backend
=
backend
use_direct_block_mask
=
is_torch_equal_or_newer
(
"2.9.0.dev0"
)
if
backend
==
"FLEX_ATTENTION_SLOW"
:
actual_backend
=
_Backend
.
FLEX_ATTENTION
use_direct_block_mask
=
False
builder_cls
,
impl_cls
=
get_attention_backend
(
actual_backend
)
# Mock flashinfer's get_per_layer_parameters if needed
if
backend
==
_Backend
.
FLASHINFER_VLLM_V1
:
if
actual_
backend
==
_Backend
.
FLASHINFER_VLLM_V1
:
import
unittest.mock
from
vllm.v1.attention.backends.utils
import
PerLayerParameters
...
...
@@ -239,6 +248,8 @@ def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec,
else
:
# Build metadata
builder
=
builder_cls
(
kv_cache_spec
,
layer_names
,
vllm_config
,
device
)
if
actual_backend
==
_Backend
.
FLEX_ATTENTION
:
builder
.
direct_build
=
use_direct_block_mask
attn_metadata
=
builder
.
build
(
common_prefix_len
=
0
,
common_attn_metadata
=
common_attn_metadata
,
...
...
@@ -453,11 +464,6 @@ def test_backend_correctness(batch_spec_name: str, model: str):
rtol
=
1e-2
atol
=
5e-3
if
backend_name
==
_Backend
.
FLEX_ATTENTION
:
atol
=
5e-1
# TODO: figure out why flex_attention has such large
# numerical differences for medium_decode, medium_prefill,
# mixed_medium
max_diff
=
torch
.
max
(
torch
.
abs
(
backend_output
-
sdpa_output
)).
item
()
max_rel_diff
=
torch
.
max
(
torch
.
abs
(
backend_output
-
sdpa_output
)
/
...
...
vllm/v1/attention/backends/flex_attention.py
View file @
e0329ed4
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment