Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ff93cc8c
Unverified
Commit
ff93cc8c
authored
Oct 23, 2025
by
Andrew Sansom
Committed by
GitHub
Oct 22, 2025
Browse files
[CORE] Support Prefix Caching with Prompt Embeds (#27219)
Signed-off-by:
Andrew Sansom
<
andrew@protopia.ai
>
parent
243ed7d3
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
185 additions
and
18 deletions
+185
-18
docs/features/README.md
docs/features/README.md
+2
-2
tests/v1/core/test_kv_cache_utils.py
tests/v1/core/test_kv_cache_utils.py
+136
-1
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+0
-10
vllm/v1/core/kv_cache_utils.py
vllm/v1/core/kv_cache_utils.py
+30
-2
vllm/v1/serial_utils.py
vllm/v1/serial_utils.py
+4
-3
vllm/v1/utils.py
vllm/v1/utils.py
+13
-0
No files found.
docs/features/README.md
View file @
ff93cc8c
...
...
@@ -52,7 +52,7 @@ th:not(:first-child) {
|
[
mm
](
multimodal_inputs.md
)
| ✅ | ✅ |
[
🟠
](
https://github.com/vllm-project/vllm/pull/4194
)
<sup>
^
</sup>
| ❔ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❔ | ✅ | | | |
| best-of | ✅ | ✅ | ✅ |
[
❌
](
https://github.com/vllm-project/vllm/issues/6137
)
| ✅ | ❌ | ✅ | ✅ | ✅ | ❔ |
[
❌
](
https://github.com/vllm-project/vllm/issues/7968
)
| ✅ | ✅ | | |
| beam-search | ✅ | ✅ | ✅ |
[
❌
](
https://github.com/vllm-project/vllm/issues/6137
)
| ✅ | ❌ | ✅ | ✅ | ✅ | ❔ |
[
❌
](
https://github.com/vllm-project/vllm/issues/7968
)
| ❔ | ✅ | ✅ | |
|
[
prompt-embeds
](
prompt_embeds.md
)
| ✅ |
[
❌
](
https://github.com/vllm-project/vllm/issues/25096
)
| ✅ | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ❔ | ❔ | ❌ | ❔ | ❔ | ✅ |
|
[
prompt-embeds
](
prompt_embeds.md
)
| ✅ |
✅
| ✅ | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ❔ | ❔ | ❌ | ❔ | ❔ | ✅ |
\*
Chunked prefill and prefix caching are only applicable to last-token pooling.
<sup>
^
</sup>
LoRA is only applicable to the language backbone of multimodal models.
...
...
@@ -75,4 +75,4 @@ th:not(:first-child) {
| multi-step | ✅ | ✅ | ✅ | ✅ | ✅ |
[
❌
](
https://github.com/vllm-project/vllm/issues/8477
)
| ✅ | ❌ | ✅ |
| best-of | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ |
| beam-search | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ |
|
[
prompt-embeds
](
prompt_embeds.md
)
| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
?
|
[
❌
](
https://github.com/vllm-project/vllm/issues/25097
)
| ✅ |
|
[
prompt-embeds
](
prompt_embeds.md
)
| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
❔
|
[
❌
](
https://github.com/vllm-project/vllm/issues/25097
)
| ✅ |
tests/v1/core/test_kv_cache_utils.py
View file @
ff93cc8c
...
...
@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
importlib
from
collections.abc
import
Callable
from
typing
import
Any
import
pytest
import
torch
...
...
@@ -32,6 +33,7 @@ from vllm.v1.core.kv_cache_utils import (
init_none_hash
,
is_kv_cache_spec_uniform
,
make_block_hash_with_group_id
,
tensor_data
,
)
from
vllm.v1.kv_cache_interface
import
(
FullAttentionSpec
,
...
...
@@ -61,12 +63,13 @@ def _auto_init_hash_fn(request):
def
make_request
(
request_id
:
str
,
prompt_token_ids
:
list
[
int
],
prompt_token_ids
:
list
[
int
]
|
None
,
block_size
:
int
=
3
,
hash_fn
:
Callable
=
hash
,
mm_positions
:
list
[
PlaceholderRange
]
|
None
=
None
,
mm_hashes
:
list
[
str
]
|
None
=
None
,
cache_salt
:
str
|
None
=
None
,
prompt_embeds
:
torch
.
Tensor
|
None
=
None
,
):
mm_features
=
[]
if
mm_positions
is
not
None
:
...
...
@@ -90,6 +93,7 @@ def make_request(
lora_request
=
None
,
cache_salt
=
cache_salt
,
block_hasher
=
get_request_block_hasher
(
block_size
,
hash_fn
),
prompt_embeds
=
prompt_embeds
,
)
...
...
@@ -450,6 +454,52 @@ def test_generate_block_hash_extra_keys_cache_salt():
assert
next_mm_idx
==
1
def
test_generate_block_hash_extra_keys_prompt_embeds
():
prompt_embeds
=
torch
.
randn
(
10
,
3
)
request
=
make_request
(
request_id
=
"0"
,
prompt_token_ids
=
None
,
mm_positions
=
None
,
mm_hashes
=
None
,
prompt_embeds
=
prompt_embeds
,
)
# Test with prompt embeds for the first block
extra_keys
,
_
=
generate_block_hash_extra_keys
(
request
,
0
,
5
,
0
)
expected_embeds
=
prompt_embeds
[
0
:
5
]
expected_bytes
=
kv_cache_utils
.
tensor_data
(
expected_embeds
).
tobytes
()
assert
extra_keys
==
(
expected_bytes
,)
# Test with prompt embeds for the second block
extra_keys
,
_
=
generate_block_hash_extra_keys
(
request
,
5
,
10
,
0
)
expected_embeds
=
prompt_embeds
[
5
:
10
]
expected_bytes
=
kv_cache_utils
.
tensor_data
(
expected_embeds
).
tobytes
()
assert
extra_keys
==
(
expected_bytes
,)
def
test_generate_block_hash_extra_keys_different_prompt_embeds
():
prompt_embeds1
=
torch
.
randn
(
10
,
3
)
prompt_embeds2
=
torch
.
randn
(
10
,
3
)
request1
=
make_request
(
request_id
=
"0"
,
prompt_token_ids
=
None
,
mm_positions
=
None
,
mm_hashes
=
None
,
prompt_embeds
=
prompt_embeds1
,
)
request2
=
make_request
(
request_id
=
"1"
,
prompt_token_ids
=
None
,
mm_positions
=
None
,
mm_hashes
=
None
,
prompt_embeds
=
prompt_embeds2
,
)
extra_keys1
,
_
=
generate_block_hash_extra_keys
(
request1
,
0
,
5
,
0
)
extra_keys2
,
_
=
generate_block_hash_extra_keys
(
request2
,
0
,
5
,
0
)
assert
extra_keys1
!=
extra_keys2
def
test_generate_block_hash_extra_keys_lora
():
request
=
make_request
(
request_id
=
"0"
,
...
...
@@ -1556,3 +1606,88 @@ def test_merge_mla_spec():
]
with
pytest
.
raises
(
AssertionError
):
kv_cache_specs
[
0
].
merge
(
kv_cache_specs
)
@
pytest
.
mark
.
parametrize
(
"hash_fn"
,
[
sha256
,
sha256_cbor
])
def
test_request_block_hasher_with_prompt_embeds
(
hash_fn
:
Callable
[[
Any
],
bytes
]):
block_size
=
3
num_tokens
=
2
*
block_size
prompt_token_ids
=
[
_
for
_
in
range
(
num_tokens
)]
hidden_size
=
5
prompt_embeds
=
torch
.
randn
((
num_tokens
,
hidden_size
))
request
=
make_request
(
request_id
=
"0"
,
prompt_token_ids
=
prompt_token_ids
,
block_size
=
block_size
,
hash_fn
=
hash_fn
,
prompt_embeds
=
prompt_embeds
,
)
block_hashes
=
request
.
block_hashes
assert
len
(
block_hashes
)
==
2
block1_embeds_bytes
=
tensor_data
(
prompt_embeds
[:
block_size
]).
tobytes
()
expected_hash1
=
hash_fn
(
(
kv_cache_utils
.
NONE_HASH
,
tuple
(
prompt_token_ids
[:
block_size
]),
(
block1_embeds_bytes
,),
)
)
assert
block_hashes
[
0
]
==
expected_hash1
block2_embeds_bytes
=
tensor_data
(
prompt_embeds
[
block_size
:
num_tokens
]).
tobytes
()
expected_hash2
=
hash_fn
(
(
block_hashes
[
0
],
tuple
(
prompt_token_ids
[
block_size
:
num_tokens
]),
(
block2_embeds_bytes
,),
)
)
assert
block_hashes
[
1
]
==
expected_hash2
@
pytest
.
mark
.
parametrize
(
"hash_fn"
,
[
sha256
,
sha256_cbor
])
def
test_request_with_prompt_embeds_and_mm_inputs
(
hash_fn
:
Callable
[[
Any
],
bytes
]):
block_size
=
3
num_tokens
=
2
*
block_size
prompt_token_ids
=
[
_
for
_
in
range
(
num_tokens
)]
hidden_size
=
5
prompt_embeds
=
torch
.
randn
((
num_tokens
,
hidden_size
))
request
=
make_request
(
request_id
=
"0"
,
prompt_token_ids
=
prompt_token_ids
,
block_size
=
block_size
,
hash_fn
=
hash_fn
,
mm_positions
=
[
PlaceholderRange
(
offset
=
0
,
length
=
3
),
PlaceholderRange
(
offset
=
3
,
length
=
3
),
],
mm_hashes
=
[
"hash1"
,
"hash2"
],
prompt_embeds
=
prompt_embeds
,
)
block_hashes
=
request
.
block_hashes
assert
len
(
block_hashes
)
==
2
block1_embeds_bytes
=
tensor_data
(
prompt_embeds
[:
block_size
]).
tobytes
()
expected_hash1
=
hash_fn
(
(
kv_cache_utils
.
NONE_HASH
,
tuple
(
prompt_token_ids
[:
block_size
]),
(
"hash1"
,
block1_embeds_bytes
),
)
)
assert
block_hashes
[
0
]
==
expected_hash1
block2_embeds_bytes
=
tensor_data
(
prompt_embeds
[
block_size
:
num_tokens
]).
tobytes
()
expected_hash2
=
hash_fn
(
(
block_hashes
[
0
],
tuple
(
prompt_token_ids
[
block_size
:
num_tokens
]),
(
"hash2"
,
block2_embeds_bytes
),
)
)
assert
block_hashes
[
1
]
==
expected_hash2
vllm/engine/arg_utils.py
View file @
ff93cc8c
...
...
@@ -1743,16 +1743,6 @@ class EngineArgs:
if
model_config
.
runner_type
!=
"pooling"
:
self
.
enable_chunked_prefill
=
True
# TODO: When prefix caching supports prompt embeds inputs, this
# check can be removed.
if
self
.
enable_prompt_embeds
and
self
.
enable_prefix_caching
is
not
False
:
logger
.
warning
(
"--enable-prompt-embeds and --enable-prefix-caching "
"are not supported together in V1. Prefix caching has "
"been disabled."
)
self
.
enable_prefix_caching
=
False
if
self
.
enable_prefix_caching
is
None
:
# Disable prefix caching default for hybrid models
# since the feature is still experimental.
...
...
vllm/v1/core/kv_cache_utils.py
View file @
ff93cc8c
...
...
@@ -26,6 +26,7 @@ from vllm.v1.kv_cache_interface import (
UniformTypeKVCacheSpecs
,
)
from
vllm.v1.request
import
Request
from
vllm.v1.utils
import
tensor_data
# BlockHash represents the hash of a single KV-cache block used for
# prefix caching. Treating it as a distinct type from `bytes` helps
...
...
@@ -461,11 +462,33 @@ def _gen_lora_extra_hash_keys(request: Request) -> list[str]:
return
[
request
.
lora_request
.
lora_name
]
def
_gen_prompt_embeds_extra_hash_keys
(
request
:
Request
,
start_token_idx
:
int
,
end_token_idx
:
int
)
->
list
[
bytes
]:
"""Generate extra keys related to prompt embeds for block hash computation.
Args:
request: The request object.
start_token_idx: The start token index of the block.
end_token_idx: The end token index of the block.
Returns:
Return prompt embeddings data of the request if it has prompt embeds.
Return empty list otherwise.
"""
if
request
.
prompt_embeds
is
None
:
return
[]
block_prompt_embeds
=
request
.
prompt_embeds
[
start_token_idx
:
end_token_idx
]
embeds_bytes
=
tensor_data
(
block_prompt_embeds
).
tobytes
()
return
[
embeds_bytes
]
def
generate_block_hash_extra_keys
(
request
:
Request
,
start_token_idx
:
int
,
end_token_idx
:
int
,
start_mm_idx
:
int
)
->
tuple
[
tuple
[
Any
,
...]
|
None
,
int
]:
"""Generate extra keys for the block hash. The extra keys can come from
the multi-modal inputs and request specific metadata (e.g., LoRA name).
the multi-modal inputs, request specific metadata (e.g., LoRA names), and
data from prompt embeddings.
Args:
request: The request object.
...
...
@@ -484,8 +507,13 @@ def generate_block_hash_extra_keys(
cache_salt_keys
:
list
[
str
]
=
(
[
request
.
cache_salt
]
if
(
start_token_idx
==
0
and
request
.
cache_salt
)
else
[]
)
prompt_embeds_keys
=
_gen_prompt_embeds_extra_hash_keys
(
request
,
start_token_idx
,
end_token_idx
)
extra_keys
:
list
[
Any
]
=
lora_extra_keys
+
mm_extra_keys
+
cache_salt_keys
extra_keys
:
list
[
Any
]
=
(
lora_extra_keys
+
mm_extra_keys
+
cache_salt_keys
+
prompt_embeds_keys
)
if
not
extra_keys
:
return
None
,
new_start_mm_idx
...
...
vllm/v1/serial_utils.py
View file @
ff93cc8c
...
...
@@ -31,6 +31,7 @@ from vllm.multimodal.inputs import (
NestedTensors
,
)
from
vllm.v1.engine
import
UtilityResult
from
vllm.v1.utils
import
tensor_data
logger
=
init_logger
(
__name__
)
...
...
@@ -218,14 +219,14 @@ class MsgpackEncoder:
)
->
tuple
[
str
,
tuple
[
int
,
...],
int
|
memoryview
]:
assert
self
.
aux_buffers
is
not
None
# view the tensor as a contiguous 1D array of bytes
arr
=
obj
.
flatten
().
contiguous
().
view
(
torch
.
uint8
).
numpy
(
)
arr
_data
=
tensor_data
(
obj
)
if
obj
.
nbytes
<
self
.
size_threshold
:
# Smaller tensors are encoded inline, just like ndarrays.
data
=
msgpack
.
Ext
(
CUSTOM_TYPE_RAW_VIEW
,
arr
.
data
)
data
=
msgpack
.
Ext
(
CUSTOM_TYPE_RAW_VIEW
,
arr
_
data
)
else
:
# Otherwise encode index of backing buffer to avoid copy.
data
=
len
(
self
.
aux_buffers
)
self
.
aux_buffers
.
append
(
arr
.
data
)
self
.
aux_buffers
.
append
(
arr
_
data
)
dtype
=
str
(
obj
.
dtype
).
removeprefix
(
"torch."
)
return
dtype
,
obj
.
shape
,
data
...
...
vllm/v1/utils.py
View file @
ff93cc8c
...
...
@@ -396,3 +396,16 @@ def record_function_or_nullcontext(name: str) -> AbstractContextManager:
_PROFILER_FUNC
=
func
return
func
(
name
)
def
tensor_data
(
tensor
:
torch
.
Tensor
)
->
memoryview
:
"""Get the raw data of a tensor as a uint8 memoryview, useful for
serializing and hashing.
Args:
tensor: The input tensor.
Returns:
A memoryview of the tensor data as uint8.
"""
return
tensor
.
flatten
().
contiguous
().
view
(
torch
.
uint8
).
numpy
().
data
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment