Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e0329ed4
Unverified
Commit
e0329ed4
authored
Aug 25, 2025
by
Driss Guessous
Committed by
GitHub
Aug 25, 2025
Browse files
Updates to Flex + VLLm integration (#21416)
Signed-off-by:
drisspg
<
drisspguessous@gmail.com
>
parent
6879cd80
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
439 additions
and
103 deletions
+439
-103
tests/kernels/test_flex_attention.py
tests/kernels/test_flex_attention.py
+87
-23
tests/v1/attention/test_attention_backends.py
tests/v1/attention/test_attention_backends.py
+18
-12
vllm/v1/attention/backends/flex_attention.py
vllm/v1/attention/backends/flex_attention.py
+334
-68
No files found.
tests/kernels/test_flex_attention.py
View file @
e0329ed4
...
...
@@ -9,12 +9,17 @@ import pytest
import
torch
from
packaging
import
version
from
vllm
import
SamplingParams
from
tests.v1.attention.utils
import
(
BatchSpec
,
create_common_attn_metadata
,
create_standard_kv_cache_spec
,
create_vllm_config
)
from
vllm.v1.attention.backends.flex_attention
import
(
FlexAttentionMetadataBuilder
)
from
..models.utils
import
check_embeddings_close
from
..models.utils
import
check_embeddings_close
,
check_logprobs_close
TORCH_VERSION
=
version
.
parse
(
torch
.
__version__
)
MINIMUM_TORCH_VERSION
=
version
.
parse
(
"2.7.0"
)
DIRECT_BUILD_VERSION
=
version
.
parse
(
"2.9.dev0"
)
def
set_seed
(
seed
):
...
...
@@ -34,22 +39,18 @@ def test_flex_attention_vs_default_backend(vllm_runner, monkeypatch):
"""Test that FlexAttention produces the same outputs as the default backend.
This test compares the outputs from the FlexAttention backend with
the default backend, ensuring they are
identical
when using the same seed.
the default backend, ensuring they are
similar
when using the same seed.
"""
model_name
=
"Qwen/Qwen2.5-1.5B-Instruct"
seed
=
42
max_tokens
=
24
num_logprobs
=
5
prompts
=
[
"Hello, my name is"
,
"The president of the United States is"
,
"The capital of France is"
,
]
sampling_params
=
SamplingParams
(
temperature
=
0.0
,
top_p
=
1.0
,
seed
=
seed
,
max_tokens
=
max_tokens
)
# Run with flex attention
with
monkeypatch
.
context
()
as
m
:
m
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
...
...
@@ -61,7 +62,8 @@ def test_flex_attention_vs_default_backend(vllm_runner, monkeypatch):
tensor_parallel_size
=
1
,
num_gpu_blocks_override
=
128
,
enforce_eager
=
True
)
as
llm_flex
:
output_flex
=
llm_flex
.
generate
(
prompts
,
sampling_params
)
output_flex
=
llm_flex
.
generate_greedy_logprobs
(
prompts
,
max_tokens
,
num_logprobs
)
# Run with default backend
with
monkeypatch
.
context
()
as
m
:
...
...
@@ -71,20 +73,17 @@ def test_flex_attention_vs_default_backend(vllm_runner, monkeypatch):
runner
=
"generate"
,
tensor_parallel_size
=
1
,
num_gpu_blocks_override
=
128
,
enforce_eager
=
True
)
as
llm_default
:
output_default
=
llm_default
.
generate
(
prompts
,
sampling_params
)
# Compare outputs from both backends
for
i
,
(
flex_result
,
default_result
)
in
enumerate
(
zip
(
output_flex
,
output_default
)):
prompt
=
prompts
[
i
]
flex_text
=
flex_result
[
1
][
0
]
default_text
=
default_result
[
1
][
0
]
assert
flex_text
==
default_text
,
(
f
"FlexAttention output doesn't match default for:
{
prompt
!
r
}
\n
"
f
"FlexAttention:
{
flex_text
!
r
}
\n
"
f
"Default:
{
default_text
!
r
}
"
)
enforce_eager
=
True
,
gpu_memory_utilization
=
0.85
)
as
llm_default
:
output_default
=
llm_default
.
generate_greedy_logprobs
(
prompts
,
max_tokens
,
num_logprobs
)
check_logprobs_close
(
outputs_0_lst
=
output_flex
,
outputs_1_lst
=
output_default
,
name_0
=
"flex"
,
name_1
=
"default"
,
)
@
pytest
.
mark
.
skipif
(
...
...
@@ -136,5 +135,70 @@ def test_encoder_flex_attention_vs_default_backend(vllm_runner, monkeypatch):
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
()
or
TORCH_VERSION
<
DIRECT_BUILD_VERSION
,
reason
=
"CUDA not available or PyTorch version < 2.7"
,
)
def
test_block_mask_direct_vs_slow_path
():
"""Test that direct path block mask is a superset of slow path.
The direct path may include extra blocks for performance (over-estimation),
but must include all blocks that the slow path determines are necessary.
"""
device
=
torch
.
device
(
"cuda"
)
vllm_config
=
create_vllm_config
(
model_name
=
"meta-llama/Meta-Llama-3-8B"
,
block_size
=
16
,
max_model_len
=
1024
)
kv_cache_spec
=
create_standard_kv_cache_spec
(
vllm_config
)
# Use a mixed batch that will create groups spanning multiple sequences
batch_spec
=
BatchSpec
(
seq_lens
=
[
35
,
64
,
128
,
256
],
query_lens
=
[
33
,
5
,
32
,
64
],
name
=
"test_mixed_batch"
)
common_attn_metadata
=
create_common_attn_metadata
(
batch_spec
,
vllm_config
.
cache_config
.
block_size
,
device
)
builder
=
FlexAttentionMetadataBuilder
(
kv_cache_spec
,
[],
vllm_config
,
device
)
metadata_direct
=
builder
.
build
(
common_prefix_len
=
0
,
common_attn_metadata
=
common_attn_metadata
)
builder
.
direct_build
=
False
metadata_slow
=
builder
.
build
(
common_prefix_len
=
0
,
common_attn_metadata
=
common_attn_metadata
)
assert
metadata_direct
.
block_mask
is
not
None
assert
metadata_slow
.
block_mask
is
not
None
# Extract block indices for comparison, B, H are the same
direct_indices
=
metadata_direct
.
block_mask
.
kv_indices
[
0
,
0
]
slow_indices
=
metadata_slow
.
block_mask
.
kv_indices
[
0
,
0
]
direct_num
=
metadata_direct
.
block_mask
.
kv_num_blocks
[
0
,
0
]
slow_num
=
metadata_slow
.
block_mask
.
kv_num_blocks
[
0
,
0
]
# main test: every block needed by slow path must be in direct path
num_groups
=
direct_num
.
shape
[
0
]
all_contained
=
True
missing_details
=
[]
for
group_idx
in
range
(
num_groups
):
direct_blocks
=
set
(
direct_indices
[
group_idx
,
:
direct_num
[
group_idx
]].
tolist
())
slow_blocks
=
set
(
slow_indices
[
group_idx
,
:
slow_num
[
group_idx
]].
tolist
())
missing_blocks
=
slow_blocks
-
direct_blocks
if
missing_blocks
:
all_contained
=
False
missing_details
.
append
(
f
"Group
{
group_idx
}
: missing
{
sorted
(
missing_blocks
)
}
"
)
assert
all_contained
,
(
"Direct path is missing blocks required by slow path:
\n
"
+
"
\n
"
.
join
(
missing_details
))
if
__name__
==
"__main__"
:
pytest
.
main
([
__file__
])
tests/v1/attention/test_attention_backends.py
View file @
e0329ed4
...
...
@@ -10,14 +10,15 @@ from tests.v1.attention.utils import (BatchSpec, _Backend,
create_standard_kv_cache_spec
,
create_vllm_config
,
get_attention_backend
)
from
vllm.utils
import
STR_DTYPE_TO_TORCH_DTYPE
,
cdiv
from
vllm.utils
import
STR_DTYPE_TO_TORCH_DTYPE
,
cdiv
,
is_torch_equal_or_newer
from
vllm.v1.attention.backends.utils
import
(
CommonAttentionMetadata
,
set_kv_cache_layout
)
from
vllm.v1.kv_cache_interface
import
FullAttentionSpec
BACKENDS_TO_TEST
=
[
_Backend
.
FLASH_ATTN_VLLM_V1
,
_Backend
.
FLASHINFER_VLLM_V1
,
_Backend
.
FLEX_ATTENTION
,
_Backend
.
TRITON_ATTN_VLLM_V1
,
_Backend
.
TREE_ATTN
_Backend
.
FLEX_ATTENTION
,
_Backend
.
TRITON_ATTN_VLLM_V1
,
_Backend
.
TREE_ATTN
,
"FLEX_ATTENTION_SLOW"
]
# Remove flashinfer from the list if it's not available
...
...
@@ -97,7 +98,7 @@ def create_and_prepopulate_kv_cache(
common_attn_metadata
:
CommonAttentionMetadata
,
randomize_blocks
:
bool
=
True
)
->
torch
.
Tensor
:
"""Create and prepopulate a KV cache with context data.
Args:
k_contexts: List of key context tensors for each sequence
v_contexts: List of value context tensors for each sequence
...
...
@@ -109,9 +110,9 @@ def create_and_prepopulate_kv_cache(
device: Device to create the cache on
num_blocks: Total number of blocks in the cache
block_table: Block table tensor to populate
randomize_blocks: Whether to randomly permute blocks
randomize_blocks: Whether to randomly permute blocks
or use sequential order
Returns:
Tuple of (kv_cache, updated_block_table)
"""
...
...
@@ -206,10 +207,18 @@ def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec,
kv_cache
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Run attention computation using the specified backend's AttentionImpl."""
builder_cls
,
impl_cls
=
get_attention_backend
(
backend
)
# Handle special case for FLEX_ATTENTION_SLOW
actual_backend
=
backend
use_direct_block_mask
=
is_torch_equal_or_newer
(
"2.9.0.dev0"
)
if
backend
==
"FLEX_ATTENTION_SLOW"
:
actual_backend
=
_Backend
.
FLEX_ATTENTION
use_direct_block_mask
=
False
builder_cls
,
impl_cls
=
get_attention_backend
(
actual_backend
)
# Mock flashinfer's get_per_layer_parameters if needed
if
backend
==
_Backend
.
FLASHINFER_VLLM_V1
:
if
actual_
backend
==
_Backend
.
FLASHINFER_VLLM_V1
:
import
unittest.mock
from
vllm.v1.attention.backends.utils
import
PerLayerParameters
...
...
@@ -239,6 +248,8 @@ def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec,
else
:
# Build metadata
builder
=
builder_cls
(
kv_cache_spec
,
layer_names
,
vllm_config
,
device
)
if
actual_backend
==
_Backend
.
FLEX_ATTENTION
:
builder
.
direct_build
=
use_direct_block_mask
attn_metadata
=
builder
.
build
(
common_prefix_len
=
0
,
common_attn_metadata
=
common_attn_metadata
,
...
...
@@ -453,11 +464,6 @@ def test_backend_correctness(batch_spec_name: str, model: str):
rtol
=
1e-2
atol
=
5e-3
if
backend_name
==
_Backend
.
FLEX_ATTENTION
:
atol
=
5e-1
# TODO: figure out why flex_attention has such large
# numerical differences for medium_decode, medium_prefill,
# mixed_medium
max_diff
=
torch
.
max
(
torch
.
abs
(
backend_output
-
sdpa_output
)).
item
()
max_rel_diff
=
torch
.
max
(
torch
.
abs
(
backend_output
-
sdpa_output
)
/
...
...
vllm/v1/attention/backends/flex_attention.py
View file @
e0329ed4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer with Fl
ash
Attention."""
from
collections
import
defaultdict
"""Attention layer with Fl
ex
Attention."""
from
dataclasses
import
dataclass
from
typing
import
Optional
from
typing
import
TYPE_CHECKING
,
Optional
,
Union
import
torch
import
torch._dynamo.decorators
import
torch.nn.functional
as
F
from
torch.nn.attention.flex_attention
import
(
BlockMask
,
_mask_mod_signature
,
_score_mod_signature
,
create_block_mask
,
...
...
@@ -16,13 +18,17 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
is_quantized_kv_cache
)
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.
platform
s
import
c
urrent_platform
from
vllm.
util
s
import
c
div
,
is_torch_equal_or_newer
from
vllm.v1.attention.backends.utils
import
(
AttentionMetadataBuilder
,
CommonAttentionMetadata
)
from
vllm.v1.kv_cache_interface
import
AttentionSpec
logger
=
init_logger
(
__name__
)
if
TYPE_CHECKING
:
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.worker.gpu_input_batch
import
InputBatch
create_block_mask_compiled
=
torch
.
compile
(
create_block_mask
,
fullgraph
=
True
,
mode
=
"reduce-overhead"
)
...
...
@@ -36,6 +42,23 @@ def _offsets_to_doc_ids_tensor(offsets: torch.Tensor) -> torch.Tensor:
torch
.
arange
(
len
(
counts
),
device
=
device
,
dtype
=
torch
.
int32
),
counts
)
def
pad_to_multiple
(
x
:
torch
.
Tensor
,
multiple
:
int
,
dim
:
int
):
difference
=
(
multiple
-
(
x
.
shape
[
dim
]
%
multiple
))
%
multiple
if
difference
==
0
:
return
x
dim
=
dim
if
dim
>=
0
else
x
.
ndim
+
dim
pad_list
=
[]
for
i
in
range
(
x
.
ndim
-
1
,
dim
-
1
,
-
1
):
if
i
==
dim
:
pad_list
.
extend
([
0
,
difference
])
else
:
pad_list
.
extend
([
0
,
0
])
return
F
.
pad
(
x
,
pad_list
,
mode
=
"constant"
,
value
=
0
)
class
FlexAttentionBackend
(
AttentionBackend
):
accept_output_buffer
:
bool
=
True
...
...
@@ -77,10 +100,10 @@ class FlexAttentionBackend(AttentionBackend):
return
False
#
@torch.compile(fullgraph=True, mode="reduce-overhead")
def
physical_to_logical_mapping
(
block_table
:
torch
.
Tensor
,
total_blocks
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
#@torch.compile(fullgraph=True, mode="reduce-overhead")
def
physical_to_logical_mapping
(
block_table
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
block_size
:
int
,
total_blocks
:
int
)
->
torch
.
Tensor
:
"""
Creates an inverse mapping from physical block locations to logical indices.
...
...
@@ -114,13 +137,38 @@ def physical_to_logical_mapping(
If a physical block is not mapped to by any logical block,
its value in the result will be -1.
IMPORTANT: Garbage Value Protection
────────────────────────────────────
The block_table tensor may contain garbage values in unused positions
(beyond the actual sequence length). For example, if a sequence only
needs 3 blocks but the table has space for 8:
block_table[0] = [10, 25, 7, 999, 1234, 888, ...]
^^^^^^^^^^^^^^^^^^^^
garbage values
These garbage values can cause issues because:
1. They may map to valid physical blocks by coincidence
2. The scatter_ operation will assign them logical indices
3. Later attention computations may incorrectly access these blocks
To prevent this, we use seq_lens and block_size to mask out unused
entries, ensuring only valid block references are processed.
Args:
block_table: Tensor of shape [max_reqs, max_num_blocks]
mapping logical blocks to physical locations
mapping logical blocks to physical locations. May contain
garbage values in unused positions.
seq_lens: Tensor of sequence lengths for each request. Used to
determine how many blocks are actually needed per sequence.
block_size: Size of each block in tokens. Used with seq_lens to
compute the number of valid blocks per sequence.
total_blocks: Total number of physical blocks available
Returns:
A tensor of shape [max_reqs, max_physical_block]
A tensor of shape [max_reqs, total_blocks] where each entry
physical_to_logical[req_id, physical_block] contains the logical
block index for that physical block, or -1 if unused.
"""
max_reqs
,
max_num_blocks
=
block_table
.
shape
device
=
block_table
.
device
...
...
@@ -130,17 +178,76 @@ def physical_to_logical_mapping(
dtype
=
torch
.
long
,
device
=
device
)
logical_indices
=
(
torch
.
arange
(
max_num_blocks
,
device
=
device
).
unsqueeze
(
0
).
expand
(
max_reqs
,
-
1
))
# Only process valid blocks to avoid garbage values
num_blocks_per_seq
=
cdiv
(
seq_lens
,
block_size
)
mask
=
torch
.
arange
(
max_num_blocks
,
device
=
device
)[
None
,
:]
<
num_blocks_per_seq
[:,
None
]
physical_to_logical
.
scatter_
(
-
1
,
block_table
.
to
(
torch
.
int64
),
logical_indices
)
# TODO Confirm - Seems like block 0 is always empty so we reset it manually
valid_block_table
=
torch
.
where
(
mask
,
block_table
,
0
)
valid_logical_indices
=
torch
.
where
(
mask
,
torch
.
arange
(
max_num_blocks
,
device
=
device
)[
None
,
:],
0
)
physical_to_logical
.
scatter_
(
-
1
,
valid_block_table
.
to
(
torch
.
int64
),
valid_logical_indices
)
# NB - Seems like block 0 is always empty so we reset it manually
physical_to_logical
[:,
0
]
=
-
1
return
physical_to_logical
def
unique_static_unsorted
(
x
:
torch
.
Tensor
,
*
,
M
:
int
,
# maximum positive value (0 is “skip me”)
dim
:
int
=
-
1
,
# axis along which to deduplicate
ignored_val
:
int
=
0
,
# value to ignore
pad_val
:
int
=
-
1
,
# sentinel for unused slots
)
->
torch
.
Tensor
:
"""
- Keeps the first occurrence of each non-zero value while preserving order,
then left-packs those uniques and fills the rest with `pad_val`.
- Returns (packed, keep_mask) with the *same shape* as `x`.
- Requires that all values be in the range [0, M]
- Skips ignored_val
Works on CPU or GPU, no Python loops, O(B·N) time / O(B·M) memory.
Example:
x =[3, 1, 0, 1, 2], M=3, ignored_val=0 => [3, 1, 2, -1, -1]
"""
if
not
(
-
1
<=
pad_val
<=
M
):
raise
ValueError
(
"`pad_val` must lie in [-1, M]"
)
# ── move `dim` to the end so we can treat tensor as [B, N] ──────────
dim
=
dim
%
x
.
ndim
x_perm
=
x
.
movedim
(
dim
,
-
1
)
# shape [..., N]
B
,
N
=
x_perm
.
numel
()
//
x_perm
.
shape
[
-
1
],
x_perm
.
shape
[
-
1
]
x_flat
=
x_perm
.
reshape
(
B
,
N
)
# [B, N]
device
=
x
.
device
idx
=
torch
.
arange
(
N
,
device
=
device
).
expand
(
B
,
N
)
# per-row indices
# ── build first-occurrence table for every v ∈ [0, M] ───────────────
first_idx
=
torch
.
full
((
B
,
M
+
1
),
N
,
device
=
device
)
# “∞”
# scatter_reduce_: first_idx[b, v] = min(first_idx[b, v], i) for each i
first_idx
.
scatter_reduce_
(
1
,
x_flat
,
idx
,
reduce
=
"amin"
)
# ── keep mask: first occurrence *and* value ≠ 0 ─────────────────────
keep
=
(
x_flat
!=
ignored_val
)
&
(
idx
==
first_idx
.
gather
(
1
,
x_flat
)
)
# [B, N]
# ── left-pack uniques into a fresh tensor ───────────────────────────
dest_pos
=
torch
.
cumsum
(
keep
.
to
(
torch
.
long
),
dim
=
1
)
-
1
# where to go
packed_flat
=
torch
.
full_like
(
x_flat
,
pad_val
)
rows
,
src_cols
=
torch
.
nonzero
(
keep
,
as_tuple
=
True
)
packed_flat
[
rows
,
dest_pos
[
rows
,
src_cols
]]
=
x_flat
[
rows
,
src_cols
]
# ── restore original layout ─────────────────────────────────────────
packed
=
packed_flat
.
reshape
(
x_perm
.
shape
).
movedim
(
-
1
,
dim
)
return
packed
def
causal_mask_mod
(
b
:
torch
.
Tensor
,
h
:
torch
.
Tensor
,
q_idx
:
torch
.
Tensor
,
kv_idx
:
torch
.
Tensor
):
return
q_idx
>=
kv_idx
...
...
@@ -170,6 +277,7 @@ class FlexAttentionMetadata:
num_reqs
:
int
physical_to_logical
:
torch
.
Tensor
decode_offset
:
torch
.
Tensor
num_blocks_per_seq
:
torch
.
Tensor
# For logging.
num_input_tokens
:
int
=
0
# Number of tokens including padding.
...
...
@@ -179,6 +287,46 @@ class FlexAttentionMetadata:
block_mask
:
Optional
[
BlockMask
]
=
None
score_mod
:
Optional
[
_score_mod_signature
]
=
None
logical_mask_mod
:
_mask_mod_signature
=
causal_mask_mod
doc_ids
:
Optional
[
torch
.
Tensor
]
=
None
direct_build
:
bool
=
True
q_block_size
:
int
=
16
kv_block_size
:
int
=
16
transformed_score_mod
:
Optional
[
_score_mod_signature
]
=
None
def
_convert_physical_to_logical
(
self
,
request_lookup
:
torch
.
Tensor
,
q_idx
:
torch
.
Tensor
,
physical_kv_idx
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""Convert physical indices to logical indices for both query and kv.
NB is_within_lower_bound: do sequences start on block_boundaries?
Returns:
tuple of (is_valid, logical_q_idx, logical_kv_idx)
"""
# Map query indices to corresponding request indices
q_req
=
request_lookup
[
q_idx
]
# Convert physical KV indices to logical indices
physical_kv_block
=
physical_kv_idx
//
self
.
block_size
physical_kv_offset
=
physical_kv_idx
%
self
.
block_size
logical_block_idx
=
self
.
physical_to_logical
[
q_req
,
physical_kv_block
]
logical_kv_idx
=
(
logical_block_idx
*
self
.
block_size
+
physical_kv_offset
)
# Determine valid kv indices
live_block
=
logical_block_idx
>=
0
within_upper_bound
=
logical_kv_idx
<
self
.
seq_lens
[
q_req
]
within_lower_bound
=
logical_kv_idx
>=
0
is_valid
=
live_block
&
within_upper_bound
&
within_lower_bound
# Convert physical query indices to logical indices
local_q_idx
=
q_idx
-
self
.
query_start_loc
[
q_req
]
logical_q_idx
=
local_q_idx
+
self
.
decode_offset
[
q_req
]
return
is_valid
,
logical_q_idx
,
logical_kv_idx
def
get_causal_mask_mod
(
self
)
->
_mask_mod_signature
:
"""Creates the mask_mod function for FlexAttention.
...
...
@@ -191,11 +339,8 @@ class FlexAttentionMetadata:
With this info we create the "logical" indices that are passed to
mask_mod functions. This allows mask mod functions to be agnostic to
layout of the query and key/value tensors.
TODO is_within_lower_bound: do sequences start on block_boundaries?
"""
# Create a lookup mapping from query indices -> request number
request_lookup
=
_offsets_to_doc_ids_tensor
(
self
.
query_start_loc
)
assert
self
.
doc_ids
is
not
None
def
final_mask_mod
(
b
:
torch
.
Tensor
,
...
...
@@ -203,27 +348,9 @@ class FlexAttentionMetadata:
q_idx
:
torch
.
Tensor
,
physical_kv_idx
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
# Map query indices to corresponding request indices
q_req
=
request_lookup
[
q_idx
]
# Convert physical KV indices to logical indices
physical_kv_block
=
physical_kv_idx
//
self
.
block_size
physical_kv_offset
=
physical_kv_idx
%
self
.
block_size
logical_block_idx
=
self
.
physical_to_logical
[
q_req
,
physical_kv_block
]
logical_kv_idx
=
logical_block_idx
*
self
.
block_size
+
physical_kv_offset
# noqa: E501
# Determine valid kv indices
live_block
=
logical_block_idx
>=
0
within_upper_bound
=
logical_kv_idx
<
self
.
seq_lens
[
q_req
]
within_lower_bound
=
logical_kv_idx
>=
0
is_valid
=
live_block
&
within_upper_bound
&
within_lower_bound
# Convert physical query indices to logical indices
local_q_idx
=
q_idx
-
self
.
query_start_loc
[
q_req
]
logical_q_idx
=
local_q_idx
+
self
.
decode_offset
[
q_req
]
(
is_valid
,
logical_q_idx
,
logical_kv_idx
)
=
self
.
_convert_physical_to_logical
(
self
.
doc_ids
,
q_idx
,
physical_kv_idx
)
# Apply mask modification only for valid indices
return
torch
.
where
(
is_valid
,
...
...
@@ -236,7 +363,7 @@ class FlexAttentionMetadata:
def
get_bidirectional_mask_mod
(
self
)
->
_mask_mod_signature
:
"""Creates the encoder mask_mod function for FlexAttention.
Since the encoder bidirectional attention doesn't run with
Since the encoder bidirectional attention doesn't run with
KV cache, this function creates a mask based on the
packed query sequences.
"""
...
...
@@ -253,6 +380,97 @@ class FlexAttentionMetadata:
return
final_mask_mod
def
get_transformed_score_mod
(
self
)
->
Optional
[
_score_mod_signature
]:
"""Creates the transformed score_mod function for FlexAttention.
This function wraps the user's score_mod to handle physical-to-logical
index conversion, similar to how get_mask_mod works for mask functions.
"""
if
self
.
score_mod
is
None
:
return
None
# Create a lookup mapping from query indices -> request number
request_lookup
=
_offsets_to_doc_ids_tensor
(
self
.
query_start_loc
)
user_score_mod
=
self
.
score_mod
def
transformed_score_mod
(
score
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
h
:
torch
.
Tensor
,
q_idx
:
torch
.
Tensor
,
physical_kv_idx
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
(
is_valid
,
logical_q_idx
,
logical_kv_idx
)
=
self
.
_convert_physical_to_logical
(
request_lookup
,
q_idx
,
physical_kv_idx
)
return
torch
.
where
(
is_valid
,
user_score_mod
(
score
,
b
,
h
,
logical_q_idx
,
logical_kv_idx
,
physical_q
=
q_idx
),
-
float
(
'inf'
))
return
transformed_score_mod
def
_build_block_mask_direct
(
self
)
->
BlockMask
:
"""Direct block mask construction for standard causal attention.
This method constructs the block mask directly using
BlockMask.from_kv_blocks which is much more efficient than the
generic create_block_mask approach.
The direct path works as follows:
1. For each query token, fetch blocks from block_table using max_seq_len
(this fetches more blocks than needed for shorter sequences)
2. Group query tokens into chunks of q_block_size
3. For each group, deduplicate the blocks using unique_static_unsorted
4. Create BlockMask using the deduplicated block indices
Over-estimation occurs when a group of q_block_size tokens contains
multiple sequence IDs (doc_ids). In this case, we fetch ALL blocks for
each sequence represented in the group, even though individual query
tokens may only need a subset of those blocks based on causal masking
and their position.
"""
page_to_block_ratio
=
self
.
kv_block_size
//
self
.
block_size
if
page_to_block_ratio
!=
1
:
raise
ValueError
(
f
"FlexAttention currently requires the cache block size "
f
"(
{
self
.
block_size
}
) to be equal to the kv_block_size "
f
"(
{
self
.
kv_block_size
}
). Please check your model's "
f
"configuration."
)
used_pages
=
self
.
block_table
[
self
.
doc_ids
,
:
cdiv
(
self
.
max_seq_len
,
self
.
block_size
)]
used_pages_padded
=
pad_to_multiple
(
used_pages
,
multiple
=
self
.
q_block_size
,
dim
=
0
)
used_pages_padded
=
used_pages_padded
.
reshape
(
used_pages_padded
.
shape
[
0
]
//
self
.
q_block_size
,
-
1
)
used_pages_padded
=
used_pages_padded
//
page_to_block_ratio
kv_indices
=
unique_static_unsorted
((
used_pages_padded
.
long
()),
M
=
self
.
num_blocks
).
to
(
torch
.
int32
)
kv_num_blocks
=
(
kv_indices
>=
0
).
sum
(
dim
=-
1
).
to
(
torch
.
int32
)
block_mask_kwargs
=
{
"seq_lengths"
:
(
self
.
num_actual_tokens
,
self
.
total_cache_tokens
),
"kv_num_blocks"
:
kv_num_blocks
[
None
,
None
],
"kv_indices"
:
kv_indices
[
None
,
None
],
"full_kv_num_blocks"
:
None
,
"full_kv_indices"
:
None
,
"BLOCK_SIZE"
:
(
self
.
q_block_size
,
self
.
kv_block_size
),
"mask_mod"
:
self
.
mask_mod
,
}
# compute_q_blocks parameter is available in PyTorch 2.9+
if
is_torch_equal_or_newer
(
"2.9.0.dev0"
):
block_mask_kwargs
[
"compute_q_blocks"
]
=
False
return
BlockMask
.
from_kv_blocks
(
**
block_mask_kwargs
)
def
build_block_mask
(
self
)
->
BlockMask
:
if
self
.
causal
:
mask_mod
=
self
.
get_causal_mask_mod
()
...
...
@@ -267,6 +485,7 @@ class FlexAttentionMetadata:
self
.
num_actual_tokens
,
kv_len
,
device
=
self
.
block_table
.
device
,
BLOCK_SIZE
=
(
self
.
q_block_size
,
self
.
kv_block_size
),
)
def
__post_init__
(
self
):
...
...
@@ -275,8 +494,21 @@ class FlexAttentionMetadata:
assert
self
.
cu_prefix_query_lens
is
None
,
"Not implemented yet."
assert
self
.
prefix_kv_lens
is
None
,
"Not implemented yet."
assert
self
.
suffix_kv_lens
is
None
,
"Not implemented yet."
# Create a lookup mapping from query indices -> request number
self
.
doc_ids
=
_offsets_to_doc_ids_tensor
(
self
.
query_start_loc
)
self
.
num_blocks
=
self
.
total_cache_tokens
//
self
.
block_size
self
.
block_mask
=
self
.
build_block_mask
()
if
self
.
causal
:
self
.
mask_mod
=
self
.
get_causal_mask_mod
()
else
:
self
.
mask_mod
=
self
.
get_bidirectional_mask_mod
()
self
.
transformed_score_mod
=
self
.
get_transformed_score_mod
()
if
self
.
direct_build
and
self
.
causal
:
self
.
block_mask
=
self
.
_build_block_mask_direct
()
else
:
self
.
block_mask
=
self
.
build_block_mask
()
class
FlexAttentionMetadataBuilder
(
...
...
@@ -287,15 +519,24 @@ class FlexAttentionMetadataBuilder(
self
.
model_config
=
vllm_config
.
model_config
self
.
parallel_config
=
vllm_config
.
parallel_config
self
.
cache_config
=
vllm_config
.
cache_config
self
.
device
=
device
self
.
num_heads_q
=
self
.
model_config
.
get_num_attention_heads
(
vllm_config
.
parallel_config
)
self
.
parallel_config
)
self
.
num_heads_kv
=
self
.
model_config
.
get_num_kv_heads
(
vllm_config
.
parallel_config
)
self
.
parallel_config
)
self
.
headdim
=
self
.
model_config
.
get_head_size
()
self
.
block_size
=
kv_cache_spec
.
block_size
self
.
kv_cache_spec
=
kv_cache_spec
self
.
device
=
device
self
.
direct_build
:
bool
=
is_torch_equal_or_newer
(
"2.9.0.dev0"
)
self
.
q_block_size
:
int
=
16
if
is_torch_equal_or_newer
(
"2.9.0.dev0"
)
else
128
self
.
kv_block_size
:
int
=
16
if
is_torch_equal_or_newer
(
"2.9.0.dev0"
)
else
128
def
reorder_batch
(
self
,
input_batch
:
"InputBatch"
,
scheduler_output
:
"SchedulerOutput"
)
->
bool
:
return
False
def
build
(
self
,
common_prefix_len
:
int
,
...
...
@@ -310,6 +551,7 @@ class FlexAttentionMetadataBuilder(
seq_lens
=
common_attn_metadata
.
seq_lens
block_table_tensor
=
common_attn_metadata
.
block_table_tensor
slot_mapping
=
common_attn_metadata
.
slot_mapping
num_blocks_per_seq
=
cdiv
(
seq_lens
,
self
.
block_size
)
use_cascade
=
common_prefix_len
>
0
cu_prefix_query_lens
=
None
...
...
@@ -320,12 +562,15 @@ class FlexAttentionMetadataBuilder(
block_size
=
self
.
kv_cache_spec
.
block_size
max_possible_seq_len
=
self
.
model_config
.
max_model_len
total_cache_tokens
=
self
.
cache_config
.
num_gpu_blocks
*
block_size
num_gpu_blocks
=
self
.
cache_config
.
num_gpu_blocks
assert
num_gpu_blocks
is
not
None
,
\
"FlexAttention requires num_gpu_blocks to be set"
total_cache_tokens
=
(
num_gpu_blocks
*
block_size
)
inverse_block_table
=
physical_to_logical_mapping
(
block_table_tensor
,
se
lf
.
cache_config
.
num_gpu_blocks
)
block_table_tensor
,
se
q_lens
,
block_size
,
num_gpu_blocks
)
# Get the original offset tensor
offset_tensor
=
common_attn_metadata
.
num_computed_tokens_cpu
.
to
(
self
.
device
,
non_blocking
=
True
)
...
...
@@ -349,9 +594,16 @@ class FlexAttentionMetadataBuilder(
physical_to_logical
=
inverse_block_table
,
total_cache_tokens
=
total_cache_tokens
,
decode_offset
=
offset_tensor
,
num_blocks_per_seq
=
num_blocks_per_seq
,
direct_build
=
self
.
direct_build
,
q_block_size
=
self
.
q_block_size
,
kv_block_size
=
self
.
kv_block_size
,
)
return
out
def
use_cascade_attention
(
self
,
*
args
,
**
kwargs
)
->
bool
:
return
False
class
FlexAttentionImpl
(
AttentionImpl
):
sliding_window
:
Optional
[
tuple
[
int
,
int
]]
...
...
@@ -370,6 +622,7 @@ class FlexAttentionImpl(AttentionImpl):
logits_soft_cap
:
Optional
[
float
]
=
None
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
kv_sharing_target_layer_name
:
Optional
[
str
]
=
None
,
**
kwargs
,
)
->
None
:
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
...
...
@@ -398,6 +651,7 @@ class FlexAttentionImpl(AttentionImpl):
raise
NotImplementedError
(
"FlexAttention does not support logits soft cap yet."
)
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
if
kv_sharing_target_layer_name
is
not
None
:
...
...
@@ -405,7 +659,6 @@ class FlexAttentionImpl(AttentionImpl):
"FlexAttention does not support kv sharing yet."
)
FlexAttentionBackend
.
validate_head_size
(
head_size
)
if
is_quantized_kv_cache
(
self
.
kv_cache_dtype
):
raise
NotImplementedError
(
"FlexAttention does not support quantized kv-cache. Yet"
)
...
...
@@ -493,35 +746,48 @@ class FlexAttentionImpl(AttentionImpl):
# Doesn't work for now -> constraint violation
# torch._dynamo.try_mark_dynamic(query, 2)
# default M=64, N=64 may run out of shared memory on some GPUs
# TODO: Explicit configs for each GPU?
# Not sure how to calculate the shared memory requirement
extra_kernel_options
=
defaultdict
[
str
,
int
](
lambda
:
64
)
if
query
.
dtype
==
torch
.
float32
:
extra_kernel_options
[
"BLOCK_M"
]
//=
2
extra_kernel_options
[
"BLOCK_N"
]
//=
2
if
current_platform
.
is_cuda
():
device_props
=
torch
.
cuda
.
get_device_properties
()
max_shared_memory
=
device_props
.
shared_memory_per_block_optin
if
max_shared_memory
<
144
*
1024
:
extra_kernel_options
[
"BLOCK_M"
]
//=
2
extra_kernel_options
[
"BLOCK_N"
]
//=
2
assert
attn_metadata
.
block_mask
is
not
None
block_m
,
block_n
=
attn_metadata
.
block_mask
.
BLOCK_SIZE
kernel_options
=
get_kernel_options
(
query
,
block_m
,
block_n
,
attn_metadata
.
direct_build
)
out
=
flex_attention_compiled
(
query
,
key_tensor
,
value_tensor
,
attn_metadata
.
score_mod
,
attn_metadata
.
transformed_
score_mod
,
attn_metadata
.
block_mask
,
self
.
scale
,
enable_gqa
=
enable_gqa
,
kernel_options
=
{
"FORCE_USE_FLEX_ATTENTION"
:
True
,
**
extra_kernel_options
},
kernel_options
=
kernel_options
,
)
# Flex doesn't have an out variant today, rely on epilogue fusion
out
=
out
.
permute
(
0
,
2
,
1
,
3
).
squeeze
(
0
)
output
[:
num_actual_tokens
,
:,
:].
copy_
(
out
)
return
output
def
get_kernel_options
(
query
,
block_m
,
block_n
,
use_direct_build
:
bool
)
->
dict
[
str
,
Union
[
int
,
bool
]]:
kernel_options
:
dict
[
str
,
Union
[
int
,
bool
]]
=
{
"FORCE_USE_FLEX_ATTENTION"
:
True
,
}
if
use_direct_build
:
kernel_options
[
"BLOCK_M"
]
=
block_m
kernel_options
[
"BLOCK_N"
]
=
block_n
return
kernel_options
else
:
kernel_options
[
"BLOCK_M"
]
=
64
kernel_options
[
"BLOCK_N"
]
=
64
if
query
.
dtype
==
torch
.
float32
:
kernel_options
[
"BLOCK_M"
]
=
32
kernel_options
[
"BLOCK_N"
]
=
32
# if current_platform.is_cuda():
if
torch
.
cuda
.
is_available
():
device_props
=
torch
.
cuda
.
get_device_properties
()
max_shared_memory
=
device_props
.
shared_memory_per_block_optin
if
max_shared_memory
<
144
*
1024
:
kernel_options
[
"BLOCK_M"
]
=
32
kernel_options
[
"BLOCK_N"
]
=
32
return
kernel_options
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment