Unverified Commit aabcd2ca authored by Chen Zhang's avatar Chen Zhang Committed by GitHub
Browse files

[v1] Introduce KVCacheBlocks as interface between Scheduler and KVCacheManager (#17479)


Signed-off-by: default avatarChen Zhang <zhangch99@outlook.com>
parent 0d115460
......@@ -542,7 +542,7 @@ def test_allocate_with_lookahead():
num_tokens=3,
num_lookahead_tokens=2, # Total required: 3+2=5 tokens
)
assert len(blocks) == 2 # ceil(5/4)=2 blocks
assert len(blocks.blocks) == 2 # ceil(5/4)=2 blocks
# Test case 2: With precomputed blocks
kv_cache_manager = KVCacheManager(kv_cache_config=config,
......@@ -553,7 +553,7 @@ def test_allocate_with_lookahead():
num_tokens=3,
num_lookahead_tokens=2,
)
assert len(blocks) == 2
assert len(blocks.blocks) == 2
# Test case 3: With precomputed blocks
# required_blocks = ceil((3 + 4) / 4) = 2
......@@ -564,4 +564,4 @@ def test_allocate_with_lookahead():
num_tokens=3,
num_lookahead_tokens=4,
)
assert len(blocks) == 2
assert len(blocks.blocks) == 2
This diff is collapsed.
......@@ -2,6 +2,7 @@
from collections import defaultdict
from collections.abc import Iterable
from dataclasses import dataclass
from typing import Optional
from vllm.distributed.kv_events import KVCacheEvent
......@@ -18,6 +19,24 @@ from vllm.v1.request import Request, RequestStatus
logger = init_logger(__name__)
@dataclass
class KVCacheBlocks:
blocks: list[KVCacheBlock]
def __add__(self, other: "KVCacheBlocks") -> "KVCacheBlocks":
"""Adds two KVCacheBlocks instances."""
return KVCacheBlocks(self.blocks + other.blocks)
@classmethod
def create_empty(cls) -> "KVCacheBlocks":
"""Creates a new KVCacheBlocks instance with no blocks."""
return cls([])
def get_block_ids(self) -> list[int]:
"""Converts the KVCacheBlocks instance to a list of block IDs."""
return [block.block_id for block in self.blocks]
class KVCacheManager:
def __init__(
......@@ -94,8 +113,8 @@ class KVCacheManager:
self.prefix_cache_stats = PrefixCacheStats()
return stats
def get_computed_blocks(
self, request: Request) -> tuple[list[KVCacheBlock], int]:
def get_computed_blocks(self,
request: Request) -> tuple[KVCacheBlocks, int]:
"""Get the computed (cached) blocks for the request.
Note that the computed blocks must be full.
......@@ -109,7 +128,7 @@ class KVCacheManager:
"""
if not self.enable_caching:
# Prefix caching is disabled.
return [], 0
return KVCacheBlocks.create_empty(), 0
# The block hashes for the request may already be computed
# if the scheduler has tried to schedule the request before.
......@@ -124,7 +143,7 @@ class KVCacheManager:
self.prefix_cache_stats.requests += 1
# When the request requires prompt logprobs, we skip prefix caching.
if request.sampling_params.prompt_logprobs is not None:
return [], 0
return KVCacheBlocks.create_empty(), 0
if len(block_hashes) * self.block_size == request.num_tokens:
# When prompt length is divisible by the block size and all
......@@ -157,15 +176,15 @@ class KVCacheManager:
# sharing, `num_computed_tokens` is always a multiple of
# `block_size`.
num_computed_tokens = len(computed_blocks) * self.block_size
return computed_blocks, num_computed_tokens
return KVCacheBlocks(computed_blocks), num_computed_tokens
def allocate_slots(
self,
request: Request,
num_tokens: int,
new_computed_blocks: Optional[list[KVCacheBlock]] = None,
new_computed_blocks: Optional[KVCacheBlocks] = None,
num_lookahead_tokens: int = 0,
) -> Optional[list[KVCacheBlock]]:
) -> Optional[KVCacheBlocks]:
"""Add slots for a request with new tokens to append.
Args:
......@@ -173,7 +192,7 @@ class KVCacheManager:
num_tokens: The number of tokens to allocate, including external
tokens. Note that this does not include tokens that have
already been computed locally (i.e. new_computed_blocks).
new_computed_blocks: A list of new computed blocks just hitting the
new_computed_blocks: The new computed blocks just hitting the
prefix caching.
num_lookahead_tokens: The number of speculative tokens to allocate.
This is used by spec decode proposers with kv-cache such
......@@ -199,7 +218,10 @@ class KVCacheManager:
if num_tokens == 0:
raise ValueError("num_tokens must be greater than 0")
new_computed_blocks = new_computed_blocks or []
if new_computed_blocks is not None:
new_computed_block_list = new_computed_blocks.blocks
else:
new_computed_block_list = []
req_blocks = self.req_to_blocks[request.request_id]
......@@ -216,17 +238,18 @@ class KVCacheManager:
# The number of computed tokens is the number of computed tokens plus
# the new prefix caching hits
num_computed_tokens = (request.num_computed_tokens +
len(new_computed_blocks) * self.block_size)
len(new_computed_block_list) * self.block_size)
num_required_blocks = cdiv(
num_computed_tokens + num_tokens + num_lookahead_tokens,
self.block_size)
num_new_blocks = (num_required_blocks - len(req_blocks) -
len(new_computed_blocks))
len(new_computed_block_list))
# If a computed block of a request is an eviction candidate (in the
# free queue and ref_cnt == 0), it cannot be counted as a free block
# when allocating this request.
num_evictable_computed_blocks = sum(1 for blk in new_computed_blocks
num_evictable_computed_blocks = sum(1
for blk in new_computed_block_list
if blk.ref_cnt == 0)
if (num_new_blocks > self.block_pool.get_num_free_blocks() -
num_evictable_computed_blocks):
......@@ -235,15 +258,15 @@ class KVCacheManager:
# Touch the computed blocks to make sure they won't be evicted.
if self.enable_caching:
self.block_pool.touch(new_computed_blocks)
self.block_pool.touch(new_computed_block_list)
else:
assert not new_computed_blocks, (
assert not new_computed_block_list, (
"Computed blocks should be empty when "
"prefix caching is disabled")
# Append the new computed blocks to the request blocks until now to
# avoid the case where the new blocks cannot be allocated.
req_blocks.extend(new_computed_blocks)
req_blocks.extend(new_computed_block_list)
# Start to handle new blocks
......@@ -267,12 +290,12 @@ class KVCacheManager:
req_blocks.extend(new_blocks)
if not self.enable_caching:
return new_blocks
return KVCacheBlocks(new_blocks)
# Use `new_computed_blocks` for a new request, and `num_cached_block`
# for a running request.
num_cached_blocks = self.num_cached_block.get(request.request_id,
len(new_computed_blocks))
# Use `new_computed_block_list` for a new request, and
# `num_cached_block` for a running request.
num_cached_blocks = self.num_cached_block.get(
request.request_id, len(new_computed_block_list))
# Speculated tokens might be rejected in the future, so we does
# not cache any speculated tokens. We only cache blocks with
# generated (accepted) tokens.
......@@ -291,7 +314,7 @@ class KVCacheManager:
self.num_cached_block[
request.request_id] = num_full_blocks_after_append
return new_blocks
return KVCacheBlocks(new_blocks)
def free(self, request: Request) -> None:
"""Free the blocks allocated for the request.
......
......@@ -261,9 +261,8 @@ class Scheduler(SchedulerInterface):
# Therefore, we might introduce some additional
# cycle to fill in the bitmask, which could be a big no-op.
structured_output_request_ids[request.request_id] = req_index
req_to_new_block_ids[request.request_id] = [
b.block_id for b in new_blocks
]
req_to_new_block_ids[request.request_id] = (
new_blocks.get_block_ids())
num_scheduled_tokens[request.request_id] = num_new_tokens
token_budget -= num_new_tokens
req_index += 1
......@@ -407,9 +406,8 @@ class Scheduler(SchedulerInterface):
if self.lora_config and request.lora_request:
scheduled_loras.add(request.lora_request.lora_int_id)
req_to_new_block_ids[request.request_id] = [
b.block_id for b in computed_blocks + new_blocks
]
req_to_new_block_ids[request.request_id] = (
computed_blocks + new_blocks).get_block_ids()
num_scheduled_tokens[request.request_id] = num_new_tokens
token_budget -= num_new_tokens
request.status = RequestStatus.RUNNING
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment