Unverified Commit aabcd2ca authored by Chen Zhang's avatar Chen Zhang Committed by GitHub
Browse files

[v1] Introduce KVCacheBlocks as interface between Scheduler and KVCacheManager (#17479)


Signed-off-by: default avatarChen Zhang <zhangch99@outlook.com>
parent 0d115460
...@@ -542,7 +542,7 @@ def test_allocate_with_lookahead(): ...@@ -542,7 +542,7 @@ def test_allocate_with_lookahead():
num_tokens=3, num_tokens=3,
num_lookahead_tokens=2, # Total required: 3+2=5 tokens num_lookahead_tokens=2, # Total required: 3+2=5 tokens
) )
assert len(blocks) == 2 # ceil(5/4)=2 blocks assert len(blocks.blocks) == 2 # ceil(5/4)=2 blocks
# Test case 2: With precomputed blocks # Test case 2: With precomputed blocks
kv_cache_manager = KVCacheManager(kv_cache_config=config, kv_cache_manager = KVCacheManager(kv_cache_config=config,
...@@ -553,7 +553,7 @@ def test_allocate_with_lookahead(): ...@@ -553,7 +553,7 @@ def test_allocate_with_lookahead():
num_tokens=3, num_tokens=3,
num_lookahead_tokens=2, num_lookahead_tokens=2,
) )
assert len(blocks) == 2 assert len(blocks.blocks) == 2
# Test case 3: With precomputed blocks # Test case 3: With precomputed blocks
# required_blocks = ceil((3 + 4) / 4) = 2 # required_blocks = ceil((3 + 4) / 4) = 2
...@@ -564,4 +564,4 @@ def test_allocate_with_lookahead(): ...@@ -564,4 +564,4 @@ def test_allocate_with_lookahead():
num_tokens=3, num_tokens=3,
num_lookahead_tokens=4, num_lookahead_tokens=4,
) )
assert len(blocks) == 2 assert len(blocks.blocks) == 2
This diff is collapsed.
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
from collections import defaultdict from collections import defaultdict
from collections.abc import Iterable from collections.abc import Iterable
from dataclasses import dataclass
from typing import Optional from typing import Optional
from vllm.distributed.kv_events import KVCacheEvent from vllm.distributed.kv_events import KVCacheEvent
...@@ -18,6 +19,24 @@ from vllm.v1.request import Request, RequestStatus ...@@ -18,6 +19,24 @@ from vllm.v1.request import Request, RequestStatus
logger = init_logger(__name__) logger = init_logger(__name__)
@dataclass
class KVCacheBlocks:
blocks: list[KVCacheBlock]
def __add__(self, other: "KVCacheBlocks") -> "KVCacheBlocks":
"""Adds two KVCacheBlocks instances."""
return KVCacheBlocks(self.blocks + other.blocks)
@classmethod
def create_empty(cls) -> "KVCacheBlocks":
"""Creates a new KVCacheBlocks instance with no blocks."""
return cls([])
def get_block_ids(self) -> list[int]:
"""Converts the KVCacheBlocks instance to a list of block IDs."""
return [block.block_id for block in self.blocks]
class KVCacheManager: class KVCacheManager:
def __init__( def __init__(
...@@ -94,8 +113,8 @@ class KVCacheManager: ...@@ -94,8 +113,8 @@ class KVCacheManager:
self.prefix_cache_stats = PrefixCacheStats() self.prefix_cache_stats = PrefixCacheStats()
return stats return stats
def get_computed_blocks( def get_computed_blocks(self,
self, request: Request) -> tuple[list[KVCacheBlock], int]: request: Request) -> tuple[KVCacheBlocks, int]:
"""Get the computed (cached) blocks for the request. """Get the computed (cached) blocks for the request.
Note that the computed blocks must be full. Note that the computed blocks must be full.
...@@ -109,7 +128,7 @@ class KVCacheManager: ...@@ -109,7 +128,7 @@ class KVCacheManager:
""" """
if not self.enable_caching: if not self.enable_caching:
# Prefix caching is disabled. # Prefix caching is disabled.
return [], 0 return KVCacheBlocks.create_empty(), 0
# The block hashes for the request may already be computed # The block hashes for the request may already be computed
# if the scheduler has tried to schedule the request before. # if the scheduler has tried to schedule the request before.
...@@ -124,7 +143,7 @@ class KVCacheManager: ...@@ -124,7 +143,7 @@ class KVCacheManager:
self.prefix_cache_stats.requests += 1 self.prefix_cache_stats.requests += 1
# When the request requires prompt logprobs, we skip prefix caching. # When the request requires prompt logprobs, we skip prefix caching.
if request.sampling_params.prompt_logprobs is not None: if request.sampling_params.prompt_logprobs is not None:
return [], 0 return KVCacheBlocks.create_empty(), 0
if len(block_hashes) * self.block_size == request.num_tokens: if len(block_hashes) * self.block_size == request.num_tokens:
# When prompt length is divisible by the block size and all # When prompt length is divisible by the block size and all
...@@ -157,15 +176,15 @@ class KVCacheManager: ...@@ -157,15 +176,15 @@ class KVCacheManager:
# sharing, `num_computed_tokens` is always a multiple of # sharing, `num_computed_tokens` is always a multiple of
# `block_size`. # `block_size`.
num_computed_tokens = len(computed_blocks) * self.block_size num_computed_tokens = len(computed_blocks) * self.block_size
return computed_blocks, num_computed_tokens return KVCacheBlocks(computed_blocks), num_computed_tokens
def allocate_slots( def allocate_slots(
self, self,
request: Request, request: Request,
num_tokens: int, num_tokens: int,
new_computed_blocks: Optional[list[KVCacheBlock]] = None, new_computed_blocks: Optional[KVCacheBlocks] = None,
num_lookahead_tokens: int = 0, num_lookahead_tokens: int = 0,
) -> Optional[list[KVCacheBlock]]: ) -> Optional[KVCacheBlocks]:
"""Add slots for a request with new tokens to append. """Add slots for a request with new tokens to append.
Args: Args:
...@@ -173,7 +192,7 @@ class KVCacheManager: ...@@ -173,7 +192,7 @@ class KVCacheManager:
num_tokens: The number of tokens to allocate, including external num_tokens: The number of tokens to allocate, including external
tokens. Note that this does not include tokens that have tokens. Note that this does not include tokens that have
already been computed locally (i.e. new_computed_blocks). already been computed locally (i.e. new_computed_blocks).
new_computed_blocks: A list of new computed blocks just hitting the new_computed_blocks: The new computed blocks just hitting the
prefix caching. prefix caching.
num_lookahead_tokens: The number of speculative tokens to allocate. num_lookahead_tokens: The number of speculative tokens to allocate.
This is used by spec decode proposers with kv-cache such This is used by spec decode proposers with kv-cache such
...@@ -199,7 +218,10 @@ class KVCacheManager: ...@@ -199,7 +218,10 @@ class KVCacheManager:
if num_tokens == 0: if num_tokens == 0:
raise ValueError("num_tokens must be greater than 0") raise ValueError("num_tokens must be greater than 0")
new_computed_blocks = new_computed_blocks or [] if new_computed_blocks is not None:
new_computed_block_list = new_computed_blocks.blocks
else:
new_computed_block_list = []
req_blocks = self.req_to_blocks[request.request_id] req_blocks = self.req_to_blocks[request.request_id]
...@@ -216,17 +238,18 @@ class KVCacheManager: ...@@ -216,17 +238,18 @@ class KVCacheManager:
# The number of computed tokens is the number of computed tokens plus # The number of computed tokens is the number of computed tokens plus
# the new prefix caching hits # the new prefix caching hits
num_computed_tokens = (request.num_computed_tokens + num_computed_tokens = (request.num_computed_tokens +
len(new_computed_blocks) * self.block_size) len(new_computed_block_list) * self.block_size)
num_required_blocks = cdiv( num_required_blocks = cdiv(
num_computed_tokens + num_tokens + num_lookahead_tokens, num_computed_tokens + num_tokens + num_lookahead_tokens,
self.block_size) self.block_size)
num_new_blocks = (num_required_blocks - len(req_blocks) - num_new_blocks = (num_required_blocks - len(req_blocks) -
len(new_computed_blocks)) len(new_computed_block_list))
# If a computed block of a request is an eviction candidate (in the # If a computed block of a request is an eviction candidate (in the
# free queue and ref_cnt == 0), it cannot be counted as a free block # free queue and ref_cnt == 0), it cannot be counted as a free block
# when allocating this request. # when allocating this request.
num_evictable_computed_blocks = sum(1 for blk in new_computed_blocks num_evictable_computed_blocks = sum(1
for blk in new_computed_block_list
if blk.ref_cnt == 0) if blk.ref_cnt == 0)
if (num_new_blocks > self.block_pool.get_num_free_blocks() - if (num_new_blocks > self.block_pool.get_num_free_blocks() -
num_evictable_computed_blocks): num_evictable_computed_blocks):
...@@ -235,15 +258,15 @@ class KVCacheManager: ...@@ -235,15 +258,15 @@ class KVCacheManager:
# Touch the computed blocks to make sure they won't be evicted. # Touch the computed blocks to make sure they won't be evicted.
if self.enable_caching: if self.enable_caching:
self.block_pool.touch(new_computed_blocks) self.block_pool.touch(new_computed_block_list)
else: else:
assert not new_computed_blocks, ( assert not new_computed_block_list, (
"Computed blocks should be empty when " "Computed blocks should be empty when "
"prefix caching is disabled") "prefix caching is disabled")
# Append the new computed blocks to the request blocks until now to # Append the new computed blocks to the request blocks until now to
# avoid the case where the new blocks cannot be allocated. # avoid the case where the new blocks cannot be allocated.
req_blocks.extend(new_computed_blocks) req_blocks.extend(new_computed_block_list)
# Start to handle new blocks # Start to handle new blocks
...@@ -267,12 +290,12 @@ class KVCacheManager: ...@@ -267,12 +290,12 @@ class KVCacheManager:
req_blocks.extend(new_blocks) req_blocks.extend(new_blocks)
if not self.enable_caching: if not self.enable_caching:
return new_blocks return KVCacheBlocks(new_blocks)
# Use `new_computed_blocks` for a new request, and `num_cached_block` # Use `new_computed_block_list` for a new request, and
# for a running request. # `num_cached_block` for a running request.
num_cached_blocks = self.num_cached_block.get(request.request_id, num_cached_blocks = self.num_cached_block.get(
len(new_computed_blocks)) request.request_id, len(new_computed_block_list))
# Speculated tokens might be rejected in the future, so we does # Speculated tokens might be rejected in the future, so we does
# not cache any speculated tokens. We only cache blocks with # not cache any speculated tokens. We only cache blocks with
# generated (accepted) tokens. # generated (accepted) tokens.
...@@ -291,7 +314,7 @@ class KVCacheManager: ...@@ -291,7 +314,7 @@ class KVCacheManager:
self.num_cached_block[ self.num_cached_block[
request.request_id] = num_full_blocks_after_append request.request_id] = num_full_blocks_after_append
return new_blocks return KVCacheBlocks(new_blocks)
def free(self, request: Request) -> None: def free(self, request: Request) -> None:
"""Free the blocks allocated for the request. """Free the blocks allocated for the request.
......
...@@ -261,9 +261,8 @@ class Scheduler(SchedulerInterface): ...@@ -261,9 +261,8 @@ class Scheduler(SchedulerInterface):
# Therefore, we might introduce some additional # Therefore, we might introduce some additional
# cycle to fill in the bitmask, which could be a big no-op. # cycle to fill in the bitmask, which could be a big no-op.
structured_output_request_ids[request.request_id] = req_index structured_output_request_ids[request.request_id] = req_index
req_to_new_block_ids[request.request_id] = [ req_to_new_block_ids[request.request_id] = (
b.block_id for b in new_blocks new_blocks.get_block_ids())
]
num_scheduled_tokens[request.request_id] = num_new_tokens num_scheduled_tokens[request.request_id] = num_new_tokens
token_budget -= num_new_tokens token_budget -= num_new_tokens
req_index += 1 req_index += 1
...@@ -407,9 +406,8 @@ class Scheduler(SchedulerInterface): ...@@ -407,9 +406,8 @@ class Scheduler(SchedulerInterface):
if self.lora_config and request.lora_request: if self.lora_config and request.lora_request:
scheduled_loras.add(request.lora_request.lora_int_id) scheduled_loras.add(request.lora_request.lora_int_id)
req_to_new_block_ids[request.request_id] = [ req_to_new_block_ids[request.request_id] = (
b.block_id for b in computed_blocks + new_blocks computed_blocks + new_blocks).get_block_ids()
]
num_scheduled_tokens[request.request_id] = num_new_tokens num_scheduled_tokens[request.request_id] = num_new_tokens
token_budget -= num_new_tokens token_budget -= num_new_tokens
request.status = RequestStatus.RUNNING request.status = RequestStatus.RUNNING
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment