Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ff93cc8c
Unverified
Commit
ff93cc8c
authored
Oct 23, 2025
by
Andrew Sansom
Committed by
GitHub
Oct 22, 2025
Browse files
[CORE] Support Prefix Caching with Prompt Embeds (#27219)
Signed-off-by:
Andrew Sansom
<
andrew@protopia.ai
>
parent
243ed7d3
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
185 additions
and
18 deletions
+185
-18
docs/features/README.md
docs/features/README.md
+2
-2
tests/v1/core/test_kv_cache_utils.py
tests/v1/core/test_kv_cache_utils.py
+136
-1
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+0
-10
vllm/v1/core/kv_cache_utils.py
vllm/v1/core/kv_cache_utils.py
+30
-2
vllm/v1/serial_utils.py
vllm/v1/serial_utils.py
+4
-3
vllm/v1/utils.py
vllm/v1/utils.py
+13
-0
No files found.
docs/features/README.md
View file @
ff93cc8c
...
@@ -52,7 +52,7 @@ th:not(:first-child) {
...
@@ -52,7 +52,7 @@ th:not(:first-child) {
|
[
mm
](
multimodal_inputs.md
)
| ✅ | ✅ |
[
🟠
](
https://github.com/vllm-project/vllm/pull/4194
)
<sup>
^
</sup>
| ❔ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❔ | ✅ | | | |
|
[
mm
](
multimodal_inputs.md
)
| ✅ | ✅ |
[
🟠
](
https://github.com/vllm-project/vllm/pull/4194
)
<sup>
^
</sup>
| ❔ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❔ | ✅ | | | |
| best-of | ✅ | ✅ | ✅ |
[
❌
](
https://github.com/vllm-project/vllm/issues/6137
)
| ✅ | ❌ | ✅ | ✅ | ✅ | ❔ |
[
❌
](
https://github.com/vllm-project/vllm/issues/7968
)
| ✅ | ✅ | | |
| best-of | ✅ | ✅ | ✅ |
[
❌
](
https://github.com/vllm-project/vllm/issues/6137
)
| ✅ | ❌ | ✅ | ✅ | ✅ | ❔ |
[
❌
](
https://github.com/vllm-project/vllm/issues/7968
)
| ✅ | ✅ | | |
| beam-search | ✅ | ✅ | ✅ |
[
❌
](
https://github.com/vllm-project/vllm/issues/6137
)
| ✅ | ❌ | ✅ | ✅ | ✅ | ❔ |
[
❌
](
https://github.com/vllm-project/vllm/issues/7968
)
| ❔ | ✅ | ✅ | |
| beam-search | ✅ | ✅ | ✅ |
[
❌
](
https://github.com/vllm-project/vllm/issues/6137
)
| ✅ | ❌ | ✅ | ✅ | ✅ | ❔ |
[
❌
](
https://github.com/vllm-project/vllm/issues/7968
)
| ❔ | ✅ | ✅ | |
|
[
prompt-embeds
](
prompt_embeds.md
)
| ✅ |
[
❌
](
https://github.com/vllm-project/vllm/issues/25096
)
| ✅ | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ❔ | ❔ | ❌ | ❔ | ❔ | ✅ |
|
[
prompt-embeds
](
prompt_embeds.md
)
| ✅ |
✅
| ✅ | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ❔ | ❔ | ❌ | ❔ | ❔ | ✅ |
\*
Chunked prefill and prefix caching are only applicable to last-token pooling.
\*
Chunked prefill and prefix caching are only applicable to last-token pooling.
<sup>
^
</sup>
LoRA is only applicable to the language backbone of multimodal models.
<sup>
^
</sup>
LoRA is only applicable to the language backbone of multimodal models.
...
@@ -75,4 +75,4 @@ th:not(:first-child) {
...
@@ -75,4 +75,4 @@ th:not(:first-child) {
| multi-step | ✅ | ✅ | ✅ | ✅ | ✅ |
[
❌
](
https://github.com/vllm-project/vllm/issues/8477
)
| ✅ | ❌ | ✅ |
| multi-step | ✅ | ✅ | ✅ | ✅ | ✅ |
[
❌
](
https://github.com/vllm-project/vllm/issues/8477
)
| ✅ | ❌ | ✅ |
| best-of | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ |
| best-of | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ |
| beam-search | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ |
| beam-search | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ |
|
[
prompt-embeds
](
prompt_embeds.md
)
| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
?
|
[
❌
](
https://github.com/vllm-project/vllm/issues/25097
)
| ✅ |
|
[
prompt-embeds
](
prompt_embeds.md
)
| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
❔
|
[
❌
](
https://github.com/vllm-project/vllm/issues/25097
)
| ✅ |
tests/v1/core/test_kv_cache_utils.py
View file @
ff93cc8c
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
importlib
import
importlib
from
collections.abc
import
Callable
from
collections.abc
import
Callable
from
typing
import
Any
import
pytest
import
pytest
import
torch
import
torch
...
@@ -32,6 +33,7 @@ from vllm.v1.core.kv_cache_utils import (
...
@@ -32,6 +33,7 @@ from vllm.v1.core.kv_cache_utils import (
init_none_hash
,
init_none_hash
,
is_kv_cache_spec_uniform
,
is_kv_cache_spec_uniform
,
make_block_hash_with_group_id
,
make_block_hash_with_group_id
,
tensor_data
,
)
)
from
vllm.v1.kv_cache_interface
import
(
from
vllm.v1.kv_cache_interface
import
(
FullAttentionSpec
,
FullAttentionSpec
,
...
@@ -61,12 +63,13 @@ def _auto_init_hash_fn(request):
...
@@ -61,12 +63,13 @@ def _auto_init_hash_fn(request):
def
make_request
(
def
make_request
(
request_id
:
str
,
request_id
:
str
,
prompt_token_ids
:
list
[
int
],
prompt_token_ids
:
list
[
int
]
|
None
,
block_size
:
int
=
3
,
block_size
:
int
=
3
,
hash_fn
:
Callable
=
hash
,
hash_fn
:
Callable
=
hash
,
mm_positions
:
list
[
PlaceholderRange
]
|
None
=
None
,
mm_positions
:
list
[
PlaceholderRange
]
|
None
=
None
,
mm_hashes
:
list
[
str
]
|
None
=
None
,
mm_hashes
:
list
[
str
]
|
None
=
None
,
cache_salt
:
str
|
None
=
None
,
cache_salt
:
str
|
None
=
None
,
prompt_embeds
:
torch
.
Tensor
|
None
=
None
,
):
):
mm_features
=
[]
mm_features
=
[]
if
mm_positions
is
not
None
:
if
mm_positions
is
not
None
:
...
@@ -90,6 +93,7 @@ def make_request(
...
@@ -90,6 +93,7 @@ def make_request(
lora_request
=
None
,
lora_request
=
None
,
cache_salt
=
cache_salt
,
cache_salt
=
cache_salt
,
block_hasher
=
get_request_block_hasher
(
block_size
,
hash_fn
),
block_hasher
=
get_request_block_hasher
(
block_size
,
hash_fn
),
prompt_embeds
=
prompt_embeds
,
)
)
...
@@ -450,6 +454,52 @@ def test_generate_block_hash_extra_keys_cache_salt():
...
@@ -450,6 +454,52 @@ def test_generate_block_hash_extra_keys_cache_salt():
assert
next_mm_idx
==
1
assert
next_mm_idx
==
1
def
test_generate_block_hash_extra_keys_prompt_embeds
():
prompt_embeds
=
torch
.
randn
(
10
,
3
)
request
=
make_request
(
request_id
=
"0"
,
prompt_token_ids
=
None
,
mm_positions
=
None
,
mm_hashes
=
None
,
prompt_embeds
=
prompt_embeds
,
)
# Test with prompt embeds for the first block
extra_keys
,
_
=
generate_block_hash_extra_keys
(
request
,
0
,
5
,
0
)
expected_embeds
=
prompt_embeds
[
0
:
5
]
expected_bytes
=
kv_cache_utils
.
tensor_data
(
expected_embeds
).
tobytes
()
assert
extra_keys
==
(
expected_bytes
,)
# Test with prompt embeds for the second block
extra_keys
,
_
=
generate_block_hash_extra_keys
(
request
,
5
,
10
,
0
)
expected_embeds
=
prompt_embeds
[
5
:
10
]
expected_bytes
=
kv_cache_utils
.
tensor_data
(
expected_embeds
).
tobytes
()
assert
extra_keys
==
(
expected_bytes
,)
def
test_generate_block_hash_extra_keys_different_prompt_embeds
():
prompt_embeds1
=
torch
.
randn
(
10
,
3
)
prompt_embeds2
=
torch
.
randn
(
10
,
3
)
request1
=
make_request
(
request_id
=
"0"
,
prompt_token_ids
=
None
,
mm_positions
=
None
,
mm_hashes
=
None
,
prompt_embeds
=
prompt_embeds1
,
)
request2
=
make_request
(
request_id
=
"1"
,
prompt_token_ids
=
None
,
mm_positions
=
None
,
mm_hashes
=
None
,
prompt_embeds
=
prompt_embeds2
,
)
extra_keys1
,
_
=
generate_block_hash_extra_keys
(
request1
,
0
,
5
,
0
)
extra_keys2
,
_
=
generate_block_hash_extra_keys
(
request2
,
0
,
5
,
0
)
assert
extra_keys1
!=
extra_keys2
def
test_generate_block_hash_extra_keys_lora
():
def
test_generate_block_hash_extra_keys_lora
():
request
=
make_request
(
request
=
make_request
(
request_id
=
"0"
,
request_id
=
"0"
,
...
@@ -1556,3 +1606,88 @@ def test_merge_mla_spec():
...
@@ -1556,3 +1606,88 @@ def test_merge_mla_spec():
]
]
with
pytest
.
raises
(
AssertionError
):
with
pytest
.
raises
(
AssertionError
):
kv_cache_specs
[
0
].
merge
(
kv_cache_specs
)
kv_cache_specs
[
0
].
merge
(
kv_cache_specs
)
@
pytest
.
mark
.
parametrize
(
"hash_fn"
,
[
sha256
,
sha256_cbor
])
def
test_request_block_hasher_with_prompt_embeds
(
hash_fn
:
Callable
[[
Any
],
bytes
]):
block_size
=
3
num_tokens
=
2
*
block_size
prompt_token_ids
=
[
_
for
_
in
range
(
num_tokens
)]
hidden_size
=
5
prompt_embeds
=
torch
.
randn
((
num_tokens
,
hidden_size
))
request
=
make_request
(
request_id
=
"0"
,
prompt_token_ids
=
prompt_token_ids
,
block_size
=
block_size
,
hash_fn
=
hash_fn
,
prompt_embeds
=
prompt_embeds
,
)
block_hashes
=
request
.
block_hashes
assert
len
(
block_hashes
)
==
2
block1_embeds_bytes
=
tensor_data
(
prompt_embeds
[:
block_size
]).
tobytes
()
expected_hash1
=
hash_fn
(
(
kv_cache_utils
.
NONE_HASH
,
tuple
(
prompt_token_ids
[:
block_size
]),
(
block1_embeds_bytes
,),
)
)
assert
block_hashes
[
0
]
==
expected_hash1
block2_embeds_bytes
=
tensor_data
(
prompt_embeds
[
block_size
:
num_tokens
]).
tobytes
()
expected_hash2
=
hash_fn
(
(
block_hashes
[
0
],
tuple
(
prompt_token_ids
[
block_size
:
num_tokens
]),
(
block2_embeds_bytes
,),
)
)
assert
block_hashes
[
1
]
==
expected_hash2
@
pytest
.
mark
.
parametrize
(
"hash_fn"
,
[
sha256
,
sha256_cbor
])
def
test_request_with_prompt_embeds_and_mm_inputs
(
hash_fn
:
Callable
[[
Any
],
bytes
]):
block_size
=
3
num_tokens
=
2
*
block_size
prompt_token_ids
=
[
_
for
_
in
range
(
num_tokens
)]
hidden_size
=
5
prompt_embeds
=
torch
.
randn
((
num_tokens
,
hidden_size
))
request
=
make_request
(
request_id
=
"0"
,
prompt_token_ids
=
prompt_token_ids
,
block_size
=
block_size
,
hash_fn
=
hash_fn
,
mm_positions
=
[
PlaceholderRange
(
offset
=
0
,
length
=
3
),
PlaceholderRange
(
offset
=
3
,
length
=
3
),
],
mm_hashes
=
[
"hash1"
,
"hash2"
],
prompt_embeds
=
prompt_embeds
,
)
block_hashes
=
request
.
block_hashes
assert
len
(
block_hashes
)
==
2
block1_embeds_bytes
=
tensor_data
(
prompt_embeds
[:
block_size
]).
tobytes
()
expected_hash1
=
hash_fn
(
(
kv_cache_utils
.
NONE_HASH
,
tuple
(
prompt_token_ids
[:
block_size
]),
(
"hash1"
,
block1_embeds_bytes
),
)
)
assert
block_hashes
[
0
]
==
expected_hash1
block2_embeds_bytes
=
tensor_data
(
prompt_embeds
[
block_size
:
num_tokens
]).
tobytes
()
expected_hash2
=
hash_fn
(
(
block_hashes
[
0
],
tuple
(
prompt_token_ids
[
block_size
:
num_tokens
]),
(
"hash2"
,
block2_embeds_bytes
),
)
)
assert
block_hashes
[
1
]
==
expected_hash2
vllm/engine/arg_utils.py
View file @
ff93cc8c
...
@@ -1743,16 +1743,6 @@ class EngineArgs:
...
@@ -1743,16 +1743,6 @@ class EngineArgs:
if
model_config
.
runner_type
!=
"pooling"
:
if
model_config
.
runner_type
!=
"pooling"
:
self
.
enable_chunked_prefill
=
True
self
.
enable_chunked_prefill
=
True
# TODO: When prefix caching supports prompt embeds inputs, this
# check can be removed.
if
self
.
enable_prompt_embeds
and
self
.
enable_prefix_caching
is
not
False
:
logger
.
warning
(
"--enable-prompt-embeds and --enable-prefix-caching "
"are not supported together in V1. Prefix caching has "
"been disabled."
)
self
.
enable_prefix_caching
=
False
if
self
.
enable_prefix_caching
is
None
:
if
self
.
enable_prefix_caching
is
None
:
# Disable prefix caching default for hybrid models
# Disable prefix caching default for hybrid models
# since the feature is still experimental.
# since the feature is still experimental.
...
...
vllm/v1/core/kv_cache_utils.py
View file @
ff93cc8c
...
@@ -26,6 +26,7 @@ from vllm.v1.kv_cache_interface import (
...
@@ -26,6 +26,7 @@ from vllm.v1.kv_cache_interface import (
UniformTypeKVCacheSpecs
,
UniformTypeKVCacheSpecs
,
)
)
from
vllm.v1.request
import
Request
from
vllm.v1.request
import
Request
from
vllm.v1.utils
import
tensor_data
# BlockHash represents the hash of a single KV-cache block used for
# BlockHash represents the hash of a single KV-cache block used for
# prefix caching. Treating it as a distinct type from `bytes` helps
# prefix caching. Treating it as a distinct type from `bytes` helps
...
@@ -461,11 +462,33 @@ def _gen_lora_extra_hash_keys(request: Request) -> list[str]:
...
@@ -461,11 +462,33 @@ def _gen_lora_extra_hash_keys(request: Request) -> list[str]:
return
[
request
.
lora_request
.
lora_name
]
return
[
request
.
lora_request
.
lora_name
]
def
_gen_prompt_embeds_extra_hash_keys
(
request
:
Request
,
start_token_idx
:
int
,
end_token_idx
:
int
)
->
list
[
bytes
]:
"""Generate extra keys related to prompt embeds for block hash computation.
Args:
request: The request object.
start_token_idx: The start token index of the block.
end_token_idx: The end token index of the block.
Returns:
Return prompt embeddings data of the request if it has prompt embeds.
Return empty list otherwise.
"""
if
request
.
prompt_embeds
is
None
:
return
[]
block_prompt_embeds
=
request
.
prompt_embeds
[
start_token_idx
:
end_token_idx
]
embeds_bytes
=
tensor_data
(
block_prompt_embeds
).
tobytes
()
return
[
embeds_bytes
]
def
generate_block_hash_extra_keys
(
def
generate_block_hash_extra_keys
(
request
:
Request
,
start_token_idx
:
int
,
end_token_idx
:
int
,
start_mm_idx
:
int
request
:
Request
,
start_token_idx
:
int
,
end_token_idx
:
int
,
start_mm_idx
:
int
)
->
tuple
[
tuple
[
Any
,
...]
|
None
,
int
]:
)
->
tuple
[
tuple
[
Any
,
...]
|
None
,
int
]:
"""Generate extra keys for the block hash. The extra keys can come from
"""Generate extra keys for the block hash. The extra keys can come from
the multi-modal inputs and request specific metadata (e.g., LoRA name).
the multi-modal inputs, request specific metadata (e.g., LoRA names), and
data from prompt embeddings.
Args:
Args:
request: The request object.
request: The request object.
...
@@ -484,8 +507,13 @@ def generate_block_hash_extra_keys(
...
@@ -484,8 +507,13 @@ def generate_block_hash_extra_keys(
cache_salt_keys
:
list
[
str
]
=
(
cache_salt_keys
:
list
[
str
]
=
(
[
request
.
cache_salt
]
if
(
start_token_idx
==
0
and
request
.
cache_salt
)
else
[]
[
request
.
cache_salt
]
if
(
start_token_idx
==
0
and
request
.
cache_salt
)
else
[]
)
)
prompt_embeds_keys
=
_gen_prompt_embeds_extra_hash_keys
(
request
,
start_token_idx
,
end_token_idx
)
extra_keys
:
list
[
Any
]
=
lora_extra_keys
+
mm_extra_keys
+
cache_salt_keys
extra_keys
:
list
[
Any
]
=
(
lora_extra_keys
+
mm_extra_keys
+
cache_salt_keys
+
prompt_embeds_keys
)
if
not
extra_keys
:
if
not
extra_keys
:
return
None
,
new_start_mm_idx
return
None
,
new_start_mm_idx
...
...
vllm/v1/serial_utils.py
View file @
ff93cc8c
...
@@ -31,6 +31,7 @@ from vllm.multimodal.inputs import (
...
@@ -31,6 +31,7 @@ from vllm.multimodal.inputs import (
NestedTensors
,
NestedTensors
,
)
)
from
vllm.v1.engine
import
UtilityResult
from
vllm.v1.engine
import
UtilityResult
from
vllm.v1.utils
import
tensor_data
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -218,14 +219,14 @@ class MsgpackEncoder:
...
@@ -218,14 +219,14 @@ class MsgpackEncoder:
)
->
tuple
[
str
,
tuple
[
int
,
...],
int
|
memoryview
]:
)
->
tuple
[
str
,
tuple
[
int
,
...],
int
|
memoryview
]:
assert
self
.
aux_buffers
is
not
None
assert
self
.
aux_buffers
is
not
None
# view the tensor as a contiguous 1D array of bytes
# view the tensor as a contiguous 1D array of bytes
arr
=
obj
.
flatten
().
contiguous
().
view
(
torch
.
uint8
).
numpy
(
)
arr
_data
=
tensor_data
(
obj
)
if
obj
.
nbytes
<
self
.
size_threshold
:
if
obj
.
nbytes
<
self
.
size_threshold
:
# Smaller tensors are encoded inline, just like ndarrays.
# Smaller tensors are encoded inline, just like ndarrays.
data
=
msgpack
.
Ext
(
CUSTOM_TYPE_RAW_VIEW
,
arr
.
data
)
data
=
msgpack
.
Ext
(
CUSTOM_TYPE_RAW_VIEW
,
arr
_
data
)
else
:
else
:
# Otherwise encode index of backing buffer to avoid copy.
# Otherwise encode index of backing buffer to avoid copy.
data
=
len
(
self
.
aux_buffers
)
data
=
len
(
self
.
aux_buffers
)
self
.
aux_buffers
.
append
(
arr
.
data
)
self
.
aux_buffers
.
append
(
arr
_
data
)
dtype
=
str
(
obj
.
dtype
).
removeprefix
(
"torch."
)
dtype
=
str
(
obj
.
dtype
).
removeprefix
(
"torch."
)
return
dtype
,
obj
.
shape
,
data
return
dtype
,
obj
.
shape
,
data
...
...
vllm/v1/utils.py
View file @
ff93cc8c
...
@@ -396,3 +396,16 @@ def record_function_or_nullcontext(name: str) -> AbstractContextManager:
...
@@ -396,3 +396,16 @@ def record_function_or_nullcontext(name: str) -> AbstractContextManager:
_PROFILER_FUNC
=
func
_PROFILER_FUNC
=
func
return
func
(
name
)
return
func
(
name
)
def
tensor_data
(
tensor
:
torch
.
Tensor
)
->
memoryview
:
"""Get the raw data of a tensor as a uint8 memoryview, useful for
serializing and hashing.
Args:
tensor: The input tensor.
Returns:
A memoryview of the tensor data as uint8.
"""
return
tensor
.
flatten
().
contiguous
().
view
(
torch
.
uint8
).
numpy
().
data
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment