Unverified Commit bf8717eb authored by Cody Yu's avatar Cody Yu Committed by GitHub
Browse files

[V1] Prefix caching for vision language models (#11187)


Signed-off-by: default avatarCody Yu <hao.yu.cody@gmail.com>
parent c77eb8a3
......@@ -2,16 +2,23 @@
import pytest
from vllm.inputs import token_inputs
from vllm.multimodal.inputs import PlaceholderRange
from vllm.sampling_params import SamplingParams
from vllm.utils import cdiv
from vllm.v1.core.kv_cache_manager import KVCacheManager, Request
from vllm.v1.core.kv_cache_utils import KVCacheBlock, hash_block_tokens
def make_request(request_id, prompt_token_ids):
def make_request(request_id,
prompt_token_ids,
mm_positions=None,
mm_hashes=None):
return Request(
request_id=request_id,
inputs=token_inputs(prompt_token_ids=prompt_token_ids),
inputs=token_inputs(prompt_token_ids=prompt_token_ids,
multi_modal_placeholders={"image": mm_positions}
if mm_positions else None,
multi_modal_hashes=mm_hashes),
sampling_params=SamplingParams(max_tokens=17),
eos_token_id=100,
arrival_time=0,
......@@ -38,6 +45,7 @@ def test_prefill():
all_token_ids = common_token_ids + unique_token_ids
req0 = make_request("0", all_token_ids)
computed_blocks = manager.get_computed_blocks(req0)
assert len(req0.kv_block_hashes) == 3
assert not computed_blocks
blocks = manager.allocate_slots(req0, 55, computed_blocks)
assert [b.block_id for b in blocks] == [0, 1, 2, 3, 4]
......@@ -61,6 +69,7 @@ def test_prefill():
unique_token_ids = [3] * 5
req1 = make_request("1", common_token_ids + unique_token_ids)
computed_blocks = manager.get_computed_blocks(req1)
assert len(req1.kv_block_hashes) == 3
assert [b.block_id for b in computed_blocks] == [0, 1, 2]
num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req1, num_new_tokens, computed_blocks)
......@@ -90,6 +99,7 @@ def test_prefill():
unique_token_ids = [3] * 6
req2 = make_request("2", common_token_ids + unique_token_ids)
computed_block = manager.get_computed_blocks(req2)
assert len(req2.kv_block_hashes) == 3
assert [b.block_id for b in computed_block] == [0, 1, 2]
num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req2, num_new_tokens, computed_blocks)
......@@ -416,3 +426,77 @@ def test_cache_blocks():
)
assert len(manager.cached_block_hash_to_block) == 3
assert blocks[0].block_hash is not None
def test_mm_prefix_caching():
"""
This tests that the multi-modal prefix caching is correct.
"""
manager = KVCacheManager(
block_size=16,
num_gpu_blocks=10,
max_model_len=8192,
sliding_window=None,
enable_caching=True,
num_preallocate_tokens=16,
)
# Common prompt tokens (T is text tokens and P is image placeholder tokens)
# [T,...,T, P0,...,P0], [P0,...,P0,T,...,T,P1,...,P1], [P1,...,P1]
common_token_ids = list(range(10)) + [-1] * 6
common_token_ids += [-1] * 4 + list(range(10, 20)) + [-1] * 2
common_token_ids += [-1] * 16
common_mm_positions = [
PlaceholderRange(offset=11, length=10),
PlaceholderRange(offset=30, length=18),
]
common_mm_hashes = ["aaa", "bbb"]
# A unique image plus some text tokens.
unique_token_ids = [-1] * 7 + [100] * 4
all_token_ids = common_token_ids + unique_token_ids
mm_positions = common_mm_positions + [
PlaceholderRange(offset=48, length=7)
]
mm_hashes = common_mm_hashes + ["ccc"]
req0 = make_request("0",
all_token_ids,
mm_positions=mm_positions,
mm_hashes=mm_hashes)
computed_blocks = manager.get_computed_blocks(req0)
# Completed block should have hashes with extra keys.
assert not computed_blocks
assert len(req0.kv_block_hashes) == 3
assert req0.kv_block_hashes[0].extra_keys == (("aaa", 0), )
assert req0.kv_block_hashes[1].extra_keys == (("aaa", 5), ("bbb", 0))
assert req0.kv_block_hashes[2].extra_keys == (("bbb", 2), )
blocks = manager.allocate_slots(req0, 59, computed_blocks)
assert [b.block_id for b in blocks] == [0, 1, 2, 3, 4]
req0.num_computed_tokens = 59
# Append slots without allocating a new block.
for _ in range(5):
req0.append_output_token_ids(8)
new_blocks = manager.append_slots(req0, 5)
assert new_blocks is not None and len(new_blocks) == 0
# The just completed block should have hashes with extra keys.
assert len(req0.kv_block_hashes) == 4
assert req0.kv_block_hashes[3].extra_keys == (("ccc", 0), )
# Cache hit.
unique_token_ids = [-1] * 7 + [200] * 5
all_token_ids = common_token_ids + unique_token_ids
mm_positions = common_mm_positions + [
PlaceholderRange(offset=48, length=7)
]
mm_hashes = common_mm_hashes + ["ccc"]
req1 = make_request("1",
all_token_ids,
mm_positions=mm_positions,
mm_hashes=mm_hashes)
computed_blocks = manager.get_computed_blocks(req1)
assert len(computed_blocks) == 3
......@@ -31,14 +31,6 @@ def test_prefix_caching_from_cli():
assert engine_args.enable_prefix_caching
def test_defaults():
engine_args = EngineArgs(model="facebook/opt-125m")
# Assert V1 defaults
assert (engine_args.enable_prefix_caching
), "V1 turns on prefix caching by default"
def test_defaults_with_usage_context():
engine_args = EngineArgs(model="facebook/opt-125m")
vllm_config: VllmConfig = engine_args.create_engine_config(
......@@ -52,10 +44,3 @@ def test_defaults_with_usage_context():
UsageContext.OPENAI_API_SERVER)
assert vllm_config.scheduler_config.max_num_seqs == 1024
assert vllm_config.scheduler_config.max_num_batched_tokens == 2048
def test_prefix_cache_disabled_with_multimodel():
engine_args = EngineArgs(model="llava-hf/llava-1.5-7b-hf")
vllm_config = engine_args.create_engine_config(UsageContext.LLM_CLASS)
assert not vllm_config.cache_config.enable_prefix_caching
......@@ -205,6 +205,7 @@ class EngineArgs:
# by user.
if self.enable_prefix_caching is None:
self.enable_prefix_caching = bool(envs.VLLM_USE_V1)
# Override max_num_seqs if it's not set by user.
if self.max_num_seqs is None:
self.max_num_seqs = 256 if not envs.VLLM_USE_V1 else 1024
......@@ -1026,11 +1027,11 @@ class EngineArgs:
device_config = DeviceConfig(device=self.device)
model_config = self.create_model_config()
if model_config.is_multimodal_model:
if self.enable_prefix_caching:
logger.warning(
"--enable-prefix-caching is currently not "
"supported for multimodal models and has been disabled.")
if (model_config.is_multimodal_model and not envs.VLLM_USE_V1
and self.enable_prefix_caching):
logger.warning("--enable-prefix-caching is currently not "
"supported for multimodal models in v0 and "
"has been disabled.")
self.enable_prefix_caching = False
cache_config = CacheConfig(
......@@ -1249,11 +1250,14 @@ class EngineArgs:
# When no user override, set the default values based on the usage
# context.
# TODO(woosuk): Tune the default values for different hardware.
if self.max_num_batched_tokens is None:
if usage_context == UsageContext.LLM_CLASS:
self.max_num_batched_tokens = 8192
elif usage_context == UsageContext.OPENAI_API_SERVER:
self.max_num_batched_tokens = 2048
default_max_num_batched_tokens = {
UsageContext.LLM_CLASS: 8192,
UsageContext.OPENAI_API_SERVER: 2048,
}
if (self.max_num_batched_tokens is None
and usage_context in default_max_num_batched_tokens):
self.max_num_batched_tokens = default_max_num_batched_tokens[
usage_context]
logger.warning(
"Setting max_num_batched_tokens to %d for %s usage context.",
self.max_num_batched_tokens, usage_context.value)
......@@ -1263,9 +1267,6 @@ class EngineArgs:
Override the EngineConfig's configs based on the usage context for V1.
"""
assert envs.VLLM_USE_V1, "V1 is not enabled"
if engine_config.model_config.is_multimodal_model:
# TODO (ywang96): Enable APC by default when VLM supports it.
assert not engine_config.cache_config.enable_prefix_caching
@dataclass
......
......@@ -162,6 +162,11 @@ class TokenInputs(TypedDict):
Placeholder ranges for the multi-modal data.
"""
multi_modal_hashes: NotRequired[List[str]]
"""
The hashes of the multi-modal data.
"""
mm_processor_kwargs: NotRequired[Dict[str, Any]]
"""
Optional multi-modal processor kwargs to be forwarded to the
......@@ -177,6 +182,7 @@ def token_inputs(
prompt: Optional[str] = None,
multi_modal_data: Optional["MultiModalDataDict"] = None,
multi_modal_inputs: Optional["MultiModalKwargs"] = None,
multi_modal_hashes: Optional[List[str]] = None,
multi_modal_placeholders: Optional["MultiModalPlaceholderDict"] = None,
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
) -> TokenInputs:
......@@ -191,6 +197,8 @@ def token_inputs(
inputs["multi_modal_data"] = multi_modal_data
if multi_modal_inputs is not None:
inputs["multi_modal_inputs"] = multi_modal_inputs
if multi_modal_hashes is not None:
inputs["multi_modal_hashes"] = multi_modal_hashes
if multi_modal_placeholders is not None:
inputs["multi_modal_placeholders"] = multi_modal_placeholders
if mm_processor_kwargs is not None:
......@@ -295,6 +303,18 @@ class SingletonInputsAdapter:
assert_never(inputs)
@cached_property
def multi_modal_hashes(self) -> List[str]:
inputs = self.inputs
if inputs["type"] == "token":
return inputs.get("multi_modal_hashes", [])
if inputs["type"] == "multimodal":
return inputs.get("mm_hashes", [])
assert_never(inputs)
@cached_property
def multi_modal_placeholders(self) -> "MultiModalPlaceholderDict":
inputs = self.inputs
......
......@@ -215,6 +215,9 @@ class MultiModalInputsV2(TypedDict):
mm_kwargs: MultiModalKwargs
"""Keyword arguments to be directly passed to the model after batching."""
mm_hashes: NotRequired[List[str]]
"""The hashes of the multi-modal data."""
mm_placeholders: MultiModalPlaceholderDict
"""
For each modality, information about the placeholder tokens in
......
......@@ -4,7 +4,9 @@ from typing import Dict, Iterable, List, Optional
from vllm.logger import init_logger
from vllm.utils import cdiv
from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue,
KVCacheBlock, hash_block_tokens,
KVCacheBlock,
generate_block_hash_extra_keys,
hash_block_tokens,
hash_request_tokens)
from vllm.v1.request import Request
......@@ -83,10 +85,12 @@ class KVCacheManager:
computed_blocks = []
# TODO(rickyx): potentially we could cache this so we don't have to
# recompute it every time.
block_hashes = hash_request_tokens(self.block_size,
request.all_token_ids)
# The block hashes for the request may already be computed
# if the request was preempted and resumed.
if not request.kv_block_hashes:
request.set_kv_block_hashes(
hash_request_tokens(self.block_size, request))
block_hashes = request.kv_block_hashes
for block_hash in block_hashes:
# block_hashes is a chain of block hashes. If a block hash is not
......@@ -242,14 +246,16 @@ class KVCacheManager:
num_computed_tokens = len(computed_blocks) * self.block_size
num_full_blocks = (num_computed_tokens + num_tokens) // self.block_size
self._cache_full_blocks(
request=request,
blk_start_idx=len(computed_blocks),
# The new full blocks are the full blocks that are not computed.
full_blocks=self.req_to_blocks[request.request_id]
[len(computed_blocks):num_full_blocks],
prev_block=computed_blocks[-1] if computed_blocks else None,
)
new_full_blocks = self.req_to_blocks[
request.request_id][len(computed_blocks):num_full_blocks]
if new_full_blocks:
self._cache_full_blocks(
request=request,
blk_start_idx=len(computed_blocks),
# The new full blocks are the full blocks that are not computed.
full_blocks=new_full_blocks,
prev_block=computed_blocks[-1] if computed_blocks else None,
)
return new_blocks
......@@ -376,6 +382,8 @@ class KVCacheManager:
full_blocks: The list of blocks to update hash metadata.
prev_block: The previous block in the chain.
"""
num_cached_block_hashes = len(request.kv_block_hashes)
# Update the new blocks with the block hashes through the chain.
prev_block_hash_value = None
if prev_block is not None:
......@@ -387,17 +395,35 @@ class KVCacheManager:
for i, blk in enumerate(full_blocks):
blk_idx = blk_start_idx + i
block_tokens = request.all_token_ids[blk_idx *
self.block_size:(blk_idx +
1) *
self.block_size]
assert len(block_tokens) == self.block_size, (
f"Expected {self.block_size} tokens, got {len(block_tokens)} "
f"at {blk_idx}th block for request "
f"{request.request_id}({request})")
# Compute the hash of the current block.
block_hash = hash_block_tokens(prev_block_hash_value, block_tokens)
if blk_idx < num_cached_block_hashes:
# The block hash may already be computed in
# "get_computed_blocks" if the tokens are not generated by
# this request (either the prompt tokens or the previously
# generated tokens with preemption). In this case we simply
# reuse the block hash.
block_hash = request.kv_block_hashes[blk_idx]
else:
# Otherwise compute the block hash and cache it in the request
# in case it will be preempted in the future.
start_token_idx = blk_idx * self.block_size
end_token_idx = (blk_idx + 1) * self.block_size
block_tokens = request.all_token_ids[
start_token_idx:end_token_idx]
assert len(block_tokens) == self.block_size, (
f"Expected {self.block_size} tokens, got "
f"{len(block_tokens)} at {blk_idx}th block for request "
f"{request.request_id}({request})")
# Generate extra keys for multi-modal inputs. Note that since
# we reach to this branch only when the block is completed with
# generated tokens, we only need to consider the last mm input.
extra_keys, _ = generate_block_hash_extra_keys(
request, start_token_idx, end_token_idx, -1)
# Compute the hash of the current block.
block_hash = hash_block_tokens(prev_block_hash_value,
block_tokens, extra_keys)
request.append_kv_block_hashes(block_hash)
# Update and added the full block to the cache.
blk.block_hash = block_hash
......
"""KV-Cache Utilities."""
from collections.abc import Sequence
from dataclasses import dataclass
from typing import List, NamedTuple, Optional, Tuple
from typing import Any, List, NamedTuple, Optional, Tuple
from vllm.logger import init_logger
from vllm.v1.request import Request
logger = init_logger(__name__)
class BlockHashType(NamedTuple):
"""Hash value of a block and the token IDs in the block.
The reason we keep a tuple of token IDs is to make sure no hash
collision happens when the hash value is the same.
"""Hash value of a block (int), the token IDs in the block, and extra keys.
The reason we keep a tuple of token IDs and extra keys is to make sure
no hash collision happens when the hash value is the same.
"""
# Hash value of the block in an integer.
hash_value: int
# Token IDs in the block.
token_ids: Tuple[int, ...]
# Extra keys for the block.
extra_keys: Optional[Any] = None
@dataclass
......@@ -159,8 +164,80 @@ class FreeKVCacheBlockQueue:
return ret
def hash_block_tokens(parent_block_hash: Optional[int],
curr_block_token_ids: Sequence[int]) -> BlockHashType:
def generate_block_hash_extra_keys(
request: Request, start_token_idx: int, end_token_idx: int,
start_mm_idx: int) -> Tuple[Optional[Tuple[Any, ...]], int]:
"""Generate extra keys for the block hash. The extra keys can come from
the multi-modal inputs and request specific metadata (e.g., LoRA ID).
For multi-modal inputs, the extra keys are (mm_hash, start_offset) that
indicate a mm input contained in the block and its starting offset in
the block tokens.
Args:
request: The request object.
start_token_idx: The start token index of the block.
end_token_idx: The end token index of the block.
start_mm_idx: The start multi-modal index of the block.
Returns:
A tuple of extra keys and the next multi-modal index.
"""
mm_positions, mm_hashes = request.mm_positions, request.mm_hashes
if not mm_positions:
return None, start_mm_idx
if mm_positions and len(mm_positions) != len(mm_hashes):
raise ValueError(
"The number of multi-modal positions and hashes must match. This "
"is likely because you do not enable MM preprocessor hashing. "
"Please set mm_cache_preprocessor=True.")
# Note that we assume mm_positions is sorted by offset.
# We do not need to check all mm inputs if the start token index is out of
# range. This usually happens in the late prefill phase and decoding phase.
if mm_positions[-1]["offset"] + mm_positions[-1][
"length"] < start_token_idx:
return None, start_mm_idx
# Support start_mm_idx == -1 to indicate the last mm input.
if start_mm_idx < 0:
assert -start_mm_idx <= len(mm_positions)
start_mm_idx = len(mm_positions) + start_mm_idx
extra_keys = []
curr_mm_idx = start_mm_idx
while mm_positions and curr_mm_idx < len(mm_positions):
assert mm_hashes[curr_mm_idx] is not None
offset = mm_positions[curr_mm_idx]["offset"]
length = mm_positions[curr_mm_idx]["length"]
if end_token_idx > offset:
if start_token_idx > offset + length:
# This block has passed the current mm input.
curr_mm_idx += 1
continue
# The block contains the current mm input.
mm_start = max(0, start_token_idx - offset)
extra_keys.append((mm_hashes[curr_mm_idx], mm_start))
if end_token_idx >= offset + length:
# If this block contains the end of the current mm input,
# move to the next mm input as this block may also contain
# the next mm input.
curr_mm_idx += 1
else:
# Otherwise this block is done with mm inputs.
break
else:
# This block has not reached the current mm input.
break
return tuple(extra_keys), curr_mm_idx
def hash_block_tokens(
parent_block_hash: Optional[int],
curr_block_token_ids: Sequence[int],
extra_keys: Optional[Tuple[Any, ...]] = None) -> BlockHashType:
"""Computes a hash value corresponding to the contents of a block and
the contents of the preceding block(s). The hash value is used for
prefix caching. We use LRU cache for this function to avoid recomputing
......@@ -174,27 +251,39 @@ def hash_block_tokens(parent_block_hash: Optional[int],
if this is the first block.
curr_block_token_ids: A list of token ids in the current
block. The current block is assumed to be full.
extra_keys: Extra keys for the block.
Returns:
The hash value of the block and the token ids in the block.
The entire tuple is used as the hash key of the block.
"""
return BlockHashType(hash((parent_block_hash, *curr_block_token_ids)),
tuple(curr_block_token_ids))
tuple(curr_block_token_ids), extra_keys)
def hash_request_tokens(block_size: int,
token_ids: Sequence[int]) -> List[BlockHashType]:
request: Request) -> List[BlockHashType]:
"""Computes hash values of a chain of blocks given a sequence of
token IDs. The hash value is used for prefix caching.
Args:
block_size: The size of each block.
token_ids: A sequence of token ids in the request.
request: The request object.
Returns:
The list of computed hash values.
"""
token_ids = request.all_token_ids
mm_positions, mm_hashes = request.mm_positions, request.mm_hashes
if mm_positions and len(mm_positions) != len(mm_hashes):
raise ValueError(
"The number of multi-modal positions and hashes must match.")
# TODO: Extend this to support other features such as LoRA.
need_extra_keys = bool(mm_positions)
extra_keys = None
curr_mm_idx = 0
ret = []
parent_block_hash_value = None
for start in range(0, len(token_ids), block_size):
......@@ -203,8 +292,14 @@ def hash_request_tokens(block_size: int,
# Do not hash the block if it is not full.
if len(block_token_ids) < block_size:
break
# Add extra keys if the block is a multi-modal block.
if need_extra_keys:
extra_keys, curr_mm_idx = generate_block_hash_extra_keys(
request, start, end, curr_mm_idx)
block_hash = hash_block_tokens(parent_block_hash_value,
block_token_ids)
block_token_ids, extra_keys)
ret.append(block_hash)
parent_block_hash_value = block_hash.hash_value
return ret
......@@ -516,6 +516,7 @@ class NewRequestData:
prompt_token_ids: List[int]
prompt: Optional[str]
mm_inputs: List["MultiModalKwargs"]
mm_hashes: List[str]
mm_positions: List["PlaceholderRange"]
sampling_params: SamplingParams
block_ids: List[int]
......@@ -533,6 +534,7 @@ class NewRequestData:
prompt_token_ids=request.prompt_token_ids,
prompt=request.prompt,
mm_inputs=request.mm_inputs,
mm_hashes=request.mm_hashes,
mm_positions=request.mm_positions,
sampling_params=request.sampling_params,
block_ids=block_ids,
......
......@@ -60,9 +60,13 @@ class AsyncLLM(EngineClient):
self.client_aborted_requests: List[str] = []
# Processor (converts Inputs --> EngineCoreRequests).
self.processor = Processor(vllm_config.model_config,
vllm_config.lora_config, self.tokenizer,
input_registry)
self.processor = Processor(
model_config=vllm_config.model_config,
cache_config=vllm_config.cache_config,
lora_config=vllm_config.lora_config,
tokenizer=self.tokenizer,
input_registry=input_registry,
)
# Detokenizer (converts EngineCoreOutputs --> RequestOutput).
self.detokenizer = Detokenizer(
......
......@@ -65,7 +65,8 @@ class EngineCore:
self._last_logging_time = time.time()
self.mm_input_mapper_server = MMInputMapperServer()
self.mm_input_mapper_server = MMInputMapperServer(
vllm_config.model_config)
def _initialize_kv_caches(self,
cache_config: CacheConfig) -> Tuple[int, int]:
......@@ -98,9 +99,8 @@ class EngineCore:
# MM mapper, so anything that has a hash must have a HIT cache
# entry here as well.
assert request.mm_inputs is not None
request.mm_inputs, request.mm_hashes = (
self.mm_input_mapper_server.process_inputs(
request.mm_inputs, request.mm_hashes))
request.mm_inputs = self.mm_input_mapper_server.process_inputs(
request.mm_inputs, request.mm_hashes)
req = Request.from_engine_core_request(request)
......
......@@ -55,9 +55,12 @@ class LLMEngine:
self.tokenizer.ping()
# Processor (convert Inputs --> EngineCoreRequests)
self.processor = Processor(vllm_config.model_config,
vllm_config.lora_config, self.tokenizer,
input_registry, mm_registry)
self.processor = Processor(model_config=vllm_config.model_config,
cache_config=vllm_config.cache_config,
lora_config=vllm_config.lora_config,
tokenizer=self.tokenizer,
input_registry=input_registry,
mm_registry=mm_registry)
# Detokenizer (converts EngineCoreOutputs --> RequestOutput)
self.detokenizer = Detokenizer(
......
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional
import PIL
from blake3 import blake3
......@@ -42,6 +42,8 @@ class MMInputMapperClient:
model_config)
self.mm_registry.init_mm_limits_per_prompt(model_config)
# Init cache
self.use_cache = model_config.mm_cache_preprocessor
self.mm_cache = LRUDictCache[str, MultiModalKwargs](MM_CACHE_SIZE)
# DEBUG: Set to None to disable
......@@ -61,7 +63,7 @@ class MMInputMapperClient:
mm_hashes: Optional[List[str]],
mm_processor_kwargs: Optional[Dict[str, Any]],
precomputed_mm_inputs: Optional[List[MultiModalKwargs]],
) -> Tuple[List[MultiModalKwargs], Optional[List[str]]]:
) -> List[MultiModalKwargs]:
if precomputed_mm_inputs is None:
image_inputs = mm_data["image"]
if not isinstance(image_inputs, list):
......@@ -70,26 +72,21 @@ class MMInputMapperClient:
else:
num_inputs = len(precomputed_mm_inputs)
# Check if hash is enabled
use_hash = mm_hashes is not None
if use_hash:
# Sanity
if self.use_cache:
assert mm_hashes is not None
assert num_inputs == len(
mm_hashes), "num_inputs = {} len(mm_hashes) = {}".format(
num_inputs, len(mm_hashes))
assert num_inputs == len(mm_hashes)
# Process each image input separately, so that later we can schedule
# them in a fine-grained manner.
# Apply caching (if enabled) and reuse precomputed inputs (if provided)
ret_hashes: Optional[List[str]] = [] if use_hash else None
ret_inputs: List[MultiModalKwargs] = []
for input_id in range(num_inputs):
if self.mm_debug_cache_hit_ratio_steps is not None:
self.cache_hit_ratio(self.mm_debug_cache_hit_ratio_steps)
mm_hash = None
mm_input = None
if use_hash:
if self.use_cache:
assert mm_hashes is not None
mm_hash = mm_hashes[input_id]
mm_input = self.mm_cache.get(mm_hash)
......@@ -106,7 +103,7 @@ class MMInputMapperClient:
mm_processor_kwargs=mm_processor_kwargs,
)
if use_hash:
if self.use_cache:
# Add to cache
assert mm_hash is not None
self.mm_cache.put(mm_hash, mm_input)
......@@ -114,18 +111,15 @@ class MMInputMapperClient:
self.mm_cache_hits += 1
mm_input = None # Avoids sending mm_input to Server
if use_hash:
assert mm_hash is not None
assert ret_hashes is not None
ret_hashes.append(mm_hash)
ret_inputs.append(mm_input)
return ret_inputs, ret_hashes
return ret_inputs
class MMInputMapperServer:
def __init__(self, ):
def __init__(self, model_config):
self.use_cache = model_config.mm_cache_preprocessor
self.mm_cache = LRUDictCache[str, MultiModalKwargs](MM_CACHE_SIZE)
def process_inputs(
......@@ -135,6 +129,9 @@ class MMInputMapperServer:
) -> List[MultiModalKwargs]:
assert len(mm_inputs) == len(mm_hashes)
if not self.use_cache:
return mm_inputs
full_mm_inputs = []
for mm_input, mm_hash in zip(mm_inputs, mm_hashes):
assert mm_hash is not None
......
import time
from typing import Any, Dict, Mapping, Optional, Tuple, Union
from vllm.config import LoRAConfig, ModelConfig
from vllm.config import CacheConfig, LoRAConfig, ModelConfig
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
PromptType, SingletonInputsAdapter)
from vllm.inputs.parse import is_encoder_decoder_inputs
......@@ -23,6 +23,7 @@ class Processor:
def __init__(
self,
model_config: ModelConfig,
cache_config: CacheConfig,
lora_config: Optional[LoRAConfig],
tokenizer: BaseTokenizerGroup,
input_registry: InputRegistry = INPUT_REGISTRY,
......@@ -45,8 +46,9 @@ class Processor:
self.mm_input_mapper_client = MMInputMapperClient(model_config)
# Multi-modal hasher (for images)
self.mm_hasher = MMHasher(
) if model_config.mm_cache_preprocessor else None
self.use_hash = model_config.mm_cache_preprocessor or \
cache_config.enable_prefix_caching
self.mm_hasher = MMHasher()
# TODO: run in an ThreadpoolExecutor or BackgroundProcess.
# This ideally should releases the GIL, so we should not block the
......@@ -77,7 +79,7 @@ class Processor:
# Compute MM hashes (if enabled)
mm_hashes = None
if self.mm_hasher is not None:
if self.use_hash:
mm_hashes = self.mm_hasher.hash(prompt)
# Process inputs.
......@@ -118,7 +120,7 @@ class Processor:
# Apply MM mapper
mm_inputs = None
if len(decoder_inputs.multi_modal_data) > 0:
mm_inputs, mm_hashes = self.mm_input_mapper_client.process_inputs(
mm_inputs = self.mm_input_mapper_client.process_inputs(
decoder_inputs.multi_modal_data, mm_hashes,
decoder_inputs.mm_processor_kwargs, precomputed_mm_inputs)
......
import enum
from typing import List, Optional, Union
from typing import TYPE_CHECKING, List, Optional, Union
from vllm.inputs import DecoderOnlyInputs, SingletonInputsAdapter, token_inputs
from vllm.lora.request import LoRARequest
......@@ -9,6 +9,9 @@ from vllm.sequence import RequestMetrics
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.utils import ConstantList
if TYPE_CHECKING:
from vllm.v1.core.kv_cache_utils import BlockHashType
class Request:
......@@ -45,6 +48,7 @@ class Request:
self._all_token_ids: List[int] = self.prompt_token_ids.copy()
self.num_computed_tokens = 0
# Multi-modal input metadata.
mm_positions = self.inputs.multi_modal_placeholders
if mm_positions:
# FIXME(woosuk): Support other modalities.
......@@ -56,6 +60,12 @@ class Request:
if self.inputs.multi_modal_inputs:
self.mm_inputs = self.inputs.multi_modal_inputs
self.mm_hashes: List[str] = self.inputs.multi_modal_hashes
# Cache the computed kv block hashes of the request to avoid
# recomputing.
self._kv_block_hashes: List[BlockHashType] = []
@classmethod
def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request":
return cls(
......@@ -65,6 +75,7 @@ class Request:
prompt=request.prompt,
multi_modal_data=None,
multi_modal_inputs=request.mm_inputs,
multi_modal_hashes=request.mm_hashes,
multi_modal_placeholders=request.mm_placeholders,
mm_processor_kwargs=None,
),
......@@ -121,6 +132,17 @@ class Request:
num_tokens = self.mm_positions[input_id]["length"]
return num_tokens
@property
def kv_block_hashes(self) -> ConstantList["BlockHashType"]:
# Prevent directly appending to the kv_block_hashes.
return ConstantList(self._kv_block_hashes)
def set_kv_block_hashes(self, value: List["BlockHashType"]) -> None:
self._kv_block_hashes = value
def append_kv_block_hashes(self, block_hash: "BlockHashType") -> None:
self._kv_block_hashes.append(block_hash)
class RequestStatus(enum.IntEnum):
"""Status of a request."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment