"vllm/vscode:/vscode.git/clone" did not exist on "d88f28da05b12bc7d63ebe3dcedf445ecb274343"
Unverified Commit 8f4a1c9a authored by Ruixiang Tan's avatar Ruixiang Tan Committed by GitHub
Browse files

[Misc] Improve code readability of KVCacheManager (#21673)


Signed-off-by: default avatartanruixiang <tanruixiang0104@gmail.com>
Signed-off-by: default avatarRuixiang Tan <819464715@qq.com>
Signed-off-by: default avatarGitHub <noreply@github.com>
parent 36ede459
...@@ -112,9 +112,9 @@ def test_kv_cache_block(): ...@@ -112,9 +112,9 @@ def test_kv_cache_block():
assert block.block_hash is None assert block.block_hash is None
# Test reference count manipulation # Test reference count manipulation
block.incr_ref() block.ref_cnt += 1
assert block.ref_cnt == 1 assert block.ref_cnt == 1
block.decr_ref() block.ref_cnt -= 1
assert block.ref_cnt == 0 assert block.ref_cnt == 0
# Test block hash setting and resetting # Test block hash setting and resetting
......
...@@ -276,7 +276,7 @@ class BlockPool: ...@@ -276,7 +276,7 @@ class BlockPool:
# candidate), so remove it. # candidate), so remove it.
if block.ref_cnt == 0 and not block.is_null: if block.ref_cnt == 0 and not block.is_null:
self.free_block_queue.remove(block) self.free_block_queue.remove(block)
block.incr_ref() block.ref_cnt += 1
def free_blocks(self, ordered_blocks: Iterable[KVCacheBlock]) -> None: def free_blocks(self, ordered_blocks: Iterable[KVCacheBlock]) -> None:
"""Free a list of blocks. The blocks should be ordered by their """Free a list of blocks. The blocks should be ordered by their
......
...@@ -126,14 +126,17 @@ class KVCacheCoordinator(ABC): ...@@ -126,14 +126,17 @@ class KVCacheCoordinator(ABC):
def get_num_common_prefix_blocks(self, request_id: str, def get_num_common_prefix_blocks(self, request_id: str,
num_running_requests: int) -> list[int]: num_running_requests: int) -> list[int]:
""" """
Get the number of common prefix blocks for a request. Get the number of common prefix blocks for all requests in the RUNNING
state for each kv cache group.
Args: Args:
request_id: The request ID. request_id: The request ID.
num_running_requests: The number of requests in the RUNNING state. num_running_requests: The total number of requests in the RUNNING
state.
Returns: Returns:
list[int]: The number of common prefix blocks. list[int]: The number of common prefix blocks for all requests in
the RUNNING state for each kv cache group.
""" """
num_blocks_per_group = [ num_blocks_per_group = [
manager.get_num_common_prefix_blocks(request_id, manager.get_num_common_prefix_blocks(request_id,
......
...@@ -170,10 +170,6 @@ class KVCacheManager: ...@@ -170,10 +170,6 @@ class KVCacheManager:
self.block_size, request) self.block_size, request)
self.req_to_block_hashes[request.request_id] = block_hashes self.req_to_block_hashes[request.request_id] = block_hashes
if self.log_stats:
assert self.prefix_cache_stats is not None
self.prefix_cache_stats.requests += 1
# NOTE: When all tokens hit the cache, we must recompute the last token # NOTE: When all tokens hit the cache, we must recompute the last token
# to obtain logits. Thus, set max_cache_hit_length to prompt_length - 1. # to obtain logits. Thus, set max_cache_hit_length to prompt_length - 1.
# This can trigger recomputation of an entire block, rather than just # This can trigger recomputation of an entire block, rather than just
...@@ -187,6 +183,7 @@ class KVCacheManager: ...@@ -187,6 +183,7 @@ class KVCacheManager:
if self.log_stats: if self.log_stats:
assert self.prefix_cache_stats is not None assert self.prefix_cache_stats is not None
self.prefix_cache_stats.requests += 1
self.prefix_cache_stats.queries += request.num_tokens self.prefix_cache_stats.queries += request.num_tokens
self.prefix_cache_stats.hits += num_new_computed_tokens self.prefix_cache_stats.hits += num_new_computed_tokens
......
...@@ -154,14 +154,6 @@ class KVCacheBlock: ...@@ -154,14 +154,6 @@ class KVCacheBlock:
# Whether the block is a null block that should never be cached. # Whether the block is a null block that should never be cached.
is_null: bool = False is_null: bool = False
# TODO(Jialin): For performance, let callers handle ref_cnt bumps to
# avoid function calls.
def incr_ref(self):
self.ref_cnt += 1
def decr_ref(self):
self.ref_cnt -= 1
@property @property
def block_hash(self) -> Optional[BlockHashWithGroupId]: def block_hash(self) -> Optional[BlockHashWithGroupId]:
return self._block_hash return self._block_hash
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import defaultdict from collections import defaultdict
from typing import Callable from typing import Callable
...@@ -177,14 +178,17 @@ class SingleTypeKVCacheManager(ABC): ...@@ -177,14 +178,17 @@ class SingleTypeKVCacheManager(ABC):
def get_num_common_prefix_blocks(self, request_id: str, def get_num_common_prefix_blocks(self, request_id: str,
num_running_requests: int) -> int: num_running_requests: int) -> int:
""" """
Get the number of common prefix blocks for a request. Get the number of common prefix blocks for all requests in the RUNNING
state.
Args: Args:
request_id: The request ID. request_id: The request ID.
num_running_requests: The number of requests in the RUNNING state. num_running_requests: The total number of requests in the RUNNING
state.
Returns: Returns:
The number of common prefix blocks. The number of common prefix blocks for all requests in the RUNNING
state.
""" """
raise NotImplementedError raise NotImplementedError
...@@ -264,7 +268,7 @@ class FullAttentionManager(SingleTypeKVCacheManager): ...@@ -264,7 +268,7 @@ class FullAttentionManager(SingleTypeKVCacheManager):
computed_blocks: tuple[list[KVCacheBlock], ...] = tuple( computed_blocks: tuple[list[KVCacheBlock], ...] = tuple(
[] for _ in range(len(kv_cache_group_ids))) [] for _ in range(len(kv_cache_group_ids)))
max_num_blocks = max_length // kv_cache_spec.block_size max_num_blocks = max_length // kv_cache_spec.block_size
for i, block_hash in zip(range(max_num_blocks), block_hashes): for block_hash in itertools.islice(block_hashes, max_num_blocks):
# block_hashes is a chain of block hashes. If a block hash is not # block_hashes is a chain of block hashes. If a block hash is not
# in the cached_block_hash_to_id, the following block hashes are # in the cached_block_hash_to_id, the following block hashes are
# not computed yet for sure. # not computed yet for sure.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment