Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
efc88cf6
Unverified
Commit
efc88cf6
authored
Aug 25, 2025
by
Woosuk Kwon
Committed by
GitHub
Aug 25, 2025
Browse files
[Misc] Simplify FlashInfer attention metadata (#23585)
Signed-off-by:
Woosuk Kwon
<
woosuk@thinkingmachines.ai
>
parent
7b6a8372
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
114 additions
and
163 deletions
+114
-163
vllm/v1/attention/backends/flashinfer.py
vllm/v1/attention/backends/flashinfer.py
+114
-163
No files found.
vllm/v1/attention/backends/flashinfer.py
View file @
efc88cf6
...
...
@@ -123,29 +123,9 @@ class FlashInferMetadata:
num_actual_tokens
:
int
# Number of tokens excluding padding.
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
qo_indptr_cpu
:
torch
.
Tensor
# An example for paged_kv_indices, paged_kv_indptr:
# request 1, page indices [0, 5, 8]
# request 2, page indices [1, 6, 7]
# request 3, page indices [3, 4]
# paged_kv_indices is a concatenation of page indices of all requests:
# [0, 5, 8, 1, 6, 7, 3, 4]
# paged_kv_indptr is used to index into paged_kv_indices:
# [0, 3, 6, 8]
# The indptr of the paged kv cache, shape: [batch_size + 1] (CPU for plan)
paged_kv_indptr_cpu
:
torch
.
Tensor
# The page indices of the paged kv cache (on device for plan)
paged_kv_indices
:
torch
.
Tensor
# The number of entries in the last page of each request in
# the paged kv cache, shape: [batch_size] (CPU for plan)
paged_kv_last_page_len_cpu
:
torch
.
Tensor
# The data type of the query
q_data_type
:
torch
.
dtype
seq_lens_cpu
:
torch
.
Tensor
slot_mapping
:
torch
.
Tensor
# For flashinfer trtllm batch decode
...
...
@@ -164,10 +144,6 @@ class FlashInferMetadata:
# For cascade attention (CPU for planning).
use_cascade
:
bool
shared_qo_indptr_cpu
:
Optional
[
torch
.
Tensor
]
=
None
shared_kv_page_indptr_cpu
:
Optional
[
torch
.
Tensor
]
=
None
shared_kv_page_indices_cpu
:
Optional
[
torch
.
Tensor
]
=
None
shared_kv_last_page_len_cpu
:
Optional
[
torch
.
Tensor
]
=
None
prefill_wrapper
:
Optional
[
BatchPrefillWithPagedKVCacheWrapper
]
=
None
decode_wrapper
:
Optional
[
BatchDecodeWithPagedKVCacheWrapper
]
=
None
...
...
@@ -327,134 +303,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
2
,
self
.
_get_workspace_buffer
(),
get_kv_cache_layout
())
return
self
.
_cascade_wrapper
def
_plan
(
self
,
attn_metadata
:
FlashInferMetadata
):
if
attn_metadata
.
use_cascade
:
attn_metadata
.
cascade_wrapper
=
self
.
_get_cascade_wrapper
()
attn_metadata
.
cascade_wrapper
.
plan
(
[
attn_metadata
.
shared_qo_indptr_cpu
,
attn_metadata
.
qo_indptr_cpu
],
[
attn_metadata
.
shared_kv_page_indptr_cpu
,
attn_metadata
.
paged_kv_indptr_cpu
],
[
attn_metadata
.
shared_kv_page_indices_cpu
,
attn_metadata
.
paged_kv_indices
],
[
attn_metadata
.
shared_kv_last_page_len_cpu
,
attn_metadata
.
paged_kv_last_page_len_cpu
],
self
.
num_qo_heads
,
self
.
num_kv_heads
,
self
.
head_dim
,
self
.
page_size
,
causal
=
True
,
sm_scale
=
self
.
global_hyperparameters
.
sm_scale
,
window_left
=
self
.
global_hyperparameters
.
window_left
,
logits_soft_cap
=
self
.
global_hyperparameters
.
logits_soft_cap
,
q_data_type
=
self
.
q_data_type
,
kv_data_type
=
self
.
kv_cache_dtype
,
)
else
:
# Regular attention (common case).
# Decodes are at the front and prefills are at the back,
# according to reorder_batch()
num_prefills
=
attn_metadata
.
num_prefills
num_decodes
=
attn_metadata
.
num_decodes
if
num_prefills
>
0
:
# Decodes are first so prefills start after the last decode
prefill_start
=
num_decodes
attn_metadata
.
prefill_wrapper
=
self
.
_get_prefill_wrapper
()
assert
attn_metadata
.
qo_indptr_cpu
[
prefill_start
:].
shape
[
0
]
==
num_prefills
+
1
assert
attn_metadata
.
paged_kv_indptr_cpu
[
prefill_start
:].
shape
[
0
]
==
num_prefills
+
1
assert
attn_metadata
.
paged_kv_last_page_len_cpu
[
prefill_start
:].
shape
[
0
]
==
num_prefills
# Since prefill_wrapper.run() will be called with
# query[num_decode_tokens:] we need to adjust the qo_indptr
# to be relative to the start of the prefill queries.
qo_indptr_cpu
=
attn_metadata
.
qo_indptr_cpu
[
prefill_start
:]
-
attn_metadata
.
qo_indptr_cpu
[
prefill_start
]
paged_kv_indptr_cpu
=
attn_metadata
.
paged_kv_indptr_cpu
[
prefill_start
:]
if
not
attn_metadata
.
prefill_use_trtllm
:
attn_metadata
.
prefill_wrapper
.
plan
(
qo_indptr_cpu
,
paged_kv_indptr_cpu
,
attn_metadata
.
paged_kv_indices
,
attn_metadata
.
paged_kv_last_page_len_cpu
[
prefill_start
:],
self
.
num_qo_heads
,
self
.
num_kv_heads
,
self
.
head_dim
,
self
.
page_size
,
causal
=
True
,
sm_scale
=
self
.
global_hyperparameters
.
sm_scale
,
window_left
=
self
.
global_hyperparameters
.
window_left
,
logits_soft_cap
=
self
.
global_hyperparameters
.
logits_soft_cap
,
q_data_type
=
self
.
q_data_type
,
kv_data_type
=
self
.
kv_cache_dtype
,
)
else
:
attn_metadata
.
qo_indptr_gpu
=
qo_indptr_cpu
.
to
(
self
.
device
)
attn_metadata
.
paged_kv_indptr_gpu
=
paged_kv_indptr_cpu
.
to
(
self
.
device
)
if
num_decodes
>
0
:
pure_decode
=
num_prefills
==
0
# possible required padding for cudagraph replay
use_cudagraph
=
(
self
.
enable_cuda_graph
and
pure_decode
and
num_decodes
<=
self
.
_decode_cudagraph_max_bs
)
if
use_cudagraph
:
num_input_tokens
=
(
self
.
vllm_config
.
pad_for_cudagraph
(
num_decodes
))
# Carefully fulfill the padding region with reasonable value
# on cpu.
# Make sure paged_kv_indptr_cpu is not decreasing
self
.
paged_kv_indptr_cpu
[
1
+
num_decodes
:
1
+
num_input_tokens
].
fill_
(
attn_metadata
.
paged_kv_indptr_cpu
[
-
1
])
# Fill the remaining paged_kv_last_page_len_cpu with 1.
# This is because flashinfer treats 0 as a full page
# instead of empty.
self
.
paged_kv_last_page_len_cpu
[
num_decodes
:
num_input_tokens
].
fill_
(
1
)
else
:
num_input_tokens
=
num_decodes
attn_metadata
.
decode_wrapper
=
self
.
_get_decode_wrapper
(
num_input_tokens
,
use_cudagraph
)
if
not
attn_metadata
.
decode_use_trtllm
:
# Use the persistent buffer with padding length,
# instead of the same address but chunked version
# in atten_metadata when using cudagraph.
fast_plan_decode
(
attn_metadata
.
decode_wrapper
,
self
.
paged_kv_indptr_cpu
[:
num_input_tokens
+
1
],
attn_metadata
.
paged_kv_indices
,
self
.
paged_kv_last_page_len_cpu
[:
num_input_tokens
],
attn_metadata
.
seq_lens_cpu
[:
num_input_tokens
],
self
.
num_qo_heads
,
self
.
num_kv_heads
,
self
.
head_dim
,
self
.
page_size
,
# Disable flashinfer's pos encoding and use vllm's rope.
pos_encoding_mode
=
"NONE"
,
sm_scale
=
self
.
global_hyperparameters
.
sm_scale
,
window_left
=
self
.
global_hyperparameters
.
window_left
,
logits_soft_cap
=
self
.
global_hyperparameters
.
logits_soft_cap
,
q_data_type
=
self
.
q_data_type
,
kv_data_type
=
self
.
kv_cache_dtype
,
)
def
build
(
self
,
common_prefix_len
:
int
,
common_attn_metadata
:
CommonAttentionMetadata
,
...
...
@@ -548,13 +396,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
attn_metadata
=
FlashInferMetadata
(
num_actual_tokens
=
num_actual_tokens
,
qo_indptr_cpu
=
common_attn_metadata
.
query_start_loc_cpu
,
paged_kv_indptr_cpu
=
self
.
paged_kv_indptr_cpu
[:
1
+
num_reqs
],
paged_kv_indices
=
paged_kv_indices
,
paged_kv_last_page_len_cpu
=
self
.
paged_kv_last_page_len_cpu
[:
num_reqs
],
q_data_type
=
self
.
q_data_type
,
seq_lens_cpu
=
seq_lens_cpu
,
slot_mapping
=
common_attn_metadata
.
slot_mapping
,
max_q_len
=
max_q_len
,
max_seq_len
=
max_seq_len
,
...
...
@@ -567,14 +409,123 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
num_prefills
=
num_prefills
,
num_prefill_tokens
=
num_prefill_tokens
,
use_cascade
=
use_cascade
,
shared_qo_indptr_cpu
=
shared_qo_indptr_cpu
,
shared_kv_page_indptr_cpu
=
shared_kv_page_indptr_cpu
,
shared_kv_page_indices_cpu
=
shared_kv_page_indices_cpu
,
shared_kv_last_page_len_cpu
=
shared_kv_last_page_len_cpu
,
)
self
.
_plan
(
attn_metadata
)
qo_indptr_cpu
=
common_attn_metadata
.
query_start_loc_cpu
paged_kv_indptr_cpu
=
self
.
paged_kv_indptr_cpu
[:
1
+
num_reqs
]
paged_kv_last_page_len_cpu
=
self
.
paged_kv_last_page_len_cpu
[:
num_reqs
]
if
attn_metadata
.
use_cascade
:
attn_metadata
.
cascade_wrapper
=
self
.
_get_cascade_wrapper
()
attn_metadata
.
cascade_wrapper
.
plan
(
[
shared_qo_indptr_cpu
,
qo_indptr_cpu
],
[
shared_kv_page_indptr_cpu
,
paged_kv_indptr_cpu
],
[
shared_kv_page_indices_cpu
,
paged_kv_indices
],
[
shared_kv_last_page_len_cpu
,
paged_kv_last_page_len_cpu
],
self
.
num_qo_heads
,
self
.
num_kv_heads
,
self
.
head_dim
,
self
.
page_size
,
causal
=
True
,
sm_scale
=
self
.
global_hyperparameters
.
sm_scale
,
window_left
=
self
.
global_hyperparameters
.
window_left
,
logits_soft_cap
=
self
.
global_hyperparameters
.
logits_soft_cap
,
q_data_type
=
self
.
q_data_type
,
kv_data_type
=
self
.
kv_cache_dtype
,
)
else
:
# Regular attention (common case).
# Decodes are at the front and prefills are at the back,
# according to reorder_batch()
num_prefills
=
attn_metadata
.
num_prefills
num_decodes
=
attn_metadata
.
num_decodes
if
num_prefills
>
0
:
# Decodes are first so prefills start after the last decode
prefill_start
=
num_decodes
attn_metadata
.
prefill_wrapper
=
self
.
_get_prefill_wrapper
()
assert
qo_indptr_cpu
[
prefill_start
:].
shape
[
0
]
==
num_prefills
+
1
assert
paged_kv_indptr_cpu
[
prefill_start
:].
shape
[
0
]
==
num_prefills
+
1
assert
paged_kv_last_page_len_cpu
[
prefill_start
:].
shape
[
0
]
==
num_prefills
# Since prefill_wrapper.run() will be called with
# query[num_decode_tokens:] we need to adjust the qo_indptr
# to be relative to the start of the prefill queries.
qo_indptr_cpu
=
qo_indptr_cpu
[
prefill_start
:]
-
qo_indptr_cpu
[
prefill_start
]
paged_kv_indptr_cpu
=
paged_kv_indptr_cpu
[
prefill_start
:]
if
not
attn_metadata
.
prefill_use_trtllm
:
attn_metadata
.
prefill_wrapper
.
plan
(
qo_indptr_cpu
,
paged_kv_indptr_cpu
,
paged_kv_indices
,
paged_kv_last_page_len_cpu
[
prefill_start
:],
self
.
num_qo_heads
,
self
.
num_kv_heads
,
self
.
head_dim
,
self
.
page_size
,
causal
=
True
,
sm_scale
=
self
.
global_hyperparameters
.
sm_scale
,
window_left
=
self
.
global_hyperparameters
.
window_left
,
logits_soft_cap
=
self
.
global_hyperparameters
.
logits_soft_cap
,
q_data_type
=
self
.
q_data_type
,
kv_data_type
=
self
.
kv_cache_dtype
,
)
else
:
attn_metadata
.
qo_indptr_gpu
=
qo_indptr_cpu
.
to
(
self
.
device
)
attn_metadata
.
paged_kv_indptr_gpu
=
paged_kv_indptr_cpu
.
to
(
self
.
device
)
if
num_decodes
>
0
:
pure_decode
=
num_prefills
==
0
# possible required padding for cudagraph replay
use_cudagraph
=
(
self
.
enable_cuda_graph
and
pure_decode
and
num_decodes
<=
self
.
_decode_cudagraph_max_bs
)
if
use_cudagraph
:
num_input_tokens
=
(
self
.
vllm_config
.
pad_for_cudagraph
(
num_decodes
))
# Carefully fulfill the padding region with reasonable value
# on cpu.
# Make sure paged_kv_indptr_cpu is not decreasing
self
.
paged_kv_indptr_cpu
[
1
+
num_decodes
:
1
+
num_input_tokens
].
fill_
(
paged_kv_indptr_cpu
[
-
1
])
# Fill the remaining paged_kv_last_page_len_cpu with 1.
# This is because flashinfer treats 0 as a full page
# instead of empty.
self
.
paged_kv_last_page_len_cpu
[
num_decodes
:
num_input_tokens
].
fill_
(
1
)
else
:
num_input_tokens
=
num_decodes
attn_metadata
.
decode_wrapper
=
self
.
_get_decode_wrapper
(
num_input_tokens
,
use_cudagraph
)
if
not
attn_metadata
.
decode_use_trtllm
:
# Use the persistent buffer with padding length,
# instead of the same address but chunked version
# in atten_metadata when using cudagraph.
fast_plan_decode
(
attn_metadata
.
decode_wrapper
,
self
.
paged_kv_indptr_cpu
[:
num_input_tokens
+
1
],
paged_kv_indices
,
self
.
paged_kv_last_page_len_cpu
[:
num_input_tokens
],
seq_lens_cpu
[:
num_input_tokens
],
self
.
num_qo_heads
,
self
.
num_kv_heads
,
self
.
head_dim
,
self
.
page_size
,
# Disable flashinfer's pos encoding and use vllm's rope.
pos_encoding_mode
=
"NONE"
,
sm_scale
=
self
.
global_hyperparameters
.
sm_scale
,
window_left
=
self
.
global_hyperparameters
.
window_left
,
logits_soft_cap
=
self
.
global_hyperparameters
.
logits_soft_cap
,
q_data_type
=
self
.
q_data_type
,
kv_data_type
=
self
.
kv_cache_dtype
,
)
return
attn_metadata
def
build_for_cudagraph_capture
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment