Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6578e873
Unverified
Commit
6578e873
authored
Aug 27, 2025
by
Woosuk Kwon
Committed by
GitHub
Aug 27, 2025
Browse files
Optimize input preparation for FlashInfer [2/N] (#23174)
Signed-off-by:
Woosuk Kwon
<
woosuk.kwon@berkeley.edu
>
parent
5bd9f841
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
54 additions
and
26 deletions
+54
-26
vllm/v1/attention/backends/flashinfer.py
vllm/v1/attention/backends/flashinfer.py
+54
-26
No files found.
vllm/v1/attention/backends/flashinfer.py
View file @
6578e873
...
...
@@ -6,6 +6,7 @@ from __future__ import annotations
from
dataclasses
import
dataclass
from
typing
import
ClassVar
,
Optional
,
Union
import
numpy
as
np
import
torch
from
flashinfer
import
(
BatchDecodeWithPagedKVCacheWrapper
,
BatchPrefillWithPagedKVCacheWrapper
,
...
...
@@ -22,6 +23,7 @@ from vllm.logger import init_logger
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
QuantKey
,
kFp8StaticTensorSym
,
kNvfp4Quant
)
from
vllm.platforms
import
current_platform
from
vllm.triton_utils
import
tl
,
triton
from
vllm.utils
import
cdiv
,
is_pin_memory_available
from
vllm.utils.flashinfer
import
(
supports_trtllm_attention
,
use_trtllm_attention
)
...
...
@@ -230,6 +232,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
dtype
=
torch
.
int32
,
device
=
"cpu"
,
pin_memory
=
pin_memory
)
self
.
paged_kv_indptr_np
=
self
.
paged_kv_indptr_cpu
.
numpy
()
self
.
paged_kv_indices_cpu
=
torch
.
zeros
(
max_num_pages
,
dtype
=
torch
.
int32
,
device
=
"cpu"
,
...
...
@@ -238,10 +241,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
dtype
=
torch
.
int32
,
device
=
"cpu"
,
pin_memory
=
pin_memory
)
self
.
block_table_arange
=
torch
.
arange
(
max_num_pages_per_req
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
self
.
paged_kv_last_page_len_np
=
(
self
.
paged_kv_last_page_len_cpu
.
numpy
())
def
_get_workspace_buffer
(
self
):
if
self
.
_workspace_buffer
is
None
:
...
...
@@ -317,9 +318,10 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
max_seq_len
=
common_attn_metadata
.
max_seq_len
seq_lens
=
common_attn_metadata
.
seq_lens
seq_lens_cpu
=
common_attn_metadata
.
seq_lens_cpu
seq_lens_np
=
seq_lens_cpu
.
numpy
()
block_table_tensor
=
common_attn_metadata
.
block_table_tensor
block
_table_bounds_cpu
=
(
seq_lens_
cpu
+
page_size
-
1
)
//
page_size
num_
block
s_np
=
(
seq_lens_
np
+
(
page_size
-
1
)
)
//
page_size
use_cascade
=
common_prefix_len
>
0
if
use_cascade
:
...
...
@@ -342,37 +344,41 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
# Remove the blocks of the shared prefix from all requests.
block_table_tensor
=
block_table_tensor
[:,
num_common_kv_blocks
:]
block
_table_bounds_cpu
-=
num_common_kv_blocks
num_
block
s_np
-=
num_common_kv_blocks
else
:
shared_qo_indptr_cpu
=
None
shared_kv_page_indptr_cpu
=
None
shared_kv_page_indices_cpu
=
None
shared_kv_last_page_len_cpu
=
None
max_num_blocks
=
block_table_bounds_cpu
.
max
().
item
()
block_table_bounds
=
block_table_bounds_cpu
.
to
(
self
.
device
,
# write self.paged_kv_indptr_cpu inplace (0-index is always 0)
np
.
cumsum
(
num_blocks_np
,
dtype
=
np
.
int32
,
out
=
self
.
paged_kv_indptr_np
[
1
:
num_reqs
+
1
],
)
paged_kv_indptr
=
self
.
paged_kv_indptr
[:
num_reqs
+
1
]
paged_kv_indptr
.
copy_
(
self
.
paged_kv_indptr_cpu
[:
num_reqs
+
1
],
non_blocking
=
True
)
mask
=
(
self
.
block_table_arange
[:
max_num_blocks
].
unsqueeze
(
0
)
<
block_table_bounds
.
unsqueeze
(
1
))
# write self.paged_kv_indices inplace
num_actual_pages
=
torch
.
sum
(
mask
)
num_actual_pages
=
num_blocks_np
.
sum
().
item
(
)
paged_kv_indices
=
self
.
paged_kv_indices
[:
num_actual_pages
]
torch
.
masked_select
(
block_table_tensor
[:,
:
max_num_blocks
],
mask
,
out
=
paged_kv_indices
)
# write self.paged_kv_indptr_cpu inplace (0-index is always 0)
torch
.
cumsum
(
block_table_bounds_cpu
,
dim
=
0
,
dtype
=
torch
.
int32
,
out
=
self
.
paged_kv_indptr_cpu
[
1
:
1
+
num_reqs
])
_copy_page_indices_kernel
[(
num_reqs
,
)](
paged_kv_indices
,
block_table_tensor
,
block_table_tensor
.
stride
(
0
),
paged_kv_indptr
,
BLOCK_SIZE
=
1024
,
)
paged_kv_last_page_len_cpu
=
seq_lens_cpu
%
page_size
# write self.paged_kv_last_page_len_cpu inplace
torch
.
where
(
paged_kv_last_page_len_cpu
==
0
,
torch
.
tensor
(
page_size
),
paged_kv_last_page_len_cpu
,
out
=
self
.
paged_kv_last_page_len_cpu
[:
num_reqs
])
paged_kv_last_page_len_np
=
seq_lens_np
%
page_size
self
.
paged_kv_last_page_len_np
[:
num_reqs
]
=
np
.
where
(
paged_kv_last_page_len_np
==
0
,
page_size
,
paged_kv_last_page_len_np
,
)
# Check if any layer uses sinks (requires TRTLLM attention)
has_sinks
=
self
.
global_hyperparameters
.
has_sinks
...
...
@@ -1002,3 +1008,25 @@ def fast_plan_decode(
self
.
_sm_scale
=
sm_scale
self
.
_rope_scale
=
rope_scale
self
.
_rope_theta
=
rope_theta
@
triton
.
jit
def
_copy_page_indices_kernel
(
page_indices
,
block_table
,
block_table_stride
,
cu_num_blocks
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
req_idx
=
tl
.
program_id
(
0
)
row_ptr
=
block_table
+
req_idx
*
block_table_stride
start_idx
=
tl
.
load
(
cu_num_blocks
+
req_idx
)
end_idx
=
tl
.
load
(
cu_num_blocks
+
req_idx
+
1
)
num_blocks
=
end_idx
-
start_idx
offset
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
for
i
in
tl
.
range
(
0
,
num_blocks
,
BLOCK_SIZE
):
block_ids
=
tl
.
load
(
row_ptr
+
i
+
offset
,
mask
=
i
+
offset
<
num_blocks
)
tl
.
store
(
page_indices
+
start_idx
+
i
+
offset
,
block_ids
,
mask
=
i
+
offset
<
num_blocks
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment