Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
61b8cea3
Unverified
Commit
61b8cea3
authored
Jul 24, 2025
by
Lucas Wilkinson
Committed by
GitHub
Jul 24, 2025
Browse files
[Attention] Optimize FlashInfer MetadataBuilder Build call (#21137)
Signed-off-by:
Lucas Wilkinson
<
lwilkins@redhat.com
>
parent
526078a9
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
94 additions
and
78 deletions
+94
-78
tests/v1/attention/test_attention_backends.py
tests/v1/attention/test_attention_backends.py
+10
-3
tests/v1/attention/utils.py
tests/v1/attention/utils.py
+1
-1
vllm/v1/attention/backends/flashinfer.py
vllm/v1/attention/backends/flashinfer.py
+83
-74
No files found.
tests/v1/attention/test_attention_backends.py
View file @
61b8cea3
...
...
@@ -11,7 +11,8 @@ from tests.v1.attention.utils import (BatchSpec, _Backend,
create_vllm_config
,
get_attention_backend
)
from
vllm.utils
import
STR_DTYPE_TO_TORCH_DTYPE
,
cdiv
from
vllm.v1.attention.backends.utils
import
CommonAttentionMetadata
from
vllm.v1.attention.backends.utils
import
(
CommonAttentionMetadata
,
set_kv_cache_layout
)
from
vllm.v1.kv_cache_interface
import
FullAttentionSpec
BACKENDS_TO_TEST
=
[
...
...
@@ -212,7 +213,7 @@ def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec,
from
vllm.v1.attention.backends.flashinfer
import
PerLayerParameters
def
mock_get_per_layer_parameters
(
vllm_config
):
def
mock_get_per_layer_parameters
(
vllm_config
,
impl_cls
):
# Return mock parameters for a single layer
head_size
=
vllm_config
.
model_config
.
get_head_size
()
return
{
...
...
@@ -297,7 +298,8 @@ def test_backend_correctness(batch_spec_name: str, model: str):
5. Comparing the vLLM backend's output to the ground-truth SDPA output.
"""
batch_spec
=
BATCH_SPECS
[
batch_spec_name
]
vllm_config
=
create_vllm_config
(
model_name
=
model
)
vllm_config
=
create_vllm_config
(
model_name
=
model
,
max_model_len
=
max
(
batch_spec
.
seq_lens
))
device
=
torch
.
device
(
"cuda:0"
)
kv_cache_spec
=
create_standard_kv_cache_spec
(
vllm_config
)
...
...
@@ -419,6 +421,11 @@ def test_backend_correctness(batch_spec_name: str, model: str):
if
backend_name
==
_Backend
.
FLASHINFER_VLLM_V1
:
kv_cache_for_backend
=
kv_cache
.
transpose
(
0
,
1
)
# For FlashInfer default to HND layout and
kv_cache_for_backend
=
kv_cache_for_backend
.
transpose
(
2
,
3
).
contiguous
().
transpose
(
2
,
3
)
set_kv_cache_layout
(
"HND"
)
backend_output
=
run_attention_backend
(
backend_name
,
kv_cache_spec
,
vllm_config
,
device
,
common_attn_metadata
,
...
...
tests/v1/attention/utils.py
View file @
61b8cea3
...
...
@@ -66,7 +66,7 @@ def create_common_attn_metadata(
num_computed_tokens_cpu
=
torch
.
tensor
(
context_lens
,
dtype
=
torch
.
int32
)
# Create block table (random for testing)
max_blocks
=
max
(
batch_spec
.
seq_lens
)
//
block_size
+
1
max_blocks
=
(
max
(
batch_spec
.
seq_lens
)
+
block_size
-
1
)
//
block_size
block_table_tensor
=
torch
.
randint
(
0
,
max_block_idx
,
(
batch_spec
.
batch_size
,
max_blocks
),
...
...
vllm/v1/attention/backends/flashinfer.py
View file @
61b8cea3
...
...
@@ -18,6 +18,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.utils
import
cdiv
from
vllm.v1.attention.backends.flash_attn
import
use_cascade_attention
from
vllm.v1.attention.backends.utils
import
(
AttentionMetadataBuilder
,
CommonAttentionMetadata
,
PerLayerParameters
,
...
...
@@ -158,7 +159,7 @@ class FlashInferMetadata:
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
qo_indptr
:
torch
.
Tensor
qo_indptr
_cpu
:
torch
.
Tensor
# An example for paged_kv_indices, paged_kv_indptr:
# request 1, page indices [0, 5, 8]
# request 2, page indices [1, 6, 7]
...
...
@@ -167,13 +168,13 @@ class FlashInferMetadata:
# [0, 5, 8, 1, 6, 7, 3, 4]
# paged_kv_indptr is used to index into paged_kv_indices:
# [0, 3, 6, 8]
# The indptr of the paged kv cache, shape: [batch_size + 1]
paged_kv_indptr
:
torch
.
Tensor
# The page indices of the paged kv cache
# The indptr of the paged kv cache, shape: [batch_size + 1]
(CPU for plan)
paged_kv_indptr
_cpu
:
torch
.
Tensor
# The page indices of the paged kv cache
(on device for plan)
paged_kv_indices
:
torch
.
Tensor
# The number of entries in the last page of each request in
# the paged kv cache, shape: [batch_size]
paged_kv_last_page_len
:
torch
.
Tensor
# the paged kv cache, shape: [batch_size]
(CPU for plan)
paged_kv_last_page_len
_cpu
:
torch
.
Tensor
# The number of query/output heads
num_qo_heads
:
int
# The number of key/value heads
...
...
@@ -201,22 +202,17 @@ class FlashInferMetadata:
num_prefills
:
int
num_prefill_tokens
:
int
# For cascade attention.
# For cascade attention
(CPU for planning)
.
use_cascade
:
bool
shared_qo_indptr
:
Optional
[
torch
.
Tensor
]
=
None
shared_kv_page_indptr
:
Optional
[
torch
.
Tensor
]
=
None
shared_kv_page_indices
:
Optional
[
torch
.
Tensor
]
=
None
shared_kv_last_page_len
:
Optional
[
torch
.
Tensor
]
=
None
shared_qo_indptr
_cpu
:
Optional
[
torch
.
Tensor
]
=
None
shared_kv_page_indptr
_cpu
:
Optional
[
torch
.
Tensor
]
=
None
shared_kv_page_indices
_cpu
:
Optional
[
torch
.
Tensor
]
=
None
shared_kv_last_page_len
_cpu
:
Optional
[
torch
.
Tensor
]
=
None
prefill_wrapper
:
Optional
[
BatchPrefillWithPagedKVCacheWrapper
]
=
None
decode_wrapper
:
Optional
[
BatchDecodeWithPagedKVCacheWrapper
]
=
None
cascade_wrapper
:
Optional
[
MultiLevelCascadeAttentionWrapper
]
=
None
@
property
def
query_start_loc
(
self
):
# The GPUModelRunner expects to be able to access this property.
return
self
.
qo_indptr
def
__post_init__
(
self
):
if
self
.
head_dim
is
not
None
:
FlashInferBackend
.
validate_head_size
(
self
.
head_dim
)
...
...
@@ -238,6 +234,12 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self
.
vllm_config
=
vllm_config
self
.
cache_config
=
vllm_config
.
cache_config
self
.
kv_cache_spec
=
kv_cache_spec
max_num_blocks_per_request
=
cdiv
(
vllm_config
.
model_config
.
max_model_len
,
self
.
kv_cache_spec
.
block_size
)
self
.
block_table_arange
=
torch
.
arange
(
max_num_blocks_per_request
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
def
reorder_batch
(
self
,
input_batch
:
InputBatch
,
scheduler_output
:
SchedulerOutput
)
->
bool
:
...
...
@@ -285,21 +287,25 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
if
self
.
global_hyperparameters
is
None
:
self
.
global_hyperparameters
=
infer_global_hyperparameters
(
get_per_layer_parameters
(
self
.
vllm_config
,
FlashInferImpl
))
if
attn_metadata
.
use_cascade
:
attn_metadata
.
cascade_wrapper
=
self
.
_get_cascade_wrapper
()
attn_metadata
.
cascade_wrapper
.
plan
(
[
attn_metadata
.
shared_qo_indptr
,
attn_metadata
.
qo_indptr
],
[
attn_metadata
.
shared_kv_page_indptr
,
attn_metadata
.
paged_kv_indptr
attn_metadata
.
shared_qo_indptr_cpu
,
attn_metadata
.
qo_indptr_cpu
],
[
attn_metadata
.
shared_kv_page_indptr_cpu
,
attn_metadata
.
paged_kv_indptr_cpu
],
[
attn_metadata
.
shared_kv_page_indices
,
attn_metadata
.
shared_kv_page_indices
_cpu
,
attn_metadata
.
paged_kv_indices
],
[
attn_metadata
.
shared_kv_last_page_len
,
attn_metadata
.
paged_kv_last_page_len
attn_metadata
.
shared_kv_last_page_len
_cpu
,
attn_metadata
.
paged_kv_last_page_len
_cpu
],
attn_metadata
.
num_qo_heads
,
attn_metadata
.
num_kv_heads
,
...
...
@@ -320,22 +326,22 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
# Decodes are first so prefills start after the last decode
prefill_start
=
num_decodes
attn_metadata
.
prefill_wrapper
=
self
.
_get_prefill_wrapper
()
assert
attn_metadata
.
qo_indptr
[
prefill_start
:].
shape
[
assert
attn_metadata
.
qo_indptr
_cpu
[
prefill_start
:].
shape
[
0
]
==
num_prefills
+
1
assert
attn_metadata
.
paged_kv_indptr
[
prefill_start
:].
shape
[
assert
attn_metadata
.
paged_kv_indptr
_cpu
[
prefill_start
:].
shape
[
0
]
==
num_prefills
+
1
assert
attn_metadata
.
paged_kv_last_page_len
[
assert
attn_metadata
.
paged_kv_last_page_len
_cpu
[
prefill_start
:].
shape
[
0
]
==
num_prefills
# Since prefill_wrapper.run() will be called with
# query[num_decode_tokens:] we need to adjust the qo_indptr
# to be relative to the start of the prefill queries.
qo_indptr
=
attn_metadata
.
qo_indptr
[
prefill_start
:]
-
attn_metadata
.
qo_indptr
[
prefill_start
]
qo_indptr
_cpu
=
attn_metadata
.
qo_indptr
_cpu
[
prefill_start
:]
-
attn_metadata
.
qo_indptr
_cpu
[
prefill_start
]
attn_metadata
.
prefill_wrapper
.
plan
(
qo_indptr
,
attn_metadata
.
paged_kv_indptr
[
prefill_start
:],
qo_indptr
_cpu
,
attn_metadata
.
paged_kv_indptr
_cpu
[
prefill_start
:],
attn_metadata
.
paged_kv_indices
,
attn_metadata
.
paged_kv_last_page_len
[
prefill_start
:],
attn_metadata
.
paged_kv_last_page_len
_cpu
[
prefill_start
:],
attn_metadata
.
num_qo_heads
,
attn_metadata
.
num_kv_heads
,
attn_metadata
.
head_dim
,
...
...
@@ -357,9 +363,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
attn_metadata
.
num_qo_heads
,
attn_metadata
.
num_kv_heads
,
attn_metadata
.
head_dim
):
attn_metadata
.
decode_wrapper
.
plan
(
attn_metadata
.
paged_kv_indptr
[:
num_decodes
+
1
],
attn_metadata
.
paged_kv_indptr
_cpu
[:
num_decodes
+
1
],
attn_metadata
.
paged_kv_indices
,
attn_metadata
.
paged_kv_last_page_len
[:
num_decodes
],
attn_metadata
.
paged_kv_last_page_len
_cpu
[:
num_decodes
],
attn_metadata
.
num_qo_heads
,
attn_metadata
.
num_kv_heads
,
attn_metadata
.
head_dim
,
...
...
@@ -383,55 +389,58 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
split_decodes_and_prefills
(
common_attn_metadata
)
page_size
=
self
.
kv_cache_spec
.
block_size
device
=
self
.
device
qo_indptr
=
common_attn_metadata
.
query_start_loc
max_seq_len
=
common_attn_metadata
.
seq_lens_cpu
.
max
()
seq_lens
=
common_attn_metadata
.
seq_lens
seq_lens_cpu
=
common_attn_metadata
.
seq_lens_cpu
block_table_tensor
=
common_attn_metadata
.
block_table_tensor
block_table_bounds
=
(
seq_lens
+
page_size
-
1
)
//
page_size
block_table_bounds
_cpu
=
(
seq_lens
_cpu
+
page_size
-
1
)
//
page_size
use_cascade
=
common_prefix_len
>
0
if
use_cascade
:
# Grab the blocks of the shared prefix from the first request.
assert
common_prefix_len
%
page_size
==
0
num_common_kv_blocks
=
common_prefix_len
//
page_size
shared_qo_indptr
=
torch
.
tensor
([
0
,
num_actual_tokens
],
# Create CPU versions directly for cascade (no GPU versions needed)
shared_qo_indptr_cpu
=
torch
.
tensor
([
0
,
num_actual_tokens
],
dtype
=
torch
.
int32
,
device
=
device
)
shared_kv_page_indptr
=
torch
.
tensor
([
0
,
num_common_kv_blocks
],
device
=
'cpu'
)
shared_kv_page_indptr
_cpu
=
torch
.
tensor
([
0
,
num_common_kv_blocks
],
dtype
=
torch
.
int32
,
device
=
device
)
shared_kv_page_indices
=
block_table_tensor
[
device
=
'cpu'
)
shared_kv_page_indices
_cpu
=
block_table_tensor
[
0
,
:
num_common_kv_blocks
]
shared_kv_last_page_len
=
torch
.
tensor
([
page_size
],
shared_kv_last_page_len
_cpu
=
torch
.
tensor
([
page_size
],
dtype
=
torch
.
int32
,
device
=
device
)
device
=
'cpu'
)
# Remove the blocks of the shared prefix from all requests.
block_table_tensor
=
block_table_tensor
[:,
num_common_kv_blocks
:]
block_table_bounds
-=
num_common_kv_blocks
block_table_bounds
_cpu
-=
num_common_kv_blocks
else
:
shared_qo_indptr
=
None
shared_kv_page_indptr
=
None
shared_kv_page_indices
=
None
shared_kv_last_page_len
=
None
mask
=
(
torch
.
arange
(
block_table_tensor
.
size
(
1
),
dtype
=
block_table_tensor
.
dtype
,
device
=
block_table_tensor
.
device
).
unsqueeze
(
0
)
shared_qo_indptr_cpu
=
None
shared_kv_page_indptr_cpu
=
None
shared_kv_page_indices_cpu
=
None
shared_kv_last_page_len_cpu
=
None
max_num_blocks
=
block_table_bounds_cpu
.
max
()
block_table_bounds
=
block_table_bounds_cpu
.
to
(
self
.
device
,
non_blocking
=
True
)
mask
=
(
self
.
block_table_arange
[:
max_num_blocks
].
unsqueeze
(
0
)
<
block_table_bounds
.
unsqueeze
(
1
))
paged_kv_indices
=
block_table_tensor
[
mask
]
paged_kv_indptr
=
torch
.
cat
([
torch
.
zeros
(
1
,
dtype
=
block_table_bounds
.
dtype
,
device
=
block_table_bounds
.
device
),
block_table_bounds
.
cumsum
(
dim
=
0
,
dtype
=
torch
.
int32
)
])
paged_kv_last_page_len
=
seq_lens
%
page_size
paged_kv_last_page_len
=
torch
.
where
(
paged_kv_last_page_len
==
0
,
page_size
,
paged_kv_last_page_len
)
paged_kv_indices
=
block_table_tensor
[
:,
:
max_num_blocks
][
mask
]
paged_kv_indptr
_cpu
=
torch
.
zeros
(
len
(
block_table_bounds_cpu
)
+
1
,
dtype
=
torch
.
int32
,
device
=
'cpu'
)
paged_kv_indptr_cpu
[
1
:]
=
block_table_bounds
_cpu
.
cumsum
(
dim
=
0
,
dtype
=
torch
.
int32
)
paged_kv_last_page_len_cpu
=
seq_lens_cpu
%
page_size
paged_kv_last_page_len
_cpu
=
torch
.
where
(
paged_kv_last_page_len
_cpu
==
0
,
page_size
,
paged_kv_last_page_len
_cpu
)
cache_dtype
=
self
.
cache_config
.
cache_dtype
if
cache_dtype
.
startswith
(
"fp8"
):
kv_cache_dtype
=
FlashInferBackend
.
get_fp8_dtype_for_flashinfer
(
...
...
@@ -440,10 +449,10 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
kv_cache_dtype
=
self
.
kv_cache_spec
.
dtype
attn_metadata
=
FlashInferMetadata
(
num_actual_tokens
=
num_actual_tokens
,
qo_indptr
=
qo_indptr
,
paged_kv_indptr
=
paged_kv_indptr
,
qo_indptr
_cpu
=
common_attn_metadata
.
query_start_loc_cpu
,
paged_kv_indptr
_cpu
=
paged_kv_indptr
_cpu
,
paged_kv_indices
=
paged_kv_indices
,
paged_kv_last_page_len
=
paged_kv_last_page_len
,
paged_kv_last_page_len
_cpu
=
paged_kv_last_page_len
_cpu
,
num_qo_heads
=
self
.
vllm_config
.
model_config
.
get_num_attention_heads
(
self
.
vllm_config
.
parallel_config
),
num_kv_heads
=
self
.
kv_cache_spec
.
num_kv_heads
,
...
...
@@ -457,14 +466,14 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
num_prefills
=
num_prefills
,
num_prefill_tokens
=
num_prefill_tokens
,
use_cascade
=
use_cascade
,
shared_qo_indptr
=
shared_qo_indptr
,
shared_kv_page_indptr
=
shared_kv_page_indptr
,
shared_kv_page_indices
=
shared_kv_page_indices
,
shared_kv_last_page_len
=
shared_kv_last_page_len
,
shared_qo_indptr
_cpu
=
shared_qo_indptr
_cpu
,
shared_kv_page_indptr
_cpu
=
shared_kv_page_indptr
_cpu
,
shared_kv_page_indices
_cpu
=
shared_kv_page_indices
_cpu
,
shared_kv_last_page_len
_cpu
=
shared_kv_last_page_len
_cpu
,
max_seq_len
=
max_seq_len
,
seq_lens
=
seq_lens
,
block_table_tensor
=
block_table_tensor
,
workspace_buffer
=
self
.
_workspace_buffer
,
workspace_buffer
=
self
.
_
get_
workspace_buffer
()
,
)
self
.
_plan
(
num_prefills
,
num_decodes
,
attn_metadata
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment