Unverified Commit c280066f authored by Or Ozeri's avatar Or Ozeri Committed by GitHub
Browse files

[v1] Move block_hashes from KVCacheManager to Request.block_hashes (#19728)


Signed-off-by: default avatarOr Ozeri <oro@il.ibm.com>
parent b9dc9d26
......@@ -7,6 +7,7 @@ import pytest
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import RequestStatus
from vllm.v1.utils import ConstantList
from .utils import create_requests, create_scheduler
......@@ -140,7 +141,8 @@ def test_prefix_caching_for_prefill_dedup():
requests = create_requests(num_requests=5,
num_tokens=num_prompt_tokens,
max_tokens=3,
same_prompt=True)
same_prompt=True,
block_size=BLOCK_SIZE)
requests_copy = requests.copy()
# Two requests with the same prompt.
......@@ -188,7 +190,8 @@ def test_prefix_caching_for_multi_turn():
block_size=BLOCK_SIZE)
requests = create_requests(num_requests=5,
num_tokens=num_prompt_tokens,
max_tokens=num_output_tokens)
max_tokens=num_output_tokens,
block_size=BLOCK_SIZE)
for req in requests:
scheduler.add_request(req)
......@@ -208,14 +211,19 @@ def test_prefix_caching_for_multi_turn():
# Create next-turn requests whose prompts are the full output of the
# previous turn.
next_turn_requests = create_requests(
num_requests=5,
num_tokens=num_prompt_tokens + num_output_tokens,
max_tokens=num_output_tokens,
)
next_turn_requests = create_requests(num_requests=5,
num_tokens=num_prompt_tokens +
num_output_tokens,
max_tokens=num_output_tokens,
block_size=BLOCK_SIZE)
for i, req in enumerate(next_turn_requests):
req.prompt_token_ids = (requests[i].prompt_token_ids +
list(requests[i].output_token_ids))
req._all_token_ids = req.prompt_token_ids.copy()
req.all_token_ids = ConstantList(req._all_token_ids)
req.block_hashes = []
req.block_hashes = req.get_hash_new_full_blocks()
# Schedule the next-turn requests.
for req in next_turn_requests:
scheduler.add_request(req)
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import importlib
from typing import Optional
from typing import Callable, Optional
import pytest
import torch
......@@ -19,7 +19,7 @@ from vllm.v1.core.kv_cache_utils import (
FreeKVCacheBlockQueue, KVCacheBlock, PrefixCachingMetrics,
estimate_max_model_len, generate_block_hash_extra_keys,
get_kv_cache_config, get_max_concurrency_for_kv_cache_config,
hash_block_tokens, hash_request_tokens, init_none_hash,
get_request_block_hasher, hash_block_tokens, init_none_hash,
is_kv_cache_type_uniform, unify_kv_cache_configs)
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec, KVCacheTensor,
......@@ -33,6 +33,8 @@ from vllm.v1.request import Request
def make_request(
request_id: str,
prompt_token_ids: list[int],
block_size: int = 3,
hash_fn: Callable = hash,
mm_positions: Optional[list[PlaceholderRange]] = None,
mm_hashes: Optional[list[str]] = None,
cache_salt: Optional[str] = None,
......@@ -49,18 +51,17 @@ def make_request(
mm_item = MultiModalKwargsItem.from_elems([mm_elem])
mm_kwargs = [mm_item] * len(mm_positions)
return Request(
request_id=request_id,
prompt_token_ids=prompt_token_ids,
multi_modal_kwargs=mm_kwargs,
multi_modal_hashes=mm_hashes,
multi_modal_placeholders=mm_positions,
sampling_params=SamplingParams(max_tokens=17),
pooling_params=None,
eos_token_id=100,
lora_request=None,
cache_salt=cache_salt,
)
return Request(request_id=request_id,
prompt_token_ids=prompt_token_ids,
multi_modal_kwargs=mm_kwargs,
multi_modal_hashes=mm_hashes,
multi_modal_placeholders=mm_positions,
sampling_params=SamplingParams(max_tokens=17),
pooling_params=None,
eos_token_id=100,
lora_request=None,
cache_salt=cache_salt,
block_hasher=get_request_block_hasher(block_size, hash_fn))
def new_kv_cache_spec(block_size=16,
......@@ -428,12 +429,14 @@ def test_hash_block_tokens(hash_fn):
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash])
def test_hash_request_tokens(hash_fn):
def test_request_block_hasher(hash_fn):
import vllm.v1.core.kv_cache_utils
init_none_hash(hash_fn)
request = make_request(
request_id="0",
prompt_token_ids=[_ for _ in range(6)],
block_size=3,
hash_fn=hash_fn,
mm_positions=[
PlaceholderRange(offset=0, length=3),
PlaceholderRange(offset=3, length=3),
......@@ -441,9 +444,7 @@ def test_hash_request_tokens(hash_fn):
mm_hashes=["hash1", "hash2"],
)
block_size = 3
block_hashes = hash_request_tokens(hash_fn, block_size, request)
block_hashes = request.block_hashes
assert len(block_hashes) == 2
assert isinstance(block_hashes[0], vllm.v1.core.kv_cache_utils.BlockHash)
assert isinstance(block_hashes[1], vllm.v1.core.kv_cache_utils.BlockHash)
......@@ -464,6 +465,8 @@ def test_hash_tokens_different_mm_input(hash_fn):
request1 = make_request(
request_id="0",
prompt_token_ids=[_ for _ in range(6)],
block_size=3,
hash_fn=hash_fn,
mm_positions=[
PlaceholderRange(offset=0, length=3),
PlaceholderRange(offset=3, length=3),
......@@ -479,9 +482,8 @@ def test_hash_tokens_different_mm_input(hash_fn):
],
mm_hashes=["hash3", "hash2"],
)
block_size = 3
block_hashes1 = hash_request_tokens(hash_fn, block_size, request1)
block_hashes2 = hash_request_tokens(hash_fn, block_size, request2)
block_hashes1 = request1.block_hashes
block_hashes2 = request2.block_hashes
assert block_hashes1[0] != block_hashes2[0]
assert block_hashes1[1] != block_hashes2[1]
......@@ -493,12 +495,13 @@ def test_hash_request_tokens_no_mm_inputs(hash_fn):
request = make_request(
request_id="0",
prompt_token_ids=[_ for _ in range(6)],
block_size=3,
hash_fn=hash_fn,
mm_positions=None,
mm_hashes=None,
)
block_size = 3
block_hashes = hash_request_tokens(hash_fn, block_size, request)
block_hashes = request.block_hashes
assert len(block_hashes) == 2
assert block_hashes[0].token_ids == (0, 1, 2)
......@@ -858,6 +861,7 @@ def test_allocate_with_lookahead():
request = make_request(
request_id="0",
prompt_token_ids=[],
block_size=block_size,
mm_positions=None,
mm_hashes=None,
)
......
This diff is collapsed.
......@@ -589,7 +589,7 @@ def test_preempt_during_execution():
block_size=16,
num_blocks=11,
enable_prefix_caching=False)
requests = create_requests(num_requests=2, num_tokens=80)
requests = create_requests(num_requests=2, num_tokens=80, block_size=16)
# Schedule the first request.
scheduler.add_request(requests[0])
......@@ -762,7 +762,7 @@ def _assert_right_scheduler_output(
def _assert_right_kv_cache_manager(
scheduler: Scheduler,
req_ids: list[str],
requests: list[Request],
num_tokens: int,
block_size: int,
num_requests: int,
......@@ -772,12 +772,12 @@ def _assert_right_kv_cache_manager(
# Make sure the request stats are right.
EXPECTED_TOTAL_BLOCKS = num_tokens // block_size
for req_id in req_ids:
for req in requests:
blocks = (scheduler.kv_cache_manager.coordinator.
single_type_managers[0].req_to_blocks[req_id])
hashes = scheduler.kv_cache_manager.req_to_block_hashes[req_id]
single_type_managers[0].req_to_blocks[req.request_id])
hashes = req.block_hashes
assert (scheduler.kv_cache_manager.coordinator.single_type_managers[0].
num_cached_block[req_id] == EXPECTED_TOTAL_BLOCKS)
num_cached_block[req.request_id] == EXPECTED_TOTAL_BLOCKS)
assert len(blocks) == EXPECTED_TOTAL_BLOCKS
assert len(hashes) == EXPECTED_TOTAL_BLOCKS
......@@ -840,7 +840,8 @@ def test_kv_connector_basic():
MAX_TOKENS = 3
requests = create_requests(num_requests=NUM_REQUESTS,
num_tokens=NUM_TOKENS,
max_tokens=MAX_TOKENS)
max_tokens=MAX_TOKENS,
block_size=BLOCK_SIZE)
req_ids = []
req_to_index = {}
for i, request in enumerate(requests):
......@@ -868,7 +869,7 @@ def test_kv_connector_basic():
)
# Ensure KVCacheManager is correct.
_assert_right_kv_cache_manager(scheduler, req_ids, NUM_TOKENS, BLOCK_SIZE,
_assert_right_kv_cache_manager(scheduler, requests, NUM_TOKENS, BLOCK_SIZE,
NUM_REQUESTS, NUM_TOTAL_BLOCKS)
# Continue Generation until done.
......@@ -886,7 +887,8 @@ def test_kv_connector_basic():
NUM_TOKENS = NUM_TOKENS_PREFIX * 2
requests = create_requests(num_requests=NUM_REQUESTS,
num_tokens=NUM_TOKENS,
max_tokens=MAX_TOKENS)
max_tokens=MAX_TOKENS,
block_size=BLOCK_SIZE)
req_ids = []
req_to_index = {}
for i, request in enumerate(requests):
......@@ -915,7 +917,7 @@ def test_kv_connector_basic():
NUM_MATCHED_NEW_TOKENS))
# Ensure KVCacheManager is correct.
_assert_right_kv_cache_manager(scheduler, req_ids, NUM_TOKENS, BLOCK_SIZE,
_assert_right_kv_cache_manager(scheduler, requests, NUM_TOKENS, BLOCK_SIZE,
NUM_REQUESTS, NUM_TOTAL_BLOCKS)
# Continue Generation until done.
......@@ -953,7 +955,8 @@ def test_kv_connector_unable_to_allocate():
MAX_TOKENS = 2
requests = create_requests(num_requests=NUM_REQUESTS,
num_tokens=NUM_TOKENS,
max_tokens=MAX_TOKENS)
max_tokens=MAX_TOKENS,
block_size=BLOCK_SIZE)
req_ids = []
req_to_index = {}
for i, request in enumerate(requests):
......@@ -1034,7 +1037,8 @@ def test_kv_connector_handles_preemption():
MAX_TOKENS = BLOCK_SIZE * 2
requests = create_requests(num_requests=NUM_REQUESTS,
num_tokens=NUM_TOKENS,
max_tokens=MAX_TOKENS)
max_tokens=MAX_TOKENS,
block_size=BLOCK_SIZE)
req_ids = []
req_to_index = {}
for i, request in enumerate(requests):
......@@ -1162,7 +1166,6 @@ def assert_scheduler_empty(scheduler: Scheduler):
# KVCache Manager.
assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0].
req_to_blocks) == 0
assert len(scheduler.kv_cache_manager.req_to_block_hashes) == 0
assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0].
num_cached_block) == 0
num_free_blocks = (
......
......@@ -17,7 +17,6 @@ from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec,
def get_sliding_window_manager(sliding_window_spec, block_pool):
return SlidingWindowManager(sliding_window_spec,
block_pool,
caching_hash_fn=lambda x: x,
kv_cache_group_id=0)
......@@ -25,7 +24,6 @@ def get_chunked_local_attention_manager(chunked_local_attention_spec,
block_pool):
return ChunkedLocalAttentionManager(chunked_local_attention_spec,
block_pool,
caching_hash_fn=lambda x: x,
kv_cache_group_id=0)
......
......@@ -10,6 +10,8 @@ from vllm.multimodal.inputs import (MultiModalBatchedField,
MultiModalFieldElem, MultiModalKwargsItem,
PlaceholderRange)
from vllm.sampling_params import SamplingParams
from vllm.v1.core.kv_cache_utils import (get_request_block_hasher,
init_none_hash)
from vllm.v1.core.sched.async_scheduler import AsyncScheduler
from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
......@@ -114,6 +116,9 @@ def create_scheduler(
)
_none_hash_initialized = False
def create_requests(
num_requests: int,
num_tokens: int = 10,
......@@ -122,7 +127,14 @@ def create_requests(
stop_token_ids: Optional[list[int]] = None,
prompt_logprobs: Optional[int] = None,
same_prompt: bool = False,
block_size: int = 16,
) -> list[Request]:
global _none_hash_initialized
if not _none_hash_initialized:
init_none_hash(hash)
_none_hash_initialized = True
block_hasher = get_request_block_hasher(block_size, hash)
sampling_params = SamplingParams(ignore_eos=False,
max_tokens=max_tokens,
stop_token_ids=stop_token_ids,
......@@ -139,9 +151,11 @@ def create_requests(
)
mm_item = MultiModalKwargsItem.from_elems([mm_elem])
mm_kwargs = [mm_item] * len(mm_position)
mm_hashes = ["hash"] * len(mm_position)
else:
mm_position = None
mm_kwargs = None
mm_hashes = None
prompt_token_ids = ([0] * num_tokens if same_prompt else [i] *
num_tokens)
request = Request(
......@@ -151,8 +165,9 @@ def create_requests(
pooling_params=None,
multi_modal_kwargs=mm_kwargs,
multi_modal_placeholders=mm_position,
multi_modal_hashes=None,
multi_modal_hashes=mm_hashes,
eos_token_id=EOS_TOKEN_ID,
block_hasher=block_hasher,
)
requests.append(request)
return requests
......@@ -147,6 +147,7 @@ def test_basic_interface():
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
request = create_request(request_id=1,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS,
do_remote_prefill=True)
request_id = request.request_id
......@@ -186,6 +187,7 @@ def test_prompt_less_than_block_size():
# Request will have 1 partial remote block.
request = create_request(request_id=1,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS,
do_remote_prefill=True,
num_remote_blocks=1)
......
......@@ -21,6 +21,7 @@ def test_basic_lifecycle():
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
request = create_request(request_id=1,
block_size=BLOCK_SIZE,
max_tokens=1,
num_tokens=NUM_TOKENS,
do_remote_decode=True)
......@@ -103,8 +104,10 @@ def test_short_prompt_lifecycle():
scheduler = create_scheduler(vllm_config)
# Not enough tokens for full block.
NUM_TOKENS = vllm_config.cache_config.block_size // 2
BLOCK_SIZE = vllm_config.cache_config.block_size
NUM_TOKENS = BLOCK_SIZE // 2
request = create_request(request_id=1,
block_size=BLOCK_SIZE,
max_tokens=1,
num_tokens=NUM_TOKENS,
do_remote_decode=True)
......@@ -148,7 +151,9 @@ def test_prefix_cache_lifecycle():
NUM_EXTERNAL_FULL_BLOCKS = 3
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
request_normal = create_request(request_id=1, num_tokens=NUM_TOKENS)
request_normal = create_request(request_id=1,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS)
scheduler.add_request(request_normal)
scheduler_output = scheduler.schedule()
......@@ -166,6 +171,7 @@ def test_prefix_cache_lifecycle():
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
request_remote = create_request(request_id=1,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS,
do_remote_decode=True)
......
......@@ -23,6 +23,7 @@ def test_basic_lifecycle():
scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks)
request = create_request(request_id=1,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS,
do_remote_prefill=True)
......@@ -133,14 +134,17 @@ def test_interleaved_lifecycle():
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
request_remote = create_request(request_id=1,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS,
do_remote_prefill=True)
request_local_a = create_request(
request_id=2,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS,
)
request_local_b = create_request(
request_id=3,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS,
)
......@@ -236,6 +240,7 @@ def test_no_spurious_prefix_caching():
# Both of these requests have prompts like [1,1,1,1,1, ...]
request_remote = create_request(
request_id=1,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS,
do_remote_prefill=True,
use_all_1s_for_prompt_tokens=True,
......@@ -243,6 +248,7 @@ def test_no_spurious_prefix_caching():
request_local = create_request(
request_id=2,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS,
do_remote_prefill=False,
use_all_1s_for_prompt_tokens=True,
......@@ -292,6 +298,7 @@ def test_full_block_prompt():
NUM_TOKENS = int(BLOCK_SIZE * NUM_EXTERNAL_FULL_BLOCKS)
request = create_request(request_id=1,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS,
do_remote_prefill=True)
......@@ -364,8 +371,11 @@ def test_cannot_schedule_after_recv():
NUM_TOKENS_LOCAL = int(BLOCK_SIZE * NUM_PROMPT_BLOCKS)
NUM_TOKENS_REMOTE = int(BLOCK_SIZE * NUM_PROMPT_BLOCKS)
request_normal = create_request(request_id=1, num_tokens=NUM_TOKENS_LOCAL)
request_normal = create_request(request_id=1,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS_LOCAL)
request_remote = create_request(request_id=2,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS_REMOTE,
do_remote_prefill=True)
......@@ -456,8 +466,11 @@ def test_cannot_recv():
NUM_TOKENS_LOCAL = int(BLOCK_SIZE * NUM_PROMPT_BLOCKS)
NUM_TOKENS_REMOTE = int(BLOCK_SIZE * (NUM_PROMPT_BLOCKS + 0.5))
request_normal = create_request(request_id=1, num_tokens=NUM_TOKENS_LOCAL)
request_normal = create_request(request_id=1,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS_LOCAL)
request_remote = create_request(request_id=2,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS_REMOTE,
do_remote_prefill=True)
......
......@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import tempfile
from collections import defaultdict
from typing import Any, Optional
from typing import Any, Callable, Optional
import torch
......@@ -14,6 +14,8 @@ from vllm.distributed.kv_transfer.kv_connector.factory import (
from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import ( # noqa
SharedStorageConnector)
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.core.kv_cache_utils import (get_request_block_hasher,
init_none_hash)
from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec)
......@@ -40,7 +42,6 @@ def assert_scheduler_empty(scheduler: Scheduler):
# KVCache Manager.
assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0].
req_to_blocks) == 0
assert len(scheduler.kv_cache_manager.req_to_block_hashes) == 0
assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0].
num_cached_block) == 0
num_free_blocks = (
......@@ -115,16 +116,23 @@ def create_scheduler(
)
def create_request(
request_id: int,
num_tokens: int = 10,
max_tokens: int = 16,
do_remote_decode: bool = False,
do_remote_prefill: bool = False,
use_all_1s_for_prompt_tokens: bool = False,
num_remote_blocks: int = 3,
) -> Request:
_none_hash_initialized = False
def create_request(request_id: int,
num_tokens: int = 10,
max_tokens: int = 16,
do_remote_decode: bool = False,
do_remote_prefill: bool = False,
use_all_1s_for_prompt_tokens: bool = False,
num_remote_blocks: int = 3,
block_size: int = 16,
hash_fn: Callable = hash) -> Request:
"""Make dummy request for testing."""
global _none_hash_initialized
if not _none_hash_initialized:
init_none_hash(hash)
_none_hash_initialized = True
kv_transfer_params: Optional[dict[str, Any]] = None
......@@ -158,6 +166,7 @@ def create_request(
multi_modal_placeholders=None,
multi_modal_hashes=None,
eos_token_id=EOS_TOKEN_ID,
block_hasher=get_request_block_hasher(block_size, hash_fn),
)
req.kv_transfer_params = kv_transfer_params
return req
......
......@@ -3243,6 +3243,24 @@ def sha256_cbor_64bit(input) -> int:
return full_hash & ((1 << 64) - 1)
def get_hash_fn_by_name(hash_fn_name: str) -> Callable:
"""Get a hash function by name, or raise an error if
the function is not found.
Args:
hash_fn_name: Name of the hash function.
Returns:
A hash function.
"""
if hash_fn_name == "sha256":
return sha256
if hash_fn_name == "sha256_cbor_64bit":
return sha256_cbor_64bit
if hash_fn_name == "builtin":
return hash
raise ValueError(f"Unsupported hash function: {hash_fn_name}")
def is_torch_equal_or_newer(target: str) -> bool:
"""Check if the installed torch version is >= the target version.
......
......@@ -2,15 +2,13 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections import defaultdict
from collections.abc import Iterable
from typing import Callable, Optional
from typing import Optional
from vllm.distributed.kv_events import (AllBlocksCleared, BlockRemoved,
BlockStored, KVCacheEvent)
from vllm.logger import init_logger
from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId,
FreeKVCacheBlockQueue, KVCacheBlock,
generate_block_hash_extra_keys,
hash_block_tokens)
FreeKVCacheBlockQueue, KVCacheBlock)
from vllm.v1.request import Request
logger = init_logger(__name__)
......@@ -97,84 +95,39 @@ class BlockPool:
self,
request: Request,
blocks: list[KVCacheBlock],
block_hashes: list[BlockHash],
num_cached_blocks: int,
num_full_blocks: int,
block_size: int,
kv_cache_group_id: int,
hash_fn: Callable,
) -> None:
"""Cache a list of full blocks for prefix caching.
This function takes a list of blocks that will have their block hash
metadata to be updated and cached. Given a request, it computes the
block hashes for the blocks starting from `num_cached_blocks` to
`num_full_blocks`, updating the metadata for each block
and caching them in the `cached_block_hash_to_block`.
metadata to be updated and cached. Given a request, it updates the
metadata for each block and caching it in the
`cached_block_hash_to_block`.
The block hashes values are computed by the Request object immediately
when it is created and when new tokens are appended.
Args:
request: The request to cache the blocks.
blocks: All blocks in the request.
block_hashes: Block hashes of the blocks in the request. Note that
this list may be shorter than the blocks list. In this case the
missed block hash will be computed in this function.
num_cached_blocks: The number of blocks that are already cached.
num_full_blocks: The number of blocks that are full and should
be cached after this function.
block_size: Number of tokens in each block.
kv_cache_group_id: The id of the KV cache group.
hash_fn: The hash function to use for block hashes.
"""
if num_cached_blocks == num_full_blocks:
return
new_full_blocks = blocks[num_cached_blocks:num_full_blocks]
assert len(block_hashes) >= num_cached_blocks
new_block_hashes = block_hashes[num_cached_blocks:]
assert len(request.block_hashes) >= num_full_blocks
new_block_hashes = request.block_hashes[num_cached_blocks:]
# Update the new blocks with the block hashes through the chain.
if num_cached_blocks == 0:
prev_block_hash_value = None
else:
prev_block = blocks[num_cached_blocks - 1]
assert prev_block.block_hash is not None
prev_block_hash_value = prev_block.block_hash.get_hash_value()
parent_block_hash = prev_block_hash_value
new_hashes: Optional[list[int]] = ([] if self.enable_kv_cache_events
else None)
for i, blk in enumerate(new_full_blocks):
assert blk.block_hash is None
if i < len(new_block_hashes):
# The block hash may already be computed in
# "get_computed_blocks" if the tokens are not generated by
# this request (either the prompt tokens or the previously
# generated tokens with preemption), or by other
# single_type_managers with the same block_size.
# In this case we simply reuse the block hash.
block_hash = new_block_hashes[i]
else:
# Otherwise compute the block hash and cache it in the request
# in case it will be preempted in the future.
blk_idx = num_cached_blocks + i
start_token_idx = blk_idx * block_size
end_token_idx = (blk_idx + 1) * block_size
block_tokens = request.all_token_ids[
start_token_idx:end_token_idx]
assert len(block_tokens) == block_size, (
f"Expected {block_size} tokens, got "
f"{len(block_tokens)} at {blk_idx}th block for request "
f"{request.request_id}({request})")
# Generate extra keys for multi-modal inputs. Note that since
# we reach to this branch only when the block is completed with
# generated tokens, we only need to consider the last mm input.
extra_keys, _ = generate_block_hash_extra_keys(
request, start_token_idx, end_token_idx, -1)
# Compute the hash of the current block.
block_hash = hash_block_tokens(hash_fn, prev_block_hash_value,
block_tokens, extra_keys)
block_hashes.append(block_hash)
block_hash = new_block_hashes[i]
# Update and added the full block to the cache.
block_hash_with_group_id = BlockHashWithGroupId(
......@@ -184,9 +137,15 @@ class BlockPool:
blk.block_id] = blk
if new_hashes is not None:
new_hashes.append(block_hash.hash_value)
prev_block_hash_value = block_hash.hash_value
if self.enable_kv_cache_events:
if num_cached_blocks == 0:
parent_block_hash = None
else:
parent_block = blocks[num_cached_blocks - 1]
assert parent_block.block_hash is not None
parent_block_hash = parent_block.block_hash.get_hash_value()
self.kv_event_queue.append(
BlockStored(
block_hashes=new_hashes,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from typing import Callable, Optional
from typing import Optional
from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock
......@@ -23,7 +23,6 @@ class KVCacheCoordinator(ABC):
max_model_len: int,
use_eagle: bool,
enable_caching: bool,
caching_hash_fn: Callable,
enable_kv_cache_events: bool,
):
self.kv_cache_config = kv_cache_config
......@@ -40,7 +39,6 @@ class KVCacheCoordinator(ABC):
kv_cache_spec=kv_cache_group.kv_cache_spec,
block_pool=self.block_pool,
kv_cache_group_id=i,
caching_hash_fn=caching_hash_fn,
) for i, kv_cache_group in enumerate(
self.kv_cache_config.kv_cache_groups))
......@@ -99,19 +97,17 @@ class KVCacheCoordinator(ABC):
manager.allocate_new_blocks(request_id, num_tokens)
for manager in self.single_type_managers)
def cache_blocks(self, request: Request, block_hashes: list[BlockHash],
num_computed_tokens: int) -> None:
def cache_blocks(self, request: Request, num_computed_tokens: int) -> None:
"""
Cache the blocks for the request.
Args:
request: The request.
block_hashes: The block hashes of the request.
num_tokens: The total number of tokens that need to be cached
(including tokens that are already cached).
"""
for manager in self.single_type_managers:
manager.cache_blocks(request, block_hashes, num_computed_tokens)
manager.cache_blocks(request, num_computed_tokens)
def free(self, request_id: str) -> None:
"""
......@@ -184,10 +180,9 @@ class KVCacheCoordinatorNoPrefixCache(KVCacheCoordinator):
"""
def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int,
use_eagle: bool, caching_hash_fn: Callable,
enable_kv_cache_events: bool):
use_eagle: bool, enable_kv_cache_events: bool):
super().__init__(kv_cache_config, max_model_len, use_eagle, False,
caching_hash_fn, enable_kv_cache_events)
enable_kv_cache_events)
self.num_single_type_manager = len(self.single_type_managers)
def get_num_common_prefix_blocks(self, request_id: str,
......@@ -213,10 +208,9 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int,
use_eagle: bool, enable_caching: bool,
caching_hash_fn: Callable, enable_kv_cache_events: bool):
enable_kv_cache_events: bool):
super().__init__(kv_cache_config, max_model_len, use_eagle,
enable_caching, caching_hash_fn,
enable_kv_cache_events)
enable_caching, enable_kv_cache_events)
self.kv_cache_spec = self.kv_cache_config.kv_cache_groups[
0].kv_cache_spec
self.block_size = self.kv_cache_spec.block_size
......@@ -250,10 +244,9 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int,
use_eagle: bool, enable_caching: bool,
caching_hash_fn: Callable, enable_kv_cache_events: bool):
enable_kv_cache_events: bool):
super().__init__(kv_cache_config, max_model_len, use_eagle,
enable_caching, caching_hash_fn,
enable_kv_cache_events)
enable_caching, enable_kv_cache_events)
self.verify_and_split_kv_cache_groups()
def verify_and_split_kv_cache_groups(self) -> None:
......@@ -386,17 +379,15 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
def get_kv_cache_coordinator(
kv_cache_config: KVCacheConfig, max_model_len: int, use_eagle: bool,
enable_caching: bool, caching_hash_fn: Callable,
enable_caching: bool,
enable_kv_cache_events: bool) -> KVCacheCoordinator:
if not enable_caching:
return KVCacheCoordinatorNoPrefixCache(kv_cache_config, max_model_len,
use_eagle, caching_hash_fn,
use_eagle,
enable_kv_cache_events)
if len(kv_cache_config.kv_cache_groups) == 1:
return UnitaryKVCacheCoordinator(kv_cache_config, max_model_len,
use_eagle, enable_caching,
caching_hash_fn,
enable_kv_cache_events)
return HybridKVCacheCoordinator(kv_cache_config, max_model_len, use_eagle,
enable_caching, caching_hash_fn,
enable_kv_cache_events)
enable_caching, enable_kv_cache_events)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections import defaultdict
from dataclasses import dataclass
from typing import Optional
from vllm.distributed.kv_events import KVCacheEvent
from vllm.logger import init_logger
from vllm.utils import sha256, sha256_cbor_64bit
from vllm.v1.core.kv_cache_coordinator import get_kv_cache_coordinator
from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock,
hash_request_tokens, init_none_hash)
from vllm.v1.core.kv_cache_utils import KVCacheBlock
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.metrics.stats import PrefixCacheStats
from vllm.v1.request import Request, RequestStatus
......@@ -71,23 +68,13 @@ class KVCacheManager:
kv_cache_config: KVCacheConfig,
max_model_len: int,
enable_caching: bool = True,
caching_hash_algo: str = "builtin",
use_eagle: bool = False,
log_stats: bool = False,
enable_kv_cache_events: bool = False,
) -> None:
self.max_model_len = max_model_len
if len(kv_cache_config.kv_cache_groups) == 0:
# Attention free models don't have kv cache,
# thus don't need prefix caching.
enable_caching = False
self.enable_caching = enable_caching
self.caching_hash_fn = (
sha256_cbor_64bit if caching_hash_algo == "sha256_cbor_64bit" else
sha256 if caching_hash_algo == "sha256" else hash)
init_none_hash(self.caching_hash_fn)
self.use_eagle = use_eagle
self.log_stats = log_stats
# FIXME: make prefix cache stats conditional on log_stats
......@@ -107,19 +94,12 @@ class KVCacheManager:
max_model_len=self.max_model_len,
use_eagle=self.use_eagle,
enable_caching=self.enable_caching,
caching_hash_fn=self.caching_hash_fn,
enable_kv_cache_events=enable_kv_cache_events,
)
self.num_kv_cache_groups = len(kv_cache_config.kv_cache_groups)
self.block_pool = self.coordinator.block_pool
self.kv_cache_config = kv_cache_config
# Mapping from request ID to kv block hashes.
# This is to avoid recomputing the block hashes for each call of
# `get_computed_blocks` or `allocate_slots`.
self.req_to_block_hashes: defaultdict[
str, list[BlockHash]] = defaultdict(list)
@property
def usage(self) -> float:
"""Get the KV cache usage.
......@@ -161,15 +141,6 @@ class KVCacheManager:
and request.sampling_params.prompt_logprobs is not None)):
return self.create_empty_block_list(), 0
# The block hashes for the request may already be computed
# if the scheduler has tried to schedule the request before.
block_hashes = self.req_to_block_hashes[request.request_id]
if not block_hashes:
assert self.block_size is not None
block_hashes = hash_request_tokens(self.caching_hash_fn,
self.block_size, request)
self.req_to_block_hashes[request.request_id] = block_hashes
# NOTE: When all tokens hit the cache, we must recompute the last token
# to obtain logits. Thus, set max_cache_hit_length to prompt_length - 1.
# This can trigger recomputation of an entire block, rather than just
......@@ -178,7 +149,7 @@ class KVCacheManager:
# could slightly improve performance in the future.
max_cache_hit_length = request.num_tokens - 1
computed_blocks, num_new_computed_tokens = (
self.coordinator.find_longest_cache_hit(block_hashes,
self.coordinator.find_longest_cache_hit(request.block_hashes,
max_cache_hit_length))
if self.log_stats:
......@@ -296,11 +267,7 @@ class KVCacheManager:
# at `request.num_tokens`, ensuring only "finalized" tokens are cached.
num_tokens_to_cache = min(num_computed_tokens + num_new_tokens,
request.num_tokens)
self.coordinator.cache_blocks(
request,
self.req_to_block_hashes[request.request_id],
num_tokens_to_cache,
)
self.coordinator.cache_blocks(request, num_tokens_to_cache)
return KVCacheBlocks(new_blocks)
......@@ -373,14 +340,6 @@ class KVCacheManager:
return self.coordinator.get_num_common_prefix_blocks(
request.request_id, num_running_requests)
def free_block_hashes(self, request: Request) -> None:
"""Discard the block hashes for the request.
NOTE: Unlike `free`, this method should be called only when the request
is finished, not when it is preempted.
"""
self.req_to_block_hashes.pop(request.request_id, None)
def take_events(self) -> list[KVCacheEvent]:
"""Take the KV cache events from the block pool.
......@@ -397,9 +356,7 @@ class KVCacheManager:
def cache_blocks(self, request: Request, num_computed_tokens: int) -> None:
"""Cache the blocks for the request, if enabled."""
if self.enable_caching:
block_hashes = self.req_to_block_hashes[request.request_id]
self.coordinator.cache_blocks(request, block_hashes,
num_computed_tokens)
self.coordinator.cache_blocks(request, num_computed_tokens)
def create_empty_block_list(self) -> KVCacheBlocks:
"""Creates a new KVCacheBlocks instance with no blocks."""
......
......@@ -547,41 +547,61 @@ def hash_block_tokens(
curr_block_token_ids_tuple, extra_keys)
def hash_request_tokens(hash_function: Any, block_size: int,
request: Request) -> list[BlockHash]:
"""Computes hash values of a chain of blocks given a sequence of
token IDs. The hash value is used for prefix caching.
def get_request_block_hasher(
block_size: int,
caching_hash_fn: Callable[[Any],
int]) -> Callable[[Request], list[BlockHash]]:
"""
Returns a function which computes the list of un-computed block hashes
of a request.
Each request holds a list of its block hashes (request.block_hashes).
When a request is created, it calls the below function to compute
the hashes of all full blocks of the request's initial tokens.
The hashes are then stored in request.block_hashes.
Later, whenever new tokens are appended to the request, it calls
the below function again to compute any new full blocks of tokens.
The returned new hashes are appended to request.block_hashes.
"""
Args:
block_size: The size of each block.
request: The request object.
def request_block_hasher(request: Request) -> list[BlockHash]:
start_token_idx = len(request.block_hashes) * block_size
num_tokens = request.num_tokens
curr_mm_idx = 0
if start_token_idx > 0:
# Set curr_mm_idx = -1 to indicate the last mm input.
# Note that since we reach to this branch only when the block is
# completed with generated tokens, we only need to consider the
# last mm input.
curr_mm_idx = -1
prev_block_hash_value = request.block_hashes[-1].hash_value \
if request.block_hashes else None
new_block_hashes: list[BlockHash] = []
while True:
end_token_idx = start_token_idx + block_size
if end_token_idx > num_tokens:
# We only hash full blocks
break
Returns:
The list of computed hash values.
"""
token_ids = request.all_token_ids
# MM and LoRA requests need extra keys for block-hash computation.
extra_keys, curr_mm_idx = generate_block_hash_extra_keys(
request, start_token_idx, end_token_idx, curr_mm_idx)
req_need_extra_keys = need_extra_keys(request)
req_extra_keys = None
curr_mm_idx = 0
# Compute the hash of the current block
block_tokens = request.all_token_ids[start_token_idx:end_token_idx]
block_hash = hash_block_tokens(caching_hash_fn,
prev_block_hash_value, block_tokens,
extra_keys)
ret = []
parent_block_hash_value = None
# Only full blocks will be hashed
for start in range(0, len(token_ids) - block_size + 1, block_size):
end = start + block_size
block_token_ids = token_ids[start:end]
new_block_hashes.append(block_hash)
start_token_idx += block_size
prev_block_hash_value = block_hash.hash_value
if req_need_extra_keys:
# MM and LoRA requests need extra keys for block-hash computation.
req_extra_keys, curr_mm_idx = generate_block_hash_extra_keys(
request, start, end, curr_mm_idx)
block_hash = hash_block_tokens(hash_function, parent_block_hash_value,
block_token_ids, req_extra_keys)
ret.append(block_hash)
parent_block_hash_value = block_hash.hash_value
return ret
return new_block_hashes
return request_block_hasher
def max_memory_usage_bytes(vllm_config: VllmConfig,
......
......@@ -155,7 +155,6 @@ class Scheduler(SchedulerInterface):
kv_cache_config=kv_cache_config,
max_model_len=self.max_model_len,
enable_caching=self.cache_config.enable_prefix_caching,
caching_hash_algo=self.cache_config.prefix_caching_hash_algo,
use_eagle=self.use_eagle,
log_stats=self.log_stats,
enable_kv_cache_events=self.enable_kv_cache_events,
......@@ -1036,7 +1035,6 @@ class Scheduler(SchedulerInterface):
def _free_blocks(self, request: Request):
assert request.is_finished()
self.kv_cache_manager.free(request)
self.kv_cache_manager.free_block_hashes(request)
del self.requests[request.request_id]
def get_num_unfinished_requests(self) -> int:
......
......@@ -3,7 +3,6 @@
import itertools
from abc import ABC, abstractmethod
from collections import defaultdict
from typing import Callable
from vllm.utils import cdiv
from vllm.v1.core.block_pool import BlockPool
......@@ -25,7 +24,6 @@ class SingleTypeKVCacheManager(ABC):
kv_cache_spec: KVCacheSpec,
block_pool: BlockPool,
kv_cache_group_id: int,
caching_hash_fn: Callable,
) -> None:
"""
Initializes the SingleTypeKVCacheManager.
......@@ -33,7 +31,6 @@ class SingleTypeKVCacheManager(ABC):
kv_cache_spec: The kv_cache_spec for this manager.
block_pool: The block pool.
kv_cache_group_id: The id of the kv cache group of this manager.
caching_hash_fn: The caching hash function.
"""
self.block_size = kv_cache_spec.block_size
......@@ -52,7 +49,6 @@ class SingleTypeKVCacheManager(ABC):
# data for reempted ones.
self.num_cached_block: dict[str, int] = {}
self.caching_hash_fn = caching_hash_fn
self.kv_cache_group_id = kv_cache_group_id
self._null_block = block_pool.null_block
......@@ -130,14 +126,12 @@ class SingleTypeKVCacheManager(ABC):
req_blocks.extend(new_blocks)
return new_blocks
def cache_blocks(self, request: Request, block_hashes: list[BlockHash],
num_tokens: int) -> None:
def cache_blocks(self, request: Request, num_tokens: int) -> None:
"""
Cache the blocks for the request.
Args:
request: The request.
block_hashes: The block hashes of the request.
num_tokens: The total number of tokens that need to be cached
(including tokens that are already cached).
"""
......@@ -147,12 +141,10 @@ class SingleTypeKVCacheManager(ABC):
self.block_pool.cache_full_blocks(
request=request,
blocks=self.req_to_blocks[request.request_id],
block_hashes=block_hashes,
num_cached_blocks=num_cached_blocks,
num_full_blocks=num_full_blocks,
block_size=self.block_size,
kv_cache_group_id=self.kv_cache_group_id,
hash_fn=self.caching_hash_fn,
)
self.num_cached_block[request.request_id] = num_full_blocks
......
......@@ -25,9 +25,11 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.tasks import POOLING_TASKS, SupportedTask
from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value)
from vllm.utils import (decorate_logs, make_zmq_socket,
from vllm.utils import (decorate_logs, get_hash_fn_by_name, make_zmq_socket,
resolve_obj_by_qualname, set_process_title)
from vllm.v1.core.kv_cache_utils import (get_kv_cache_config,
from vllm.v1.core.kv_cache_utils import (BlockHash, get_kv_cache_config,
get_request_block_hasher,
init_none_hash,
unify_kv_cache_configs)
from vllm.v1.core.sched.interface import SchedulerInterface
from vllm.v1.core.sched.output import SchedulerOutput
......@@ -140,6 +142,19 @@ class EngineCore:
self.batch_queue_size)
self.batch_queue = queue.Queue(self.batch_queue_size)
self.request_block_hasher: Optional[Callable[[Request],
list[BlockHash]]] = None
if (self.vllm_config.cache_config.enable_prefix_caching
or self.scheduler.get_kv_connector() is not None):
block_size = vllm_config.cache_config.block_size
caching_hash_fn = get_hash_fn_by_name(
vllm_config.cache_config.prefix_caching_hash_algo)
init_none_hash(caching_hash_fn)
self.request_block_hasher = get_request_block_hasher(
block_size, caching_hash_fn)
def _initialize_kv_caches(
self, vllm_config: VllmConfig) -> tuple[int, int, KVCacheConfig]:
start = time.time()
......@@ -417,7 +432,8 @@ class EngineCore:
request.mm_kwargs = self.mm_input_cache_server.get_and_update(
request.mm_kwargs, request.mm_hashes)
req = Request.from_engine_core_request(request)
req = Request.from_engine_core_request(request,
self.request_block_hasher)
if req.use_structured_output:
# Note on thread safety: no race condition.
# `grammar_init` is only invoked in input processing thread. For
......
......@@ -3,7 +3,8 @@
import enum
import time
from typing import TYPE_CHECKING, Any, Optional, Union
from functools import partial
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
from vllm.pooling_params import PoolingParams
......@@ -16,6 +17,7 @@ from vllm.v1.utils import ConstantList
if TYPE_CHECKING:
from vllm.lora.request import LoRARequest
from vllm.v1.core.kv_cache_utils import BlockHash
class Request:
......@@ -36,6 +38,8 @@ class Request:
structured_output_request: Optional["StructuredOutputRequest"] = None,
cache_salt: Optional[str] = None,
priority: int = 0,
block_hasher: Optional[Callable[["Request"],
list["BlockHash"]]] = None,
) -> None:
self.request_id = request_id
self.client_index = client_index
......@@ -108,8 +112,18 @@ class Request:
# indicates that the output is corrupted
self.num_nans_in_logits = 0
self.block_hashes: list[BlockHash] = []
self.get_hash_new_full_blocks: Optional[Callable[
[], list[BlockHash]]] = None
if block_hasher is not None:
self.get_hash_new_full_blocks = partial(block_hasher, self)
self.block_hashes = self.get_hash_new_full_blocks()
@classmethod
def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request":
def from_engine_core_request(
cls, request: EngineCoreRequest,
block_hasher: Optional[Callable[["Request"], list["BlockHash"]]]
) -> "Request":
if request.mm_kwargs is not None:
assert is_list_of(request.mm_kwargs, MultiModalKwargsItem), (
"mm_kwargs was not updated in EngineCore.add_request")
......@@ -131,6 +145,7 @@ class Request:
if request.sampling_params else None,
cache_salt=request.cache_salt,
priority=request.priority,
block_hasher=block_hasher,
)
def append_output_token_ids(
......@@ -144,6 +159,9 @@ class Request:
self._output_token_ids.extend(token_ids)
self._all_token_ids.extend(token_ids)
if self.get_hash_new_full_blocks is not None:
self.block_hashes.extend(self.get_hash_new_full_blocks())
@property
def is_output_corrupted(self) -> bool:
return self.num_nans_in_logits > 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment