Unverified Commit f8ece6e1 authored by Shawn Du's avatar Shawn Du Committed by GitHub
Browse files

[Core][v1] Unify allocating slots in prefill and decode in KV cache manager (#12608)

As mentioned in RFC https://github.com/vllm-project/vllm/issues/12254

,
this PR achieves the task: combine allocate_slots and append_slots.

There should be no functionality change, except that in decode, also
raise exception when num_tokens is zero (like prefill), and change the
unit test case accordingly.

@comaniac @rickyyx @WoosukKwon @youkaichao @heheda12345 @simon-mo

---------
Signed-off-by: default avatarShawn Du <shawnd200@outlook.com>
parent abfcdcdf
......@@ -164,7 +164,7 @@ def test_decode():
req0.num_computed_tokens = 55
for _ in range(4):
req0.append_output_token_ids(8)
new_blocks = manager.append_slots(req0, 4)
new_blocks = manager.allocate_slots(req0, 4)
assert new_blocks is not None and len(new_blocks) == 0
assert manager.req_to_blocks[req0.request_id][-2].block_hash is None
......@@ -175,7 +175,7 @@ def test_decode():
# the preallocated block.
for _ in range(5 + 10):
req0.append_output_token_ids(7)
new_blocks = manager.append_slots(req0, 15)
new_blocks = manager.allocate_slots(req0, 15)
assert new_blocks is not None and len(new_blocks) == 0
assert manager.req_to_blocks[req0.request_id][-2].block_hash is not None
......@@ -185,7 +185,7 @@ def test_decode():
# the preallocated block.
for _ in range(6 + 11):
req0.append_output_token_ids(12)
new_blocks = manager.append_slots(req0, 17)
new_blocks = manager.allocate_slots(req0, 17)
# Plus one preallocated block.
assert new_blocks is not None and len(new_blocks) == 2
......@@ -395,12 +395,14 @@ def test_preallocate_blocks(num_preallocate_tokens: int, block_size: int):
req.num_computed_tokens = block_size
assert len(blocks) == 1 + num_preallocated_blocks
# Assume all computed.
manager.append_slots(req, block_size * (len(blocks) - 1))
# Assume all computed, only when num_preallocate_tokens > 0, we need to
# consume the previously preallocated blocks.
if num_preallocated_blocks > 0:
manager.allocate_slots(req, block_size * (len(blocks) - 1))
req.num_computed_tokens = block_size * len(blocks)
# Append 1 block.
blocks = manager.append_slots(req, block_size)
blocks = manager.allocate_slots(req, block_size)
assert len(blocks) == 1 + num_preallocated_blocks
......@@ -503,7 +505,7 @@ def test_mm_prefix_caching():
# Append slots without allocating a new block.
for _ in range(5):
req0.append_output_token_ids(8)
new_blocks = manager.append_slots(req0, 5)
new_blocks = manager.allocate_slots(req0, 5)
assert new_blocks is not None and len(new_blocks) == 0
# The just completed block should have hashes with extra keys.
......@@ -603,7 +605,7 @@ def test_reset_prefix_cache():
unique_token_ids = [3] * 7
all_token_ids = full_block_token_ids + unique_token_ids
req0 = make_request("0", all_token_ids)
blocks = manager.allocate_slots(req0, 55, [])
blocks = manager.allocate_slots(req0, 55)
assert [b.block_id for b in blocks] == [0, 1, 2, 3]
unique_token_ids = [4] * 7
......@@ -639,7 +641,7 @@ def test_uncache_blocks():
)
req0 = make_request("0", list(range(30)))
blocks = manager.allocate_slots(req0, 30, [])
blocks = manager.allocate_slots(req0, 30)
assert [b.block_id for b in blocks] == [0, 1]
assert len(manager.cached_block_hash_to_block) == 1
......@@ -648,7 +650,7 @@ def test_uncache_blocks():
# Simulate speculative tokens.
for _ in range(5):
req0.append_output_token_ids(8)
manager.append_slots(req0, 5)
manager.allocate_slots(req0, 5)
assert len(manager.cached_block_hash_to_block) == 2
# After sampling, assuming only 1 token is accepted.
......
from collections import defaultdict
from typing import Dict, Iterable, List, Optional, Tuple
from typing import DefaultDict, Dict, Iterable, List, Optional, Tuple
from vllm.logger import init_logger
from vllm.utils import cdiv
......@@ -67,7 +67,8 @@ class KVCacheManager:
# Mapping from request ID to blocks to track the blocks allocated
# for each request, so that we can free the blocks when the request
# is finished.
self.req_to_blocks: Dict[str, List[KVCacheBlock]] = {}
self.req_to_blocks: DefaultDict[str,
List[KVCacheBlock]] = defaultdict(list)
@property
def usage(self) -> float:
......@@ -115,157 +116,116 @@ class KVCacheManager:
num_computed_tokens = len(computed_blocks) * self.block_size
return computed_blocks, num_computed_tokens
def append_slots(
self,
request: Request,
num_tokens: int,
) -> Optional[List[KVCacheBlock]]:
"""Append slots to the block table of the request.
We first append slots to already allocated blocks. If the allocated
blocks are not enough, we allocate new blocks.
Args:
request: The request to append slots.
num_tokens: The number of tokens to append.
Returns:
A list of new blocks if new blocks are allocated, or None
if new blocks are required but cannot be allocated.
"""
num_required_blocks = cdiv(request.num_computed_tokens + num_tokens,
self.block_size)
req_blocks = self.req_to_blocks[request.request_id]
num_new_blocks = num_required_blocks - len(req_blocks)
if num_new_blocks > self.free_block_queue.num_free_blocks:
# Need to allocate new blocks due to insufficient pre-allocated
# slots, but we cannot allocate new blocks due to the limit.
return None
if num_new_blocks <= 0:
# No new block is needed.
new_blocks = []
else:
# Get new blocks from the free block pool considering
# preallocated blocks.
num_new_blocks = min(
num_new_blocks + self.num_preallocate_blocks,
self.free_block_queue.num_free_blocks,
# Should not exceed the maximum number of blocks per request.
# This is especially because the block table has the shape
# [..., max_num_blocks_per_req].
# TODO(woosuk): Check and reject requests if
# num_prompt_tokens + max_tokens > max_model_len.
self.max_num_blocks_per_req - len(req_blocks),
)
assert num_new_blocks > 0
new_blocks = self._get_new_blocks(num_new_blocks)
req_blocks.extend(new_blocks)
if not self.enable_caching:
return new_blocks
num_computed_full_blocks = (request.num_computed_tokens //
self.block_size)
# NOTE(rickyx): We are assuming the `num_tokens` are actual
# tokens rather than lookahead slots (e.g. for speculative decoding).
# TODO(rickyx): When supporting speculative decoding, we will need to
# differentiate between them so that we can know how many blocks are
# full after appending the actual tokens.
num_full_blocks_after_append = (request.num_computed_tokens +
num_tokens) // self.block_size
assert num_full_blocks_after_append <= len(req_blocks)
new_full_blocks = req_blocks[
num_computed_full_blocks:num_full_blocks_after_append]
if new_full_blocks:
self._cache_full_blocks(
request=request,
blk_start_idx=num_computed_full_blocks,
full_blocks=new_full_blocks,
prev_block=req_blocks[num_computed_full_blocks - 1]
if num_computed_full_blocks >= 1 else None,
)
return new_blocks
def allocate_slots(
self,
request: Request,
num_tokens: int,
computed_blocks: List[KVCacheBlock],
new_computed_blocks: Optional[List[KVCacheBlock]] = None
) -> Optional[List[KVCacheBlock]]:
"""Allocate slots for a new request.
"""Add slots for a request with new tokens to append.
Args:
request: The request to allocate slots.
num_tokens: The number of tokens to allocate. Note that this does
not include the tokens that have already been computed.
computed_blocks: A list of computed blocks.
new_computed_blocks: A list of new computed blocks just hitting the
prefix caching.
Blocks layout:
-----------------------------------------------------------------------
| < computed > | < new computed > | < new > | < pre-allocated > |
-----------------------------------------------------------------------
| < required > |
--------------------------------------------------
| < full > |
------------------------------------------------
| <new full> |
--------------
The following *_blocks are illustrated in this layout.
Returns:
A list of new allocated blocks.
"""
if num_tokens == 0:
raise ValueError(
f"num_tokens must be greater than 0, got {num_tokens}")
raise ValueError("num_tokens must be greater than 0")
new_computed_blocks = new_computed_blocks or []
# The number of computed tokens is the number of computed tokens plus
# the new prefix caching hits
num_computed_tokens = (request.num_computed_tokens +
len(new_computed_blocks) * self.block_size)
num_required_blocks = cdiv(num_computed_tokens + num_tokens,
self.block_size)
req_blocks = self.req_to_blocks[request.request_id]
num_new_blocks = (num_required_blocks - len(req_blocks) -
len(new_computed_blocks))
# If a computed block of a request is an eviction candidate (in the
# free queue and ref_cnt == 0), it cannot be counted as a free block
# when allocating this request.
num_evictable_computed_blocks = sum(1 for blk in computed_blocks
num_evictable_computed_blocks = sum(1 for blk in new_computed_blocks
if blk.ref_cnt == 0)
num_required_blocks = cdiv(num_tokens, self.block_size)
if (num_required_blocks > self.free_block_queue.num_free_blocks -
if (num_new_blocks > self.free_block_queue.num_free_blocks -
num_evictable_computed_blocks):
# Cannot allocate new blocks.
# Cannot allocate new blocks
return None
# Touch the computed blocks to make sure they won't be evicted.
if self.enable_caching:
self._touch(computed_blocks)
self._touch(new_computed_blocks)
else:
assert not computed_blocks, (
assert not new_computed_blocks, (
"Computed blocks should be empty when "
"prefix caching is disabled")
# Determine the number of new blocks to allocate considering
# Append the new computed blocks to the request blocks until now to
# avoid the case where the new blocks cannot be allocated.
req_blocks.extend(new_computed_blocks)
# Start to handle new blocks
if num_new_blocks <= 0:
# No new block is needed.
new_blocks = []
else:
# Get new blocks from the free block pool considering
# preallocated blocks.
num_new_blocks = min(
num_required_blocks + self.num_preallocate_blocks,
num_new_blocks + self.num_preallocate_blocks,
self.free_block_queue.num_free_blocks,
# Should not exceed the maximum number of blocks per request.
# This is especially because the block table has the shape
# [..., max_num_blocks_per_req].
# TODO(woosuk): Check and reject requests if
# num_prompt_tokens + max_tokens > max_model_len.
self.max_num_blocks_per_req - len(computed_blocks),
self.max_num_blocks_per_req - len(req_blocks),
)
assert num_new_blocks > 0
# Concatenate the computed block IDs and the new block IDs.
new_blocks = self._get_new_blocks(num_new_blocks)
self.req_to_blocks[request.request_id] = computed_blocks + new_blocks
req_blocks.extend(new_blocks)
if not self.enable_caching:
return new_blocks
num_computed_tokens = len(computed_blocks) * self.block_size
# NOTE(rickyx): We are assuming the `num_tokens` are actual
# tokens rather than lookahead slots (e.g. for speculative decoding).
# TODO(rickyx): When supporting speculative decoding, we will need to
# differentiate between them so that we can know how many blocks are
# full after appending the actual tokens.
num_full_blocks = (num_computed_tokens + num_tokens) // self.block_size
new_full_blocks = self.req_to_blocks[
request.request_id][len(computed_blocks):num_full_blocks]
num_computed_full_blocks = num_computed_tokens // self.block_size
new_full_blocks = req_blocks[num_computed_full_blocks:num_full_blocks]
if new_full_blocks:
self._cache_full_blocks(
request=request,
blk_start_idx=len(computed_blocks),
blk_start_idx=num_computed_full_blocks,
# The new full blocks are the full blocks that are not computed.
full_blocks=new_full_blocks,
prev_block=computed_blocks[-1] if computed_blocks else None,
)
prev_block=(req_blocks[num_computed_full_blocks - 1]
if num_computed_full_blocks > 0 else None))
return new_blocks
......
......@@ -138,7 +138,7 @@ class Scheduler:
assert num_new_tokens > 0
while True:
new_blocks = self.kv_cache_manager.append_slots(
new_blocks = self.kv_cache_manager.allocate_slots(
request, num_new_tokens)
if new_blocks is None:
# The request cannot be scheduled.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment