"vscode:/vscode.git/clone" did not exist on "ccd3e55e51d44bf3a17b2203a304c9609aa5dfe2"
Unverified Commit bf8717eb authored by Cody Yu's avatar Cody Yu Committed by GitHub
Browse files

[V1] Prefix caching for vision language models (#11187)


Signed-off-by: default avatarCody Yu <hao.yu.cody@gmail.com>
parent c77eb8a3
...@@ -2,16 +2,23 @@ ...@@ -2,16 +2,23 @@
import pytest import pytest
from vllm.inputs import token_inputs from vllm.inputs import token_inputs
from vllm.multimodal.inputs import PlaceholderRange
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.utils import cdiv from vllm.utils import cdiv
from vllm.v1.core.kv_cache_manager import KVCacheManager, Request from vllm.v1.core.kv_cache_manager import KVCacheManager, Request
from vllm.v1.core.kv_cache_utils import KVCacheBlock, hash_block_tokens from vllm.v1.core.kv_cache_utils import KVCacheBlock, hash_block_tokens
def make_request(request_id, prompt_token_ids): def make_request(request_id,
prompt_token_ids,
mm_positions=None,
mm_hashes=None):
return Request( return Request(
request_id=request_id, request_id=request_id,
inputs=token_inputs(prompt_token_ids=prompt_token_ids), inputs=token_inputs(prompt_token_ids=prompt_token_ids,
multi_modal_placeholders={"image": mm_positions}
if mm_positions else None,
multi_modal_hashes=mm_hashes),
sampling_params=SamplingParams(max_tokens=17), sampling_params=SamplingParams(max_tokens=17),
eos_token_id=100, eos_token_id=100,
arrival_time=0, arrival_time=0,
...@@ -38,6 +45,7 @@ def test_prefill(): ...@@ -38,6 +45,7 @@ def test_prefill():
all_token_ids = common_token_ids + unique_token_ids all_token_ids = common_token_ids + unique_token_ids
req0 = make_request("0", all_token_ids) req0 = make_request("0", all_token_ids)
computed_blocks = manager.get_computed_blocks(req0) computed_blocks = manager.get_computed_blocks(req0)
assert len(req0.kv_block_hashes) == 3
assert not computed_blocks assert not computed_blocks
blocks = manager.allocate_slots(req0, 55, computed_blocks) blocks = manager.allocate_slots(req0, 55, computed_blocks)
assert [b.block_id for b in blocks] == [0, 1, 2, 3, 4] assert [b.block_id for b in blocks] == [0, 1, 2, 3, 4]
...@@ -61,6 +69,7 @@ def test_prefill(): ...@@ -61,6 +69,7 @@ def test_prefill():
unique_token_ids = [3] * 5 unique_token_ids = [3] * 5
req1 = make_request("1", common_token_ids + unique_token_ids) req1 = make_request("1", common_token_ids + unique_token_ids)
computed_blocks = manager.get_computed_blocks(req1) computed_blocks = manager.get_computed_blocks(req1)
assert len(req1.kv_block_hashes) == 3
assert [b.block_id for b in computed_blocks] == [0, 1, 2] assert [b.block_id for b in computed_blocks] == [0, 1, 2]
num_new_tokens = 53 - 3 * 16 num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req1, num_new_tokens, computed_blocks) blocks = manager.allocate_slots(req1, num_new_tokens, computed_blocks)
...@@ -90,6 +99,7 @@ def test_prefill(): ...@@ -90,6 +99,7 @@ def test_prefill():
unique_token_ids = [3] * 6 unique_token_ids = [3] * 6
req2 = make_request("2", common_token_ids + unique_token_ids) req2 = make_request("2", common_token_ids + unique_token_ids)
computed_block = manager.get_computed_blocks(req2) computed_block = manager.get_computed_blocks(req2)
assert len(req2.kv_block_hashes) == 3
assert [b.block_id for b in computed_block] == [0, 1, 2] assert [b.block_id for b in computed_block] == [0, 1, 2]
num_new_tokens = 53 - 3 * 16 num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req2, num_new_tokens, computed_blocks) blocks = manager.allocate_slots(req2, num_new_tokens, computed_blocks)
...@@ -416,3 +426,77 @@ def test_cache_blocks(): ...@@ -416,3 +426,77 @@ def test_cache_blocks():
) )
assert len(manager.cached_block_hash_to_block) == 3 assert len(manager.cached_block_hash_to_block) == 3
assert blocks[0].block_hash is not None assert blocks[0].block_hash is not None
def test_mm_prefix_caching():
"""
This tests that the multi-modal prefix caching is correct.
"""
manager = KVCacheManager(
block_size=16,
num_gpu_blocks=10,
max_model_len=8192,
sliding_window=None,
enable_caching=True,
num_preallocate_tokens=16,
)
# Common prompt tokens (T is text tokens and P is image placeholder tokens)
# [T,...,T, P0,...,P0], [P0,...,P0,T,...,T,P1,...,P1], [P1,...,P1]
common_token_ids = list(range(10)) + [-1] * 6
common_token_ids += [-1] * 4 + list(range(10, 20)) + [-1] * 2
common_token_ids += [-1] * 16
common_mm_positions = [
PlaceholderRange(offset=11, length=10),
PlaceholderRange(offset=30, length=18),
]
common_mm_hashes = ["aaa", "bbb"]
# A unique image plus some text tokens.
unique_token_ids = [-1] * 7 + [100] * 4
all_token_ids = common_token_ids + unique_token_ids
mm_positions = common_mm_positions + [
PlaceholderRange(offset=48, length=7)
]
mm_hashes = common_mm_hashes + ["ccc"]
req0 = make_request("0",
all_token_ids,
mm_positions=mm_positions,
mm_hashes=mm_hashes)
computed_blocks = manager.get_computed_blocks(req0)
# Completed block should have hashes with extra keys.
assert not computed_blocks
assert len(req0.kv_block_hashes) == 3
assert req0.kv_block_hashes[0].extra_keys == (("aaa", 0), )
assert req0.kv_block_hashes[1].extra_keys == (("aaa", 5), ("bbb", 0))
assert req0.kv_block_hashes[2].extra_keys == (("bbb", 2), )
blocks = manager.allocate_slots(req0, 59, computed_blocks)
assert [b.block_id for b in blocks] == [0, 1, 2, 3, 4]
req0.num_computed_tokens = 59
# Append slots without allocating a new block.
for _ in range(5):
req0.append_output_token_ids(8)
new_blocks = manager.append_slots(req0, 5)
assert new_blocks is not None and len(new_blocks) == 0
# The just completed block should have hashes with extra keys.
assert len(req0.kv_block_hashes) == 4
assert req0.kv_block_hashes[3].extra_keys == (("ccc", 0), )
# Cache hit.
unique_token_ids = [-1] * 7 + [200] * 5
all_token_ids = common_token_ids + unique_token_ids
mm_positions = common_mm_positions + [
PlaceholderRange(offset=48, length=7)
]
mm_hashes = common_mm_hashes + ["ccc"]
req1 = make_request("1",
all_token_ids,
mm_positions=mm_positions,
mm_hashes=mm_hashes)
computed_blocks = manager.get_computed_blocks(req1)
assert len(computed_blocks) == 3
...@@ -31,14 +31,6 @@ def test_prefix_caching_from_cli(): ...@@ -31,14 +31,6 @@ def test_prefix_caching_from_cli():
assert engine_args.enable_prefix_caching assert engine_args.enable_prefix_caching
def test_defaults():
engine_args = EngineArgs(model="facebook/opt-125m")
# Assert V1 defaults
assert (engine_args.enable_prefix_caching
), "V1 turns on prefix caching by default"
def test_defaults_with_usage_context(): def test_defaults_with_usage_context():
engine_args = EngineArgs(model="facebook/opt-125m") engine_args = EngineArgs(model="facebook/opt-125m")
vllm_config: VllmConfig = engine_args.create_engine_config( vllm_config: VllmConfig = engine_args.create_engine_config(
...@@ -52,10 +44,3 @@ def test_defaults_with_usage_context(): ...@@ -52,10 +44,3 @@ def test_defaults_with_usage_context():
UsageContext.OPENAI_API_SERVER) UsageContext.OPENAI_API_SERVER)
assert vllm_config.scheduler_config.max_num_seqs == 1024 assert vllm_config.scheduler_config.max_num_seqs == 1024
assert vllm_config.scheduler_config.max_num_batched_tokens == 2048 assert vllm_config.scheduler_config.max_num_batched_tokens == 2048
def test_prefix_cache_disabled_with_multimodel():
engine_args = EngineArgs(model="llava-hf/llava-1.5-7b-hf")
vllm_config = engine_args.create_engine_config(UsageContext.LLM_CLASS)
assert not vllm_config.cache_config.enable_prefix_caching
...@@ -205,6 +205,7 @@ class EngineArgs: ...@@ -205,6 +205,7 @@ class EngineArgs:
# by user. # by user.
if self.enable_prefix_caching is None: if self.enable_prefix_caching is None:
self.enable_prefix_caching = bool(envs.VLLM_USE_V1) self.enable_prefix_caching = bool(envs.VLLM_USE_V1)
# Override max_num_seqs if it's not set by user. # Override max_num_seqs if it's not set by user.
if self.max_num_seqs is None: if self.max_num_seqs is None:
self.max_num_seqs = 256 if not envs.VLLM_USE_V1 else 1024 self.max_num_seqs = 256 if not envs.VLLM_USE_V1 else 1024
...@@ -1026,11 +1027,11 @@ class EngineArgs: ...@@ -1026,11 +1027,11 @@ class EngineArgs:
device_config = DeviceConfig(device=self.device) device_config = DeviceConfig(device=self.device)
model_config = self.create_model_config() model_config = self.create_model_config()
if model_config.is_multimodal_model: if (model_config.is_multimodal_model and not envs.VLLM_USE_V1
if self.enable_prefix_caching: and self.enable_prefix_caching):
logger.warning( logger.warning("--enable-prefix-caching is currently not "
"--enable-prefix-caching is currently not " "supported for multimodal models in v0 and "
"supported for multimodal models and has been disabled.") "has been disabled.")
self.enable_prefix_caching = False self.enable_prefix_caching = False
cache_config = CacheConfig( cache_config = CacheConfig(
...@@ -1249,11 +1250,14 @@ class EngineArgs: ...@@ -1249,11 +1250,14 @@ class EngineArgs:
# When no user override, set the default values based on the usage # When no user override, set the default values based on the usage
# context. # context.
# TODO(woosuk): Tune the default values for different hardware. # TODO(woosuk): Tune the default values for different hardware.
if self.max_num_batched_tokens is None: default_max_num_batched_tokens = {
if usage_context == UsageContext.LLM_CLASS: UsageContext.LLM_CLASS: 8192,
self.max_num_batched_tokens = 8192 UsageContext.OPENAI_API_SERVER: 2048,
elif usage_context == UsageContext.OPENAI_API_SERVER: }
self.max_num_batched_tokens = 2048 if (self.max_num_batched_tokens is None
and usage_context in default_max_num_batched_tokens):
self.max_num_batched_tokens = default_max_num_batched_tokens[
usage_context]
logger.warning( logger.warning(
"Setting max_num_batched_tokens to %d for %s usage context.", "Setting max_num_batched_tokens to %d for %s usage context.",
self.max_num_batched_tokens, usage_context.value) self.max_num_batched_tokens, usage_context.value)
...@@ -1263,9 +1267,6 @@ class EngineArgs: ...@@ -1263,9 +1267,6 @@ class EngineArgs:
Override the EngineConfig's configs based on the usage context for V1. Override the EngineConfig's configs based on the usage context for V1.
""" """
assert envs.VLLM_USE_V1, "V1 is not enabled" assert envs.VLLM_USE_V1, "V1 is not enabled"
if engine_config.model_config.is_multimodal_model:
# TODO (ywang96): Enable APC by default when VLM supports it.
assert not engine_config.cache_config.enable_prefix_caching
@dataclass @dataclass
......
...@@ -162,6 +162,11 @@ class TokenInputs(TypedDict): ...@@ -162,6 +162,11 @@ class TokenInputs(TypedDict):
Placeholder ranges for the multi-modal data. Placeholder ranges for the multi-modal data.
""" """
multi_modal_hashes: NotRequired[List[str]]
"""
The hashes of the multi-modal data.
"""
mm_processor_kwargs: NotRequired[Dict[str, Any]] mm_processor_kwargs: NotRequired[Dict[str, Any]]
""" """
Optional multi-modal processor kwargs to be forwarded to the Optional multi-modal processor kwargs to be forwarded to the
...@@ -177,6 +182,7 @@ def token_inputs( ...@@ -177,6 +182,7 @@ def token_inputs(
prompt: Optional[str] = None, prompt: Optional[str] = None,
multi_modal_data: Optional["MultiModalDataDict"] = None, multi_modal_data: Optional["MultiModalDataDict"] = None,
multi_modal_inputs: Optional["MultiModalKwargs"] = None, multi_modal_inputs: Optional["MultiModalKwargs"] = None,
multi_modal_hashes: Optional[List[str]] = None,
multi_modal_placeholders: Optional["MultiModalPlaceholderDict"] = None, multi_modal_placeholders: Optional["MultiModalPlaceholderDict"] = None,
mm_processor_kwargs: Optional[Dict[str, Any]] = None, mm_processor_kwargs: Optional[Dict[str, Any]] = None,
) -> TokenInputs: ) -> TokenInputs:
...@@ -191,6 +197,8 @@ def token_inputs( ...@@ -191,6 +197,8 @@ def token_inputs(
inputs["multi_modal_data"] = multi_modal_data inputs["multi_modal_data"] = multi_modal_data
if multi_modal_inputs is not None: if multi_modal_inputs is not None:
inputs["multi_modal_inputs"] = multi_modal_inputs inputs["multi_modal_inputs"] = multi_modal_inputs
if multi_modal_hashes is not None:
inputs["multi_modal_hashes"] = multi_modal_hashes
if multi_modal_placeholders is not None: if multi_modal_placeholders is not None:
inputs["multi_modal_placeholders"] = multi_modal_placeholders inputs["multi_modal_placeholders"] = multi_modal_placeholders
if mm_processor_kwargs is not None: if mm_processor_kwargs is not None:
...@@ -295,6 +303,18 @@ class SingletonInputsAdapter: ...@@ -295,6 +303,18 @@ class SingletonInputsAdapter:
assert_never(inputs) assert_never(inputs)
@cached_property
def multi_modal_hashes(self) -> List[str]:
inputs = self.inputs
if inputs["type"] == "token":
return inputs.get("multi_modal_hashes", [])
if inputs["type"] == "multimodal":
return inputs.get("mm_hashes", [])
assert_never(inputs)
@cached_property @cached_property
def multi_modal_placeholders(self) -> "MultiModalPlaceholderDict": def multi_modal_placeholders(self) -> "MultiModalPlaceholderDict":
inputs = self.inputs inputs = self.inputs
......
...@@ -215,6 +215,9 @@ class MultiModalInputsV2(TypedDict): ...@@ -215,6 +215,9 @@ class MultiModalInputsV2(TypedDict):
mm_kwargs: MultiModalKwargs mm_kwargs: MultiModalKwargs
"""Keyword arguments to be directly passed to the model after batching.""" """Keyword arguments to be directly passed to the model after batching."""
mm_hashes: NotRequired[List[str]]
"""The hashes of the multi-modal data."""
mm_placeholders: MultiModalPlaceholderDict mm_placeholders: MultiModalPlaceholderDict
""" """
For each modality, information about the placeholder tokens in For each modality, information about the placeholder tokens in
......
...@@ -4,7 +4,9 @@ from typing import Dict, Iterable, List, Optional ...@@ -4,7 +4,9 @@ from typing import Dict, Iterable, List, Optional
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import cdiv from vllm.utils import cdiv
from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue, from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue,
KVCacheBlock, hash_block_tokens, KVCacheBlock,
generate_block_hash_extra_keys,
hash_block_tokens,
hash_request_tokens) hash_request_tokens)
from vllm.v1.request import Request from vllm.v1.request import Request
...@@ -83,10 +85,12 @@ class KVCacheManager: ...@@ -83,10 +85,12 @@ class KVCacheManager:
computed_blocks = [] computed_blocks = []
# TODO(rickyx): potentially we could cache this so we don't have to # The block hashes for the request may already be computed
# recompute it every time. # if the request was preempted and resumed.
block_hashes = hash_request_tokens(self.block_size, if not request.kv_block_hashes:
request.all_token_ids) request.set_kv_block_hashes(
hash_request_tokens(self.block_size, request))
block_hashes = request.kv_block_hashes
for block_hash in block_hashes: for block_hash in block_hashes:
# block_hashes is a chain of block hashes. If a block hash is not # block_hashes is a chain of block hashes. If a block hash is not
...@@ -242,12 +246,14 @@ class KVCacheManager: ...@@ -242,12 +246,14 @@ class KVCacheManager:
num_computed_tokens = len(computed_blocks) * self.block_size num_computed_tokens = len(computed_blocks) * self.block_size
num_full_blocks = (num_computed_tokens + num_tokens) // self.block_size num_full_blocks = (num_computed_tokens + num_tokens) // self.block_size
new_full_blocks = self.req_to_blocks[
request.request_id][len(computed_blocks):num_full_blocks]
if new_full_blocks:
self._cache_full_blocks( self._cache_full_blocks(
request=request, request=request,
blk_start_idx=len(computed_blocks), blk_start_idx=len(computed_blocks),
# The new full blocks are the full blocks that are not computed. # The new full blocks are the full blocks that are not computed.
full_blocks=self.req_to_blocks[request.request_id] full_blocks=new_full_blocks,
[len(computed_blocks):num_full_blocks],
prev_block=computed_blocks[-1] if computed_blocks else None, prev_block=computed_blocks[-1] if computed_blocks else None,
) )
...@@ -376,6 +382,8 @@ class KVCacheManager: ...@@ -376,6 +382,8 @@ class KVCacheManager:
full_blocks: The list of blocks to update hash metadata. full_blocks: The list of blocks to update hash metadata.
prev_block: The previous block in the chain. prev_block: The previous block in the chain.
""" """
num_cached_block_hashes = len(request.kv_block_hashes)
# Update the new blocks with the block hashes through the chain. # Update the new blocks with the block hashes through the chain.
prev_block_hash_value = None prev_block_hash_value = None
if prev_block is not None: if prev_block is not None:
...@@ -387,17 +395,35 @@ class KVCacheManager: ...@@ -387,17 +395,35 @@ class KVCacheManager:
for i, blk in enumerate(full_blocks): for i, blk in enumerate(full_blocks):
blk_idx = blk_start_idx + i blk_idx = blk_start_idx + i
block_tokens = request.all_token_ids[blk_idx * if blk_idx < num_cached_block_hashes:
self.block_size:(blk_idx + # The block hash may already be computed in
1) * # "get_computed_blocks" if the tokens are not generated by
self.block_size] # this request (either the prompt tokens or the previously
# generated tokens with preemption). In this case we simply
# reuse the block hash.
block_hash = request.kv_block_hashes[blk_idx]
else:
# Otherwise compute the block hash and cache it in the request
# in case it will be preempted in the future.
start_token_idx = blk_idx * self.block_size
end_token_idx = (blk_idx + 1) * self.block_size
block_tokens = request.all_token_ids[
start_token_idx:end_token_idx]
assert len(block_tokens) == self.block_size, ( assert len(block_tokens) == self.block_size, (
f"Expected {self.block_size} tokens, got {len(block_tokens)} " f"Expected {self.block_size} tokens, got "
f"at {blk_idx}th block for request " f"{len(block_tokens)} at {blk_idx}th block for request "
f"{request.request_id}({request})") f"{request.request_id}({request})")
# Generate extra keys for multi-modal inputs. Note that since
# we reach to this branch only when the block is completed with
# generated tokens, we only need to consider the last mm input.
extra_keys, _ = generate_block_hash_extra_keys(
request, start_token_idx, end_token_idx, -1)
# Compute the hash of the current block. # Compute the hash of the current block.
block_hash = hash_block_tokens(prev_block_hash_value, block_tokens) block_hash = hash_block_tokens(prev_block_hash_value,
block_tokens, extra_keys)
request.append_kv_block_hashes(block_hash)
# Update and added the full block to the cache. # Update and added the full block to the cache.
blk.block_hash = block_hash blk.block_hash = block_hash
......
"""KV-Cache Utilities.""" """KV-Cache Utilities."""
from collections.abc import Sequence from collections.abc import Sequence
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, NamedTuple, Optional, Tuple from typing import Any, List, NamedTuple, Optional, Tuple
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.v1.request import Request
logger = init_logger(__name__) logger = init_logger(__name__)
class BlockHashType(NamedTuple): class BlockHashType(NamedTuple):
"""Hash value of a block and the token IDs in the block. """Hash value of a block (int), the token IDs in the block, and extra keys.
The reason we keep a tuple of token IDs is to make sure no hash The reason we keep a tuple of token IDs and extra keys is to make sure
collision happens when the hash value is the same. no hash collision happens when the hash value is the same.
""" """
# Hash value of the block in an integer.
hash_value: int hash_value: int
# Token IDs in the block.
token_ids: Tuple[int, ...] token_ids: Tuple[int, ...]
# Extra keys for the block.
extra_keys: Optional[Any] = None
@dataclass @dataclass
...@@ -159,8 +164,80 @@ class FreeKVCacheBlockQueue: ...@@ -159,8 +164,80 @@ class FreeKVCacheBlockQueue:
return ret return ret
def hash_block_tokens(parent_block_hash: Optional[int], def generate_block_hash_extra_keys(
curr_block_token_ids: Sequence[int]) -> BlockHashType: request: Request, start_token_idx: int, end_token_idx: int,
start_mm_idx: int) -> Tuple[Optional[Tuple[Any, ...]], int]:
"""Generate extra keys for the block hash. The extra keys can come from
the multi-modal inputs and request specific metadata (e.g., LoRA ID).
For multi-modal inputs, the extra keys are (mm_hash, start_offset) that
indicate a mm input contained in the block and its starting offset in
the block tokens.
Args:
request: The request object.
start_token_idx: The start token index of the block.
end_token_idx: The end token index of the block.
start_mm_idx: The start multi-modal index of the block.
Returns:
A tuple of extra keys and the next multi-modal index.
"""
mm_positions, mm_hashes = request.mm_positions, request.mm_hashes
if not mm_positions:
return None, start_mm_idx
if mm_positions and len(mm_positions) != len(mm_hashes):
raise ValueError(
"The number of multi-modal positions and hashes must match. This "
"is likely because you do not enable MM preprocessor hashing. "
"Please set mm_cache_preprocessor=True.")
# Note that we assume mm_positions is sorted by offset.
# We do not need to check all mm inputs if the start token index is out of
# range. This usually happens in the late prefill phase and decoding phase.
if mm_positions[-1]["offset"] + mm_positions[-1][
"length"] < start_token_idx:
return None, start_mm_idx
# Support start_mm_idx == -1 to indicate the last mm input.
if start_mm_idx < 0:
assert -start_mm_idx <= len(mm_positions)
start_mm_idx = len(mm_positions) + start_mm_idx
extra_keys = []
curr_mm_idx = start_mm_idx
while mm_positions and curr_mm_idx < len(mm_positions):
assert mm_hashes[curr_mm_idx] is not None
offset = mm_positions[curr_mm_idx]["offset"]
length = mm_positions[curr_mm_idx]["length"]
if end_token_idx > offset:
if start_token_idx > offset + length:
# This block has passed the current mm input.
curr_mm_idx += 1
continue
# The block contains the current mm input.
mm_start = max(0, start_token_idx - offset)
extra_keys.append((mm_hashes[curr_mm_idx], mm_start))
if end_token_idx >= offset + length:
# If this block contains the end of the current mm input,
# move to the next mm input as this block may also contain
# the next mm input.
curr_mm_idx += 1
else:
# Otherwise this block is done with mm inputs.
break
else:
# This block has not reached the current mm input.
break
return tuple(extra_keys), curr_mm_idx
def hash_block_tokens(
parent_block_hash: Optional[int],
curr_block_token_ids: Sequence[int],
extra_keys: Optional[Tuple[Any, ...]] = None) -> BlockHashType:
"""Computes a hash value corresponding to the contents of a block and """Computes a hash value corresponding to the contents of a block and
the contents of the preceding block(s). The hash value is used for the contents of the preceding block(s). The hash value is used for
prefix caching. We use LRU cache for this function to avoid recomputing prefix caching. We use LRU cache for this function to avoid recomputing
...@@ -174,27 +251,39 @@ def hash_block_tokens(parent_block_hash: Optional[int], ...@@ -174,27 +251,39 @@ def hash_block_tokens(parent_block_hash: Optional[int],
if this is the first block. if this is the first block.
curr_block_token_ids: A list of token ids in the current curr_block_token_ids: A list of token ids in the current
block. The current block is assumed to be full. block. The current block is assumed to be full.
extra_keys: Extra keys for the block.
Returns: Returns:
The hash value of the block and the token ids in the block. The hash value of the block and the token ids in the block.
The entire tuple is used as the hash key of the block. The entire tuple is used as the hash key of the block.
""" """
return BlockHashType(hash((parent_block_hash, *curr_block_token_ids)), return BlockHashType(hash((parent_block_hash, *curr_block_token_ids)),
tuple(curr_block_token_ids)) tuple(curr_block_token_ids), extra_keys)
def hash_request_tokens(block_size: int, def hash_request_tokens(block_size: int,
token_ids: Sequence[int]) -> List[BlockHashType]: request: Request) -> List[BlockHashType]:
"""Computes hash values of a chain of blocks given a sequence of """Computes hash values of a chain of blocks given a sequence of
token IDs. The hash value is used for prefix caching. token IDs. The hash value is used for prefix caching.
Args: Args:
block_size: The size of each block. block_size: The size of each block.
token_ids: A sequence of token ids in the request. request: The request object.
Returns: Returns:
The list of computed hash values. The list of computed hash values.
""" """
token_ids = request.all_token_ids
mm_positions, mm_hashes = request.mm_positions, request.mm_hashes
if mm_positions and len(mm_positions) != len(mm_hashes):
raise ValueError(
"The number of multi-modal positions and hashes must match.")
# TODO: Extend this to support other features such as LoRA.
need_extra_keys = bool(mm_positions)
extra_keys = None
curr_mm_idx = 0
ret = [] ret = []
parent_block_hash_value = None parent_block_hash_value = None
for start in range(0, len(token_ids), block_size): for start in range(0, len(token_ids), block_size):
...@@ -203,8 +292,14 @@ def hash_request_tokens(block_size: int, ...@@ -203,8 +292,14 @@ def hash_request_tokens(block_size: int,
# Do not hash the block if it is not full. # Do not hash the block if it is not full.
if len(block_token_ids) < block_size: if len(block_token_ids) < block_size:
break break
# Add extra keys if the block is a multi-modal block.
if need_extra_keys:
extra_keys, curr_mm_idx = generate_block_hash_extra_keys(
request, start, end, curr_mm_idx)
block_hash = hash_block_tokens(parent_block_hash_value, block_hash = hash_block_tokens(parent_block_hash_value,
block_token_ids) block_token_ids, extra_keys)
ret.append(block_hash) ret.append(block_hash)
parent_block_hash_value = block_hash.hash_value parent_block_hash_value = block_hash.hash_value
return ret return ret
...@@ -516,6 +516,7 @@ class NewRequestData: ...@@ -516,6 +516,7 @@ class NewRequestData:
prompt_token_ids: List[int] prompt_token_ids: List[int]
prompt: Optional[str] prompt: Optional[str]
mm_inputs: List["MultiModalKwargs"] mm_inputs: List["MultiModalKwargs"]
mm_hashes: List[str]
mm_positions: List["PlaceholderRange"] mm_positions: List["PlaceholderRange"]
sampling_params: SamplingParams sampling_params: SamplingParams
block_ids: List[int] block_ids: List[int]
...@@ -533,6 +534,7 @@ class NewRequestData: ...@@ -533,6 +534,7 @@ class NewRequestData:
prompt_token_ids=request.prompt_token_ids, prompt_token_ids=request.prompt_token_ids,
prompt=request.prompt, prompt=request.prompt,
mm_inputs=request.mm_inputs, mm_inputs=request.mm_inputs,
mm_hashes=request.mm_hashes,
mm_positions=request.mm_positions, mm_positions=request.mm_positions,
sampling_params=request.sampling_params, sampling_params=request.sampling_params,
block_ids=block_ids, block_ids=block_ids,
......
...@@ -60,9 +60,13 @@ class AsyncLLM(EngineClient): ...@@ -60,9 +60,13 @@ class AsyncLLM(EngineClient):
self.client_aborted_requests: List[str] = [] self.client_aborted_requests: List[str] = []
# Processor (converts Inputs --> EngineCoreRequests). # Processor (converts Inputs --> EngineCoreRequests).
self.processor = Processor(vllm_config.model_config, self.processor = Processor(
vllm_config.lora_config, self.tokenizer, model_config=vllm_config.model_config,
input_registry) cache_config=vllm_config.cache_config,
lora_config=vllm_config.lora_config,
tokenizer=self.tokenizer,
input_registry=input_registry,
)
# Detokenizer (converts EngineCoreOutputs --> RequestOutput). # Detokenizer (converts EngineCoreOutputs --> RequestOutput).
self.detokenizer = Detokenizer( self.detokenizer = Detokenizer(
......
...@@ -65,7 +65,8 @@ class EngineCore: ...@@ -65,7 +65,8 @@ class EngineCore:
self._last_logging_time = time.time() self._last_logging_time = time.time()
self.mm_input_mapper_server = MMInputMapperServer() self.mm_input_mapper_server = MMInputMapperServer(
vllm_config.model_config)
def _initialize_kv_caches(self, def _initialize_kv_caches(self,
cache_config: CacheConfig) -> Tuple[int, int]: cache_config: CacheConfig) -> Tuple[int, int]:
...@@ -98,9 +99,8 @@ class EngineCore: ...@@ -98,9 +99,8 @@ class EngineCore:
# MM mapper, so anything that has a hash must have a HIT cache # MM mapper, so anything that has a hash must have a HIT cache
# entry here as well. # entry here as well.
assert request.mm_inputs is not None assert request.mm_inputs is not None
request.mm_inputs, request.mm_hashes = ( request.mm_inputs = self.mm_input_mapper_server.process_inputs(
self.mm_input_mapper_server.process_inputs( request.mm_inputs, request.mm_hashes)
request.mm_inputs, request.mm_hashes))
req = Request.from_engine_core_request(request) req = Request.from_engine_core_request(request)
......
...@@ -55,9 +55,12 @@ class LLMEngine: ...@@ -55,9 +55,12 @@ class LLMEngine:
self.tokenizer.ping() self.tokenizer.ping()
# Processor (convert Inputs --> EngineCoreRequests) # Processor (convert Inputs --> EngineCoreRequests)
self.processor = Processor(vllm_config.model_config, self.processor = Processor(model_config=vllm_config.model_config,
vllm_config.lora_config, self.tokenizer, cache_config=vllm_config.cache_config,
input_registry, mm_registry) lora_config=vllm_config.lora_config,
tokenizer=self.tokenizer,
input_registry=input_registry,
mm_registry=mm_registry)
# Detokenizer (converts EngineCoreOutputs --> RequestOutput) # Detokenizer (converts EngineCoreOutputs --> RequestOutput)
self.detokenizer = Detokenizer( self.detokenizer = Detokenizer(
......
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional
import PIL import PIL
from blake3 import blake3 from blake3 import blake3
...@@ -42,6 +42,8 @@ class MMInputMapperClient: ...@@ -42,6 +42,8 @@ class MMInputMapperClient:
model_config) model_config)
self.mm_registry.init_mm_limits_per_prompt(model_config) self.mm_registry.init_mm_limits_per_prompt(model_config)
# Init cache
self.use_cache = model_config.mm_cache_preprocessor
self.mm_cache = LRUDictCache[str, MultiModalKwargs](MM_CACHE_SIZE) self.mm_cache = LRUDictCache[str, MultiModalKwargs](MM_CACHE_SIZE)
# DEBUG: Set to None to disable # DEBUG: Set to None to disable
...@@ -61,7 +63,7 @@ class MMInputMapperClient: ...@@ -61,7 +63,7 @@ class MMInputMapperClient:
mm_hashes: Optional[List[str]], mm_hashes: Optional[List[str]],
mm_processor_kwargs: Optional[Dict[str, Any]], mm_processor_kwargs: Optional[Dict[str, Any]],
precomputed_mm_inputs: Optional[List[MultiModalKwargs]], precomputed_mm_inputs: Optional[List[MultiModalKwargs]],
) -> Tuple[List[MultiModalKwargs], Optional[List[str]]]: ) -> List[MultiModalKwargs]:
if precomputed_mm_inputs is None: if precomputed_mm_inputs is None:
image_inputs = mm_data["image"] image_inputs = mm_data["image"]
if not isinstance(image_inputs, list): if not isinstance(image_inputs, list):
...@@ -70,26 +72,21 @@ class MMInputMapperClient: ...@@ -70,26 +72,21 @@ class MMInputMapperClient:
else: else:
num_inputs = len(precomputed_mm_inputs) num_inputs = len(precomputed_mm_inputs)
# Check if hash is enabled # Sanity
use_hash = mm_hashes is not None if self.use_cache:
if use_hash:
assert mm_hashes is not None assert mm_hashes is not None
assert num_inputs == len( assert num_inputs == len(mm_hashes)
mm_hashes), "num_inputs = {} len(mm_hashes) = {}".format(
num_inputs, len(mm_hashes))
# Process each image input separately, so that later we can schedule # Process each image input separately, so that later we can schedule
# them in a fine-grained manner. # them in a fine-grained manner.
# Apply caching (if enabled) and reuse precomputed inputs (if provided) # Apply caching (if enabled) and reuse precomputed inputs (if provided)
ret_hashes: Optional[List[str]] = [] if use_hash else None
ret_inputs: List[MultiModalKwargs] = [] ret_inputs: List[MultiModalKwargs] = []
for input_id in range(num_inputs): for input_id in range(num_inputs):
if self.mm_debug_cache_hit_ratio_steps is not None: if self.mm_debug_cache_hit_ratio_steps is not None:
self.cache_hit_ratio(self.mm_debug_cache_hit_ratio_steps) self.cache_hit_ratio(self.mm_debug_cache_hit_ratio_steps)
mm_hash = None
mm_input = None mm_input = None
if use_hash: if self.use_cache:
assert mm_hashes is not None assert mm_hashes is not None
mm_hash = mm_hashes[input_id] mm_hash = mm_hashes[input_id]
mm_input = self.mm_cache.get(mm_hash) mm_input = self.mm_cache.get(mm_hash)
...@@ -106,7 +103,7 @@ class MMInputMapperClient: ...@@ -106,7 +103,7 @@ class MMInputMapperClient:
mm_processor_kwargs=mm_processor_kwargs, mm_processor_kwargs=mm_processor_kwargs,
) )
if use_hash: if self.use_cache:
# Add to cache # Add to cache
assert mm_hash is not None assert mm_hash is not None
self.mm_cache.put(mm_hash, mm_input) self.mm_cache.put(mm_hash, mm_input)
...@@ -114,18 +111,15 @@ class MMInputMapperClient: ...@@ -114,18 +111,15 @@ class MMInputMapperClient:
self.mm_cache_hits += 1 self.mm_cache_hits += 1
mm_input = None # Avoids sending mm_input to Server mm_input = None # Avoids sending mm_input to Server
if use_hash:
assert mm_hash is not None
assert ret_hashes is not None
ret_hashes.append(mm_hash)
ret_inputs.append(mm_input) ret_inputs.append(mm_input)
return ret_inputs, ret_hashes return ret_inputs
class MMInputMapperServer: class MMInputMapperServer:
def __init__(self, ): def __init__(self, model_config):
self.use_cache = model_config.mm_cache_preprocessor
self.mm_cache = LRUDictCache[str, MultiModalKwargs](MM_CACHE_SIZE) self.mm_cache = LRUDictCache[str, MultiModalKwargs](MM_CACHE_SIZE)
def process_inputs( def process_inputs(
...@@ -135,6 +129,9 @@ class MMInputMapperServer: ...@@ -135,6 +129,9 @@ class MMInputMapperServer:
) -> List[MultiModalKwargs]: ) -> List[MultiModalKwargs]:
assert len(mm_inputs) == len(mm_hashes) assert len(mm_inputs) == len(mm_hashes)
if not self.use_cache:
return mm_inputs
full_mm_inputs = [] full_mm_inputs = []
for mm_input, mm_hash in zip(mm_inputs, mm_hashes): for mm_input, mm_hash in zip(mm_inputs, mm_hashes):
assert mm_hash is not None assert mm_hash is not None
......
import time import time
from typing import Any, Dict, Mapping, Optional, Tuple, Union from typing import Any, Dict, Mapping, Optional, Tuple, Union
from vllm.config import LoRAConfig, ModelConfig from vllm.config import CacheConfig, LoRAConfig, ModelConfig
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs, from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
PromptType, SingletonInputsAdapter) PromptType, SingletonInputsAdapter)
from vllm.inputs.parse import is_encoder_decoder_inputs from vllm.inputs.parse import is_encoder_decoder_inputs
...@@ -23,6 +23,7 @@ class Processor: ...@@ -23,6 +23,7 @@ class Processor:
def __init__( def __init__(
self, self,
model_config: ModelConfig, model_config: ModelConfig,
cache_config: CacheConfig,
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
tokenizer: BaseTokenizerGroup, tokenizer: BaseTokenizerGroup,
input_registry: InputRegistry = INPUT_REGISTRY, input_registry: InputRegistry = INPUT_REGISTRY,
...@@ -45,8 +46,9 @@ class Processor: ...@@ -45,8 +46,9 @@ class Processor:
self.mm_input_mapper_client = MMInputMapperClient(model_config) self.mm_input_mapper_client = MMInputMapperClient(model_config)
# Multi-modal hasher (for images) # Multi-modal hasher (for images)
self.mm_hasher = MMHasher( self.use_hash = model_config.mm_cache_preprocessor or \
) if model_config.mm_cache_preprocessor else None cache_config.enable_prefix_caching
self.mm_hasher = MMHasher()
# TODO: run in an ThreadpoolExecutor or BackgroundProcess. # TODO: run in an ThreadpoolExecutor or BackgroundProcess.
# This ideally should releases the GIL, so we should not block the # This ideally should releases the GIL, so we should not block the
...@@ -77,7 +79,7 @@ class Processor: ...@@ -77,7 +79,7 @@ class Processor:
# Compute MM hashes (if enabled) # Compute MM hashes (if enabled)
mm_hashes = None mm_hashes = None
if self.mm_hasher is not None: if self.use_hash:
mm_hashes = self.mm_hasher.hash(prompt) mm_hashes = self.mm_hasher.hash(prompt)
# Process inputs. # Process inputs.
...@@ -118,7 +120,7 @@ class Processor: ...@@ -118,7 +120,7 @@ class Processor:
# Apply MM mapper # Apply MM mapper
mm_inputs = None mm_inputs = None
if len(decoder_inputs.multi_modal_data) > 0: if len(decoder_inputs.multi_modal_data) > 0:
mm_inputs, mm_hashes = self.mm_input_mapper_client.process_inputs( mm_inputs = self.mm_input_mapper_client.process_inputs(
decoder_inputs.multi_modal_data, mm_hashes, decoder_inputs.multi_modal_data, mm_hashes,
decoder_inputs.mm_processor_kwargs, precomputed_mm_inputs) decoder_inputs.mm_processor_kwargs, precomputed_mm_inputs)
......
import enum import enum
from typing import List, Optional, Union from typing import TYPE_CHECKING, List, Optional, Union
from vllm.inputs import DecoderOnlyInputs, SingletonInputsAdapter, token_inputs from vllm.inputs import DecoderOnlyInputs, SingletonInputsAdapter, token_inputs
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
...@@ -9,6 +9,9 @@ from vllm.sequence import RequestMetrics ...@@ -9,6 +9,9 @@ from vllm.sequence import RequestMetrics
from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine import EngineCoreRequest
from vllm.v1.utils import ConstantList from vllm.v1.utils import ConstantList
if TYPE_CHECKING:
from vllm.v1.core.kv_cache_utils import BlockHashType
class Request: class Request:
...@@ -45,6 +48,7 @@ class Request: ...@@ -45,6 +48,7 @@ class Request:
self._all_token_ids: List[int] = self.prompt_token_ids.copy() self._all_token_ids: List[int] = self.prompt_token_ids.copy()
self.num_computed_tokens = 0 self.num_computed_tokens = 0
# Multi-modal input metadata.
mm_positions = self.inputs.multi_modal_placeholders mm_positions = self.inputs.multi_modal_placeholders
if mm_positions: if mm_positions:
# FIXME(woosuk): Support other modalities. # FIXME(woosuk): Support other modalities.
...@@ -56,6 +60,12 @@ class Request: ...@@ -56,6 +60,12 @@ class Request:
if self.inputs.multi_modal_inputs: if self.inputs.multi_modal_inputs:
self.mm_inputs = self.inputs.multi_modal_inputs self.mm_inputs = self.inputs.multi_modal_inputs
self.mm_hashes: List[str] = self.inputs.multi_modal_hashes
# Cache the computed kv block hashes of the request to avoid
# recomputing.
self._kv_block_hashes: List[BlockHashType] = []
@classmethod @classmethod
def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request": def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request":
return cls( return cls(
...@@ -65,6 +75,7 @@ class Request: ...@@ -65,6 +75,7 @@ class Request:
prompt=request.prompt, prompt=request.prompt,
multi_modal_data=None, multi_modal_data=None,
multi_modal_inputs=request.mm_inputs, multi_modal_inputs=request.mm_inputs,
multi_modal_hashes=request.mm_hashes,
multi_modal_placeholders=request.mm_placeholders, multi_modal_placeholders=request.mm_placeholders,
mm_processor_kwargs=None, mm_processor_kwargs=None,
), ),
...@@ -121,6 +132,17 @@ class Request: ...@@ -121,6 +132,17 @@ class Request:
num_tokens = self.mm_positions[input_id]["length"] num_tokens = self.mm_positions[input_id]["length"]
return num_tokens return num_tokens
@property
def kv_block_hashes(self) -> ConstantList["BlockHashType"]:
# Prevent directly appending to the kv_block_hashes.
return ConstantList(self._kv_block_hashes)
def set_kv_block_hashes(self, value: List["BlockHashType"]) -> None:
self._kv_block_hashes = value
def append_kv_block_hashes(self, block_hash: "BlockHashType") -> None:
self._kv_block_hashes.append(block_hash)
class RequestStatus(enum.IntEnum): class RequestStatus(enum.IntEnum):
"""Status of a request.""" """Status of a request."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment