Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ebed80a7
Unverified
Commit
ebed80a7
authored
Mar 06, 2026
by
Dor Huri
Committed by
GitHub
Mar 06, 2026
Browse files
[Performance] Extract KV-cache update from TreeAttention backend (#35384)
Signed-off-by:
dorhuri123
<
dor.huri1@live.biu.ac.il
>
parent
a73af584
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
28 additions
and
19 deletions
+28
-19
vllm/v1/attention/backends/tree_attn.py
vllm/v1/attention/backends/tree_attn.py
+28
-19
No files found.
vllm/v1/attention/backends/tree_attn.py
View file @
ebed80a7
...
...
@@ -31,6 +31,7 @@ logger = init_logger(__name__)
class
TreeAttentionBackend
(
AttentionBackend
):
accept_output_buffer
:
bool
=
True
supported_dtypes
:
ClassVar
[
list
[
torch
.
dtype
]]
=
[
torch
.
float16
,
torch
.
bfloat16
]
forward_includes_kv_cache_update
:
bool
=
False
@
staticmethod
def
get_supported_kernel_block_sizes
()
->
list
[
int
|
MultipleOf
]:
...
...
@@ -326,6 +327,33 @@ class TreeAttentionImpl(AttentionImpl):
"TreeAttentionImpl."
)
def
do_kv_cache_update
(
self
,
layer
:
torch
.
nn
.
Module
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
)
->
None
:
key_cache
,
value_cache
=
kv_cache
.
unbind
(
0
)
# Reshape the input keys and values and store them in the cache.
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
# not padded. However, we don't need to do key[:num_actual_tokens]
# and value[:num_actual_tokens] because the reshape_and_cache_flash
# op uses the slot_mapping's shape to determine the number of
# actual tokens.
ops
.
reshape_and_cache_flash
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
self
.
kv_cache_dtype
,
layer
.
_k_scale
,
layer
.
_v_scale
,
)
def
forward
(
self
,
layer
:
torch
.
nn
.
Module
,
...
...
@@ -361,26 +389,7 @@ class TreeAttentionImpl(AttentionImpl):
# Profiling run.
return
output
.
fill_
(
0
)
# Cache the input KVs.
key_cache
,
value_cache
=
kv_cache
.
unbind
(
0
)
if
self
.
kv_sharing_target_layer_name
is
None
:
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
# not padded. However, we don't need to do key[:num_actual_tokens]
# and value[:num_actual_tokens] because the reshape_and_cache_flash
# op uses the slot_mapping's shape to determine the number of
# actual tokens.
ops
.
reshape_and_cache_flash
(
key
,
value
,
key_cache
,
value_cache
,
attn_metadata
.
slot_mapping
,
self
.
kv_cache_dtype
,
layer
.
_k_scale
,
layer
.
_v_scale
,
)
num_actual_tokens
=
attn_metadata
.
num_actual_tokens
num_decode_tokens
=
attn_metadata
.
num_decode_tokens
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment