Commit eef99f73 authored by laibao's avatar laibao
Browse files

feat: kvpress新增 KV cache 申请/截断支持

parent 8d3d07fc
...@@ -154,6 +154,17 @@ class KVCacheCoordinator(ABC): ...@@ -154,6 +154,17 @@ class KVCacheCoordinator(ABC):
for manager in self.single_type_managers: for manager in self.single_type_managers:
manager.remove_skipped_blocks(request_id, num_computed_tokens) manager.remove_skipped_blocks(request_id, num_computed_tokens)
def truncate_to_num_tokens(self, request_id: str, num_tokens: int) -> bool:
"""Truncate a request's allocated blocks to fit `num_tokens` slots.
Returns True if any blocks were freed.
"""
truncated = False
for manager in self.single_type_managers:
truncated = manager.truncate_to_num_tokens(request_id,
num_tokens) or truncated
return truncated
def get_blocks(self, request_id: str) -> tuple[list[KVCacheBlock], ...]: def get_blocks(self, request_id: str) -> tuple[list[KVCacheBlock], ...]:
""" """
Get the blocks for the request. Get the blocks for the request.
......
...@@ -7,6 +7,8 @@ from typing import Optional ...@@ -7,6 +7,8 @@ from typing import Optional
from vllm.distributed.kv_events import KVCacheEvent from vllm.distributed.kv_events import KVCacheEvent
from vllm.logger import init_logger from vllm.logger import init_logger
import vllm.envs as envs
from vllm.platforms import current_platform
from vllm.utils import sha256 from vllm.utils import sha256
from vllm.v1.core.kv_cache_coordinator import get_kv_cache_coordinator from vllm.v1.core.kv_cache_coordinator import get_kv_cache_coordinator
from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock, from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock,
...@@ -251,9 +253,17 @@ class KVCacheManager: ...@@ -251,9 +253,17 @@ class KVCacheManager:
# the new prefix caching hits # the new prefix caching hits
num_computed_tokens = (request.num_computed_tokens + num_computed_tokens = (request.num_computed_tokens +
num_new_computed_tokens) num_new_computed_tokens)
num_tokens_need_slot = min( if envs.VLLM_ENABLE_KV_COMPRESSION and not current_platform.is_tpu():
num_computed_tokens + num_new_tokens + num_lookahead_tokens, # KV compression decouples logical positions from KV cache
self.max_model_len) # positions. Allocate based on the KV cache length (plus the tokens
# scheduled for this step, which are temporarily written to cache).
num_tokens_need_slot = min(
request.num_kv_tokens + num_new_tokens + num_lookahead_tokens,
self.max_model_len)
else:
num_tokens_need_slot = min(
num_computed_tokens + num_new_tokens + num_lookahead_tokens,
self.max_model_len)
num_blocks_to_allocate = self.coordinator.get_num_blocks_to_allocate( num_blocks_to_allocate = self.coordinator.get_num_blocks_to_allocate(
request_id=request.request_id, request_id=request.request_id,
...@@ -385,6 +395,14 @@ class KVCacheManager: ...@@ -385,6 +395,14 @@ class KVCacheManager:
return KVCacheBlocks( return KVCacheBlocks(
self.coordinator.get_blocks(request_id)).get_block_ids() self.coordinator.get_blocks(request_id)).get_block_ids()
def truncate_to_num_tokens(self, request_id: str, num_tokens: int) -> bool:
"""Truncate a request's allocated blocks to fit `num_tokens` slots.
This is a best-effort operation that may free blocks back to the pool.
Returns True if any blocks were freed.
"""
return self.coordinator.truncate_to_num_tokens(request_id, num_tokens)
def cache_blocks(self, request: Request, num_computed_tokens: int) -> None: def cache_blocks(self, request: Request, num_computed_tokens: int) -> None:
"""Cache the blocks for the request, if enabled.""" """Cache the blocks for the request, if enabled."""
if self.enable_caching: if self.enable_caching:
......
...@@ -174,6 +174,15 @@ class SingleTypeKVCacheManager(ABC): ...@@ -174,6 +174,15 @@ class SingleTypeKVCacheManager(ABC):
self.block_pool.free_blocks(ordered_blocks) self.block_pool.free_blocks(ordered_blocks)
self.num_cached_block.pop(request_id, None) self.num_cached_block.pop(request_id, None)
def truncate_to_num_tokens(self, request_id: str, num_tokens: int) -> bool:
"""Truncate a request's allocated blocks to fit `num_tokens` slots.
This is a best-effort optimization hook. Subclasses may override this
to free no-longer-needed blocks (e.g., after KV compaction). The default
implementation is a no-op.
"""
return False
@abstractmethod @abstractmethod
def get_num_common_prefix_blocks(self, request_id: str, def get_num_common_prefix_blocks(self, request_id: str,
num_running_requests: int) -> int: num_running_requests: int) -> int:
...@@ -283,6 +292,24 @@ class FullAttentionManager(SingleTypeKVCacheManager): ...@@ -283,6 +292,24 @@ class FullAttentionManager(SingleTypeKVCacheManager):
# No need to remove blocks for full attention. # No need to remove blocks for full attention.
pass pass
def truncate_to_num_tokens(self, request_id: str, num_tokens: int) -> bool:
num_tokens = max(int(num_tokens), 0)
blocks = self.req_to_blocks.get(request_id)
if not blocks:
return False
num_required_blocks = cdiv(num_tokens, self.block_size)
if num_required_blocks >= len(blocks):
return False
removed_blocks = blocks[num_required_blocks:]
del blocks[num_required_blocks:]
self.block_pool.free_blocks(reversed(removed_blocks))
if request_id in self.num_cached_block:
self.num_cached_block[request_id] = min(
self.num_cached_block[request_id], len(blocks))
return True
def get_num_common_prefix_blocks(self, request_id: str, def get_num_common_prefix_blocks(self, request_id: str,
num_running_requests: int) -> int: num_running_requests: int) -> int:
blocks = self.req_to_blocks[request_id] blocks = self.req_to_blocks[request_id]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment