Unverified Commit ff93cc8c authored by Andrew Sansom's avatar Andrew Sansom Committed by GitHub
Browse files

[CORE] Support Prefix Caching with Prompt Embeds (#27219)


Signed-off-by: default avatarAndrew Sansom <andrew@protopia.ai>
parent 243ed7d3
...@@ -52,7 +52,7 @@ th:not(:first-child) { ...@@ -52,7 +52,7 @@ th:not(:first-child) {
| [mm](multimodal_inputs.md) | ✅ | ✅ | [🟠](https://github.com/vllm-project/vllm/pull/4194)<sup>^</sup> | ❔ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❔ | ✅ | | | | | [mm](multimodal_inputs.md) | ✅ | ✅ | [🟠](https://github.com/vllm-project/vllm/pull/4194)<sup>^</sup> | ❔ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❔ | ✅ | | | |
| best-of | ✅ | ✅ | ✅ | [](https://github.com/vllm-project/vllm/issues/6137) | ✅ | ❌ | ✅ | ✅ | ✅ | ❔ | [](https://github.com/vllm-project/vllm/issues/7968) | ✅ | ✅ | | | | best-of | ✅ | ✅ | ✅ | [](https://github.com/vllm-project/vllm/issues/6137) | ✅ | ❌ | ✅ | ✅ | ✅ | ❔ | [](https://github.com/vllm-project/vllm/issues/7968) | ✅ | ✅ | | |
| beam-search | ✅ | ✅ | ✅ | [](https://github.com/vllm-project/vllm/issues/6137) | ✅ | ❌ | ✅ | ✅ | ✅ | ❔ | [](https://github.com/vllm-project/vllm/issues/7968) | ❔ | ✅ | ✅ | | | beam-search | ✅ | ✅ | ✅ | [](https://github.com/vllm-project/vllm/issues/6137) | ✅ | ❌ | ✅ | ✅ | ✅ | ❔ | [](https://github.com/vllm-project/vllm/issues/7968) | ❔ | ✅ | ✅ | |
| [prompt-embeds](prompt_embeds.md) | ✅ | [](https://github.com/vllm-project/vllm/issues/25096) | ✅ | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ❔ | ❔ | ❌ | ❔ | ❔ | ✅ | | [prompt-embeds](prompt_embeds.md) | ✅ | | ✅ | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ❔ | ❔ | ❌ | ❔ | ❔ | ✅ |
\* Chunked prefill and prefix caching are only applicable to last-token pooling. \* Chunked prefill and prefix caching are only applicable to last-token pooling.
<sup>^</sup> LoRA is only applicable to the language backbone of multimodal models. <sup>^</sup> LoRA is only applicable to the language backbone of multimodal models.
...@@ -75,4 +75,4 @@ th:not(:first-child) { ...@@ -75,4 +75,4 @@ th:not(:first-child) {
| multi-step | ✅ | ✅ | ✅ | ✅ | ✅ | [](https://github.com/vllm-project/vllm/issues/8477) | ✅ | ❌ | ✅ | | multi-step | ✅ | ✅ | ✅ | ✅ | ✅ | [](https://github.com/vllm-project/vllm/issues/8477) | ✅ | ❌ | ✅ |
| best-of | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | | best-of | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ |
| beam-search | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | | beam-search | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ |
| [prompt-embeds](prompt_embeds.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ? | [](https://github.com/vllm-project/vllm/issues/25097) | ✅ | | [prompt-embeds](prompt_embeds.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | [](https://github.com/vllm-project/vllm/issues/25097) | ✅ |
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import importlib import importlib
from collections.abc import Callable from collections.abc import Callable
from typing import Any
import pytest import pytest
import torch import torch
...@@ -32,6 +33,7 @@ from vllm.v1.core.kv_cache_utils import ( ...@@ -32,6 +33,7 @@ from vllm.v1.core.kv_cache_utils import (
init_none_hash, init_none_hash,
is_kv_cache_spec_uniform, is_kv_cache_spec_uniform,
make_block_hash_with_group_id, make_block_hash_with_group_id,
tensor_data,
) )
from vllm.v1.kv_cache_interface import ( from vllm.v1.kv_cache_interface import (
FullAttentionSpec, FullAttentionSpec,
...@@ -61,12 +63,13 @@ def _auto_init_hash_fn(request): ...@@ -61,12 +63,13 @@ def _auto_init_hash_fn(request):
def make_request( def make_request(
request_id: str, request_id: str,
prompt_token_ids: list[int], prompt_token_ids: list[int] | None,
block_size: int = 3, block_size: int = 3,
hash_fn: Callable = hash, hash_fn: Callable = hash,
mm_positions: list[PlaceholderRange] | None = None, mm_positions: list[PlaceholderRange] | None = None,
mm_hashes: list[str] | None = None, mm_hashes: list[str] | None = None,
cache_salt: str | None = None, cache_salt: str | None = None,
prompt_embeds: torch.Tensor | None = None,
): ):
mm_features = [] mm_features = []
if mm_positions is not None: if mm_positions is not None:
...@@ -90,6 +93,7 @@ def make_request( ...@@ -90,6 +93,7 @@ def make_request(
lora_request=None, lora_request=None,
cache_salt=cache_salt, cache_salt=cache_salt,
block_hasher=get_request_block_hasher(block_size, hash_fn), block_hasher=get_request_block_hasher(block_size, hash_fn),
prompt_embeds=prompt_embeds,
) )
...@@ -450,6 +454,52 @@ def test_generate_block_hash_extra_keys_cache_salt(): ...@@ -450,6 +454,52 @@ def test_generate_block_hash_extra_keys_cache_salt():
assert next_mm_idx == 1 assert next_mm_idx == 1
def test_generate_block_hash_extra_keys_prompt_embeds():
prompt_embeds = torch.randn(10, 3)
request = make_request(
request_id="0",
prompt_token_ids=None,
mm_positions=None,
mm_hashes=None,
prompt_embeds=prompt_embeds,
)
# Test with prompt embeds for the first block
extra_keys, _ = generate_block_hash_extra_keys(request, 0, 5, 0)
expected_embeds = prompt_embeds[0:5]
expected_bytes = kv_cache_utils.tensor_data(expected_embeds).tobytes()
assert extra_keys == (expected_bytes,)
# Test with prompt embeds for the second block
extra_keys, _ = generate_block_hash_extra_keys(request, 5, 10, 0)
expected_embeds = prompt_embeds[5:10]
expected_bytes = kv_cache_utils.tensor_data(expected_embeds).tobytes()
assert extra_keys == (expected_bytes,)
def test_generate_block_hash_extra_keys_different_prompt_embeds():
prompt_embeds1 = torch.randn(10, 3)
prompt_embeds2 = torch.randn(10, 3)
request1 = make_request(
request_id="0",
prompt_token_ids=None,
mm_positions=None,
mm_hashes=None,
prompt_embeds=prompt_embeds1,
)
request2 = make_request(
request_id="1",
prompt_token_ids=None,
mm_positions=None,
mm_hashes=None,
prompt_embeds=prompt_embeds2,
)
extra_keys1, _ = generate_block_hash_extra_keys(request1, 0, 5, 0)
extra_keys2, _ = generate_block_hash_extra_keys(request2, 0, 5, 0)
assert extra_keys1 != extra_keys2
def test_generate_block_hash_extra_keys_lora(): def test_generate_block_hash_extra_keys_lora():
request = make_request( request = make_request(
request_id="0", request_id="0",
...@@ -1556,3 +1606,88 @@ def test_merge_mla_spec(): ...@@ -1556,3 +1606,88 @@ def test_merge_mla_spec():
] ]
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
kv_cache_specs[0].merge(kv_cache_specs) kv_cache_specs[0].merge(kv_cache_specs)
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor])
def test_request_block_hasher_with_prompt_embeds(hash_fn: Callable[[Any], bytes]):
block_size = 3
num_tokens = 2 * block_size
prompt_token_ids = [_ for _ in range(num_tokens)]
hidden_size = 5
prompt_embeds = torch.randn((num_tokens, hidden_size))
request = make_request(
request_id="0",
prompt_token_ids=prompt_token_ids,
block_size=block_size,
hash_fn=hash_fn,
prompt_embeds=prompt_embeds,
)
block_hashes = request.block_hashes
assert len(block_hashes) == 2
block1_embeds_bytes = tensor_data(prompt_embeds[:block_size]).tobytes()
expected_hash1 = hash_fn(
(
kv_cache_utils.NONE_HASH,
tuple(prompt_token_ids[:block_size]),
(block1_embeds_bytes,),
)
)
assert block_hashes[0] == expected_hash1
block2_embeds_bytes = tensor_data(prompt_embeds[block_size:num_tokens]).tobytes()
expected_hash2 = hash_fn(
(
block_hashes[0],
tuple(prompt_token_ids[block_size:num_tokens]),
(block2_embeds_bytes,),
)
)
assert block_hashes[1] == expected_hash2
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor])
def test_request_with_prompt_embeds_and_mm_inputs(hash_fn: Callable[[Any], bytes]):
block_size = 3
num_tokens = 2 * block_size
prompt_token_ids = [_ for _ in range(num_tokens)]
hidden_size = 5
prompt_embeds = torch.randn((num_tokens, hidden_size))
request = make_request(
request_id="0",
prompt_token_ids=prompt_token_ids,
block_size=block_size,
hash_fn=hash_fn,
mm_positions=[
PlaceholderRange(offset=0, length=3),
PlaceholderRange(offset=3, length=3),
],
mm_hashes=["hash1", "hash2"],
prompt_embeds=prompt_embeds,
)
block_hashes = request.block_hashes
assert len(block_hashes) == 2
block1_embeds_bytes = tensor_data(prompt_embeds[:block_size]).tobytes()
expected_hash1 = hash_fn(
(
kv_cache_utils.NONE_HASH,
tuple(prompt_token_ids[:block_size]),
("hash1", block1_embeds_bytes),
)
)
assert block_hashes[0] == expected_hash1
block2_embeds_bytes = tensor_data(prompt_embeds[block_size:num_tokens]).tobytes()
expected_hash2 = hash_fn(
(
block_hashes[0],
tuple(prompt_token_ids[block_size:num_tokens]),
("hash2", block2_embeds_bytes),
)
)
assert block_hashes[1] == expected_hash2
...@@ -1743,16 +1743,6 @@ class EngineArgs: ...@@ -1743,16 +1743,6 @@ class EngineArgs:
if model_config.runner_type != "pooling": if model_config.runner_type != "pooling":
self.enable_chunked_prefill = True self.enable_chunked_prefill = True
# TODO: When prefix caching supports prompt embeds inputs, this
# check can be removed.
if self.enable_prompt_embeds and self.enable_prefix_caching is not False:
logger.warning(
"--enable-prompt-embeds and --enable-prefix-caching "
"are not supported together in V1. Prefix caching has "
"been disabled."
)
self.enable_prefix_caching = False
if self.enable_prefix_caching is None: if self.enable_prefix_caching is None:
# Disable prefix caching default for hybrid models # Disable prefix caching default for hybrid models
# since the feature is still experimental. # since the feature is still experimental.
......
...@@ -26,6 +26,7 @@ from vllm.v1.kv_cache_interface import ( ...@@ -26,6 +26,7 @@ from vllm.v1.kv_cache_interface import (
UniformTypeKVCacheSpecs, UniformTypeKVCacheSpecs,
) )
from vllm.v1.request import Request from vllm.v1.request import Request
from vllm.v1.utils import tensor_data
# BlockHash represents the hash of a single KV-cache block used for # BlockHash represents the hash of a single KV-cache block used for
# prefix caching. Treating it as a distinct type from `bytes` helps # prefix caching. Treating it as a distinct type from `bytes` helps
...@@ -461,11 +462,33 @@ def _gen_lora_extra_hash_keys(request: Request) -> list[str]: ...@@ -461,11 +462,33 @@ def _gen_lora_extra_hash_keys(request: Request) -> list[str]:
return [request.lora_request.lora_name] return [request.lora_request.lora_name]
def _gen_prompt_embeds_extra_hash_keys(
request: Request, start_token_idx: int, end_token_idx: int
) -> list[bytes]:
"""Generate extra keys related to prompt embeds for block hash computation.
Args:
request: The request object.
start_token_idx: The start token index of the block.
end_token_idx: The end token index of the block.
Returns:
Return prompt embeddings data of the request if it has prompt embeds.
Return empty list otherwise.
"""
if request.prompt_embeds is None:
return []
block_prompt_embeds = request.prompt_embeds[start_token_idx:end_token_idx]
embeds_bytes = tensor_data(block_prompt_embeds).tobytes()
return [embeds_bytes]
def generate_block_hash_extra_keys( def generate_block_hash_extra_keys(
request: Request, start_token_idx: int, end_token_idx: int, start_mm_idx: int request: Request, start_token_idx: int, end_token_idx: int, start_mm_idx: int
) -> tuple[tuple[Any, ...] | None, int]: ) -> tuple[tuple[Any, ...] | None, int]:
"""Generate extra keys for the block hash. The extra keys can come from """Generate extra keys for the block hash. The extra keys can come from
the multi-modal inputs and request specific metadata (e.g., LoRA name). the multi-modal inputs, request specific metadata (e.g., LoRA names), and
data from prompt embeddings.
Args: Args:
request: The request object. request: The request object.
...@@ -484,8 +507,13 @@ def generate_block_hash_extra_keys( ...@@ -484,8 +507,13 @@ def generate_block_hash_extra_keys(
cache_salt_keys: list[str] = ( cache_salt_keys: list[str] = (
[request.cache_salt] if (start_token_idx == 0 and request.cache_salt) else [] [request.cache_salt] if (start_token_idx == 0 and request.cache_salt) else []
) )
prompt_embeds_keys = _gen_prompt_embeds_extra_hash_keys(
request, start_token_idx, end_token_idx
)
extra_keys: list[Any] = lora_extra_keys + mm_extra_keys + cache_salt_keys extra_keys: list[Any] = (
lora_extra_keys + mm_extra_keys + cache_salt_keys + prompt_embeds_keys
)
if not extra_keys: if not extra_keys:
return None, new_start_mm_idx return None, new_start_mm_idx
......
...@@ -31,6 +31,7 @@ from vllm.multimodal.inputs import ( ...@@ -31,6 +31,7 @@ from vllm.multimodal.inputs import (
NestedTensors, NestedTensors,
) )
from vllm.v1.engine import UtilityResult from vllm.v1.engine import UtilityResult
from vllm.v1.utils import tensor_data
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -218,14 +219,14 @@ class MsgpackEncoder: ...@@ -218,14 +219,14 @@ class MsgpackEncoder:
) -> tuple[str, tuple[int, ...], int | memoryview]: ) -> tuple[str, tuple[int, ...], int | memoryview]:
assert self.aux_buffers is not None assert self.aux_buffers is not None
# view the tensor as a contiguous 1D array of bytes # view the tensor as a contiguous 1D array of bytes
arr = obj.flatten().contiguous().view(torch.uint8).numpy() arr_data = tensor_data(obj)
if obj.nbytes < self.size_threshold: if obj.nbytes < self.size_threshold:
# Smaller tensors are encoded inline, just like ndarrays. # Smaller tensors are encoded inline, just like ndarrays.
data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr.data) data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr_data)
else: else:
# Otherwise encode index of backing buffer to avoid copy. # Otherwise encode index of backing buffer to avoid copy.
data = len(self.aux_buffers) data = len(self.aux_buffers)
self.aux_buffers.append(arr.data) self.aux_buffers.append(arr_data)
dtype = str(obj.dtype).removeprefix("torch.") dtype = str(obj.dtype).removeprefix("torch.")
return dtype, obj.shape, data return dtype, obj.shape, data
......
...@@ -396,3 +396,16 @@ def record_function_or_nullcontext(name: str) -> AbstractContextManager: ...@@ -396,3 +396,16 @@ def record_function_or_nullcontext(name: str) -> AbstractContextManager:
_PROFILER_FUNC = func _PROFILER_FUNC = func
return func(name) return func(name)
def tensor_data(tensor: torch.Tensor) -> memoryview:
"""Get the raw data of a tensor as a uint8 memoryview, useful for
serializing and hashing.
Args:
tensor: The input tensor.
Returns:
A memoryview of the tensor data as uint8.
"""
return tensor.flatten().contiguous().view(torch.uint8).numpy().data
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment