Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
98217b09
Unverified
Commit
98217b09
authored
Feb 26, 2026
by
ElizaWszola
Committed by
GitHub
Feb 26, 2026
Browse files
[Performance] Extract KV cache update op from flashinfer forward (#35422)
Signed-off-by:
ElizaWszola
<
ewszola@redhat.com
>
parent
967572dd
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
37 additions
and
25 deletions
+37
-25
vllm/v1/attention/backends/flashinfer.py
vllm/v1/attention/backends/flashinfer.py
+37
-25
No files found.
vllm/v1/attention/backends/flashinfer.py
View file @
98217b09
...
...
@@ -381,6 +381,8 @@ class FlashInferBackend(AttentionBackend):
return
"HND"
return
None
forward_includes_kv_cache_update
:
bool
=
False
@
dataclass
class
FIPrefill
:
...
...
@@ -1330,28 +1332,11 @@ class FlashInferImpl(AttentionImpl):
num_actual_tokens
=
attn_metadata
.
num_actual_tokens
if
self
.
kv_sharing_target_layer_name
is
None
:
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
# not padded. However, we don't need to do key[:num_actual_tokens]
# and value[:num_actual_tokens] because the reshape_and_cache_flash
# op uses the slot_mapping's shape to determine the number of
# actual tokens.
torch
.
ops
.
_C_cache_ops
.
reshape_and_cache_flash
(
key
,
value
,
kv_cache
[:,
0
],
kv_cache
[:,
1
],
attn_metadata
.
slot_mapping
,
self
.
kv_cache_dtype
,
layer
.
_k_scale
,
layer
.
_v_scale
,
)
# The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
# to process the cache when the kv_cache_dtype is fp8
if
self
.
kv_cache_dtype
.
startswith
(
"fp8"
):
if
self
.
kv_sharing_target_layer_name
is
None
and
self
.
kv_cache_dtype
.
startswith
(
"fp8"
):
torch_dtype
=
FlashInferBackend
.
get_fp8_dtype_for_flashinfer
(
self
.
kv_cache_dtype
)
...
...
@@ -1599,6 +1584,33 @@ class FlashInferImpl(AttentionImpl):
)
return
output_padded
def
do_kv_cache_update
(
self
,
layer
:
torch
.
nn
.
Module
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
)
->
None
:
if
self
.
kv_sharing_target_layer_name
is
None
:
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
# not padded. However, we don't need to do key[:num_actual_tokens]
# and value[:num_actual_tokens] because the reshape_and_cache_flash
# op uses the slot_mapping's shape to determine the number of
# actual tokens.
torch
.
ops
.
_C_cache_ops
.
reshape_and_cache_flash
(
key
,
value
,
kv_cache
[:,
0
],
kv_cache
[:,
1
],
slot_mapping
,
self
.
kv_cache_dtype
,
layer
.
_k_scale
,
layer
.
_v_scale
,
)
def
fast_plan_decode
(
self
,
# decode wrapper
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment