Unverified Commit f8ece6e1 authored by Shawn Du's avatar Shawn Du Committed by GitHub
Browse files

[Core][v1] Unify allocating slots in prefill and decode in KV cache manager (#12608)

As mentioned in RFC https://github.com/vllm-project/vllm/issues/12254

,
this PR achieves the task: combine allocate_slots and append_slots.

There should be no functionality change, except that in decode, also
raise exception when num_tokens is zero (like prefill), and change the
unit test case accordingly.

@comaniac @rickyyx @WoosukKwon @youkaichao @heheda12345 @simon-mo

---------
Signed-off-by: default avatarShawn Du <shawnd200@outlook.com>
parent abfcdcdf
...@@ -164,7 +164,7 @@ def test_decode(): ...@@ -164,7 +164,7 @@ def test_decode():
req0.num_computed_tokens = 55 req0.num_computed_tokens = 55
for _ in range(4): for _ in range(4):
req0.append_output_token_ids(8) req0.append_output_token_ids(8)
new_blocks = manager.append_slots(req0, 4) new_blocks = manager.allocate_slots(req0, 4)
assert new_blocks is not None and len(new_blocks) == 0 assert new_blocks is not None and len(new_blocks) == 0
assert manager.req_to_blocks[req0.request_id][-2].block_hash is None assert manager.req_to_blocks[req0.request_id][-2].block_hash is None
...@@ -175,7 +175,7 @@ def test_decode(): ...@@ -175,7 +175,7 @@ def test_decode():
# the preallocated block. # the preallocated block.
for _ in range(5 + 10): for _ in range(5 + 10):
req0.append_output_token_ids(7) req0.append_output_token_ids(7)
new_blocks = manager.append_slots(req0, 15) new_blocks = manager.allocate_slots(req0, 15)
assert new_blocks is not None and len(new_blocks) == 0 assert new_blocks is not None and len(new_blocks) == 0
assert manager.req_to_blocks[req0.request_id][-2].block_hash is not None assert manager.req_to_blocks[req0.request_id][-2].block_hash is not None
...@@ -185,7 +185,7 @@ def test_decode(): ...@@ -185,7 +185,7 @@ def test_decode():
# the preallocated block. # the preallocated block.
for _ in range(6 + 11): for _ in range(6 + 11):
req0.append_output_token_ids(12) req0.append_output_token_ids(12)
new_blocks = manager.append_slots(req0, 17) new_blocks = manager.allocate_slots(req0, 17)
# Plus one preallocated block. # Plus one preallocated block.
assert new_blocks is not None and len(new_blocks) == 2 assert new_blocks is not None and len(new_blocks) == 2
...@@ -395,12 +395,14 @@ def test_preallocate_blocks(num_preallocate_tokens: int, block_size: int): ...@@ -395,12 +395,14 @@ def test_preallocate_blocks(num_preallocate_tokens: int, block_size: int):
req.num_computed_tokens = block_size req.num_computed_tokens = block_size
assert len(blocks) == 1 + num_preallocated_blocks assert len(blocks) == 1 + num_preallocated_blocks
# Assume all computed. # Assume all computed, only when num_preallocate_tokens > 0, we need to
manager.append_slots(req, block_size * (len(blocks) - 1)) # consume the previously preallocated blocks.
if num_preallocated_blocks > 0:
manager.allocate_slots(req, block_size * (len(blocks) - 1))
req.num_computed_tokens = block_size * len(blocks) req.num_computed_tokens = block_size * len(blocks)
# Append 1 block. # Append 1 block.
blocks = manager.append_slots(req, block_size) blocks = manager.allocate_slots(req, block_size)
assert len(blocks) == 1 + num_preallocated_blocks assert len(blocks) == 1 + num_preallocated_blocks
...@@ -503,7 +505,7 @@ def test_mm_prefix_caching(): ...@@ -503,7 +505,7 @@ def test_mm_prefix_caching():
# Append slots without allocating a new block. # Append slots without allocating a new block.
for _ in range(5): for _ in range(5):
req0.append_output_token_ids(8) req0.append_output_token_ids(8)
new_blocks = manager.append_slots(req0, 5) new_blocks = manager.allocate_slots(req0, 5)
assert new_blocks is not None and len(new_blocks) == 0 assert new_blocks is not None and len(new_blocks) == 0
# The just completed block should have hashes with extra keys. # The just completed block should have hashes with extra keys.
...@@ -603,7 +605,7 @@ def test_reset_prefix_cache(): ...@@ -603,7 +605,7 @@ def test_reset_prefix_cache():
unique_token_ids = [3] * 7 unique_token_ids = [3] * 7
all_token_ids = full_block_token_ids + unique_token_ids all_token_ids = full_block_token_ids + unique_token_ids
req0 = make_request("0", all_token_ids) req0 = make_request("0", all_token_ids)
blocks = manager.allocate_slots(req0, 55, []) blocks = manager.allocate_slots(req0, 55)
assert [b.block_id for b in blocks] == [0, 1, 2, 3] assert [b.block_id for b in blocks] == [0, 1, 2, 3]
unique_token_ids = [4] * 7 unique_token_ids = [4] * 7
...@@ -639,7 +641,7 @@ def test_uncache_blocks(): ...@@ -639,7 +641,7 @@ def test_uncache_blocks():
) )
req0 = make_request("0", list(range(30))) req0 = make_request("0", list(range(30)))
blocks = manager.allocate_slots(req0, 30, []) blocks = manager.allocate_slots(req0, 30)
assert [b.block_id for b in blocks] == [0, 1] assert [b.block_id for b in blocks] == [0, 1]
assert len(manager.cached_block_hash_to_block) == 1 assert len(manager.cached_block_hash_to_block) == 1
...@@ -648,7 +650,7 @@ def test_uncache_blocks(): ...@@ -648,7 +650,7 @@ def test_uncache_blocks():
# Simulate speculative tokens. # Simulate speculative tokens.
for _ in range(5): for _ in range(5):
req0.append_output_token_ids(8) req0.append_output_token_ids(8)
manager.append_slots(req0, 5) manager.allocate_slots(req0, 5)
assert len(manager.cached_block_hash_to_block) == 2 assert len(manager.cached_block_hash_to_block) == 2
# After sampling, assuming only 1 token is accepted. # After sampling, assuming only 1 token is accepted.
......
from collections import defaultdict from collections import defaultdict
from typing import Dict, Iterable, List, Optional, Tuple from typing import DefaultDict, Dict, Iterable, List, Optional, Tuple
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import cdiv from vllm.utils import cdiv
...@@ -67,7 +67,8 @@ class KVCacheManager: ...@@ -67,7 +67,8 @@ class KVCacheManager:
# Mapping from request ID to blocks to track the blocks allocated # Mapping from request ID to blocks to track the blocks allocated
# for each request, so that we can free the blocks when the request # for each request, so that we can free the blocks when the request
# is finished. # is finished.
self.req_to_blocks: Dict[str, List[KVCacheBlock]] = {} self.req_to_blocks: DefaultDict[str,
List[KVCacheBlock]] = defaultdict(list)
@property @property
def usage(self) -> float: def usage(self) -> float:
...@@ -115,157 +116,116 @@ class KVCacheManager: ...@@ -115,157 +116,116 @@ class KVCacheManager:
num_computed_tokens = len(computed_blocks) * self.block_size num_computed_tokens = len(computed_blocks) * self.block_size
return computed_blocks, num_computed_tokens return computed_blocks, num_computed_tokens
def append_slots(
self,
request: Request,
num_tokens: int,
) -> Optional[List[KVCacheBlock]]:
"""Append slots to the block table of the request.
We first append slots to already allocated blocks. If the allocated
blocks are not enough, we allocate new blocks.
Args:
request: The request to append slots.
num_tokens: The number of tokens to append.
Returns:
A list of new blocks if new blocks are allocated, or None
if new blocks are required but cannot be allocated.
"""
num_required_blocks = cdiv(request.num_computed_tokens + num_tokens,
self.block_size)
req_blocks = self.req_to_blocks[request.request_id]
num_new_blocks = num_required_blocks - len(req_blocks)
if num_new_blocks > self.free_block_queue.num_free_blocks:
# Need to allocate new blocks due to insufficient pre-allocated
# slots, but we cannot allocate new blocks due to the limit.
return None
if num_new_blocks <= 0:
# No new block is needed.
new_blocks = []
else:
# Get new blocks from the free block pool considering
# preallocated blocks.
num_new_blocks = min(
num_new_blocks + self.num_preallocate_blocks,
self.free_block_queue.num_free_blocks,
# Should not exceed the maximum number of blocks per request.
# This is especially because the block table has the shape
# [..., max_num_blocks_per_req].
# TODO(woosuk): Check and reject requests if
# num_prompt_tokens + max_tokens > max_model_len.
self.max_num_blocks_per_req - len(req_blocks),
)
assert num_new_blocks > 0
new_blocks = self._get_new_blocks(num_new_blocks)
req_blocks.extend(new_blocks)
if not self.enable_caching:
return new_blocks
num_computed_full_blocks = (request.num_computed_tokens //
self.block_size)
# NOTE(rickyx): We are assuming the `num_tokens` are actual
# tokens rather than lookahead slots (e.g. for speculative decoding).
# TODO(rickyx): When supporting speculative decoding, we will need to
# differentiate between them so that we can know how many blocks are
# full after appending the actual tokens.
num_full_blocks_after_append = (request.num_computed_tokens +
num_tokens) // self.block_size
assert num_full_blocks_after_append <= len(req_blocks)
new_full_blocks = req_blocks[
num_computed_full_blocks:num_full_blocks_after_append]
if new_full_blocks:
self._cache_full_blocks(
request=request,
blk_start_idx=num_computed_full_blocks,
full_blocks=new_full_blocks,
prev_block=req_blocks[num_computed_full_blocks - 1]
if num_computed_full_blocks >= 1 else None,
)
return new_blocks
def allocate_slots( def allocate_slots(
self, self,
request: Request, request: Request,
num_tokens: int, num_tokens: int,
computed_blocks: List[KVCacheBlock], new_computed_blocks: Optional[List[KVCacheBlock]] = None
) -> Optional[List[KVCacheBlock]]: ) -> Optional[List[KVCacheBlock]]:
"""Allocate slots for a new request. """Add slots for a request with new tokens to append.
Args: Args:
request: The request to allocate slots. request: The request to allocate slots.
num_tokens: The number of tokens to allocate. Note that this does num_tokens: The number of tokens to allocate. Note that this does
not include the tokens that have already been computed. not include the tokens that have already been computed.
computed_blocks: A list of computed blocks. new_computed_blocks: A list of new computed blocks just hitting the
prefix caching.
Blocks layout:
-----------------------------------------------------------------------
| < computed > | < new computed > | < new > | < pre-allocated > |
-----------------------------------------------------------------------
| < required > |
--------------------------------------------------
| < full > |
------------------------------------------------
| <new full> |
--------------
The following *_blocks are illustrated in this layout.
Returns: Returns:
A list of new allocated blocks. A list of new allocated blocks.
""" """
if num_tokens == 0: if num_tokens == 0:
raise ValueError( raise ValueError("num_tokens must be greater than 0")
f"num_tokens must be greater than 0, got {num_tokens}")
new_computed_blocks = new_computed_blocks or []
# The number of computed tokens is the number of computed tokens plus
# the new prefix caching hits
num_computed_tokens = (request.num_computed_tokens +
len(new_computed_blocks) * self.block_size)
num_required_blocks = cdiv(num_computed_tokens + num_tokens,
self.block_size)
req_blocks = self.req_to_blocks[request.request_id]
num_new_blocks = (num_required_blocks - len(req_blocks) -
len(new_computed_blocks))
# If a computed block of a request is an eviction candidate (in the # If a computed block of a request is an eviction candidate (in the
# free queue and ref_cnt == 0), it cannot be counted as a free block # free queue and ref_cnt == 0), it cannot be counted as a free block
# when allocating this request. # when allocating this request.
num_evictable_computed_blocks = sum(1 for blk in computed_blocks num_evictable_computed_blocks = sum(1 for blk in new_computed_blocks
if blk.ref_cnt == 0) if blk.ref_cnt == 0)
if (num_new_blocks > self.free_block_queue.num_free_blocks -
num_required_blocks = cdiv(num_tokens, self.block_size)
if (num_required_blocks > self.free_block_queue.num_free_blocks -
num_evictable_computed_blocks): num_evictable_computed_blocks):
# Cannot allocate new blocks. # Cannot allocate new blocks
return None return None
# Touch the computed blocks to make sure they won't be evicted. # Touch the computed blocks to make sure they won't be evicted.
if self.enable_caching: if self.enable_caching:
self._touch(computed_blocks) self._touch(new_computed_blocks)
else: else:
assert not computed_blocks, ( assert not new_computed_blocks, (
"Computed blocks should be empty when " "Computed blocks should be empty when "
"prefix caching is disabled") "prefix caching is disabled")
# Determine the number of new blocks to allocate considering # Append the new computed blocks to the request blocks until now to
# avoid the case where the new blocks cannot be allocated.
req_blocks.extend(new_computed_blocks)
# Start to handle new blocks
if num_new_blocks <= 0:
# No new block is needed.
new_blocks = []
else:
# Get new blocks from the free block pool considering
# preallocated blocks. # preallocated blocks.
num_new_blocks = min( num_new_blocks = min(
num_required_blocks + self.num_preallocate_blocks, num_new_blocks + self.num_preallocate_blocks,
self.free_block_queue.num_free_blocks, self.free_block_queue.num_free_blocks,
# Should not exceed the maximum number of blocks per request. # Should not exceed the maximum number of blocks per request.
# This is especially because the block table has the shape # This is especially because the block table has the shape
# [..., max_num_blocks_per_req]. # [..., max_num_blocks_per_req].
# TODO(woosuk): Check and reject requests if # TODO(woosuk): Check and reject requests if
# num_prompt_tokens + max_tokens > max_model_len. # num_prompt_tokens + max_tokens > max_model_len.
self.max_num_blocks_per_req - len(computed_blocks), self.max_num_blocks_per_req - len(req_blocks),
) )
assert num_new_blocks > 0 assert num_new_blocks > 0
# Concatenate the computed block IDs and the new block IDs. # Concatenate the computed block IDs and the new block IDs.
new_blocks = self._get_new_blocks(num_new_blocks) new_blocks = self._get_new_blocks(num_new_blocks)
self.req_to_blocks[request.request_id] = computed_blocks + new_blocks req_blocks.extend(new_blocks)
if not self.enable_caching: if not self.enable_caching:
return new_blocks return new_blocks
num_computed_tokens = len(computed_blocks) * self.block_size # NOTE(rickyx): We are assuming the `num_tokens` are actual
# tokens rather than lookahead slots (e.g. for speculative decoding).
# TODO(rickyx): When supporting speculative decoding, we will need to
# differentiate between them so that we can know how many blocks are
# full after appending the actual tokens.
num_full_blocks = (num_computed_tokens + num_tokens) // self.block_size num_full_blocks = (num_computed_tokens + num_tokens) // self.block_size
num_computed_full_blocks = num_computed_tokens // self.block_size
new_full_blocks = self.req_to_blocks[ new_full_blocks = req_blocks[num_computed_full_blocks:num_full_blocks]
request.request_id][len(computed_blocks):num_full_blocks]
if new_full_blocks: if new_full_blocks:
self._cache_full_blocks( self._cache_full_blocks(
request=request, request=request,
blk_start_idx=len(computed_blocks), blk_start_idx=num_computed_full_blocks,
# The new full blocks are the full blocks that are not computed. # The new full blocks are the full blocks that are not computed.
full_blocks=new_full_blocks, full_blocks=new_full_blocks,
prev_block=computed_blocks[-1] if computed_blocks else None, prev_block=(req_blocks[num_computed_full_blocks - 1]
) if num_computed_full_blocks > 0 else None))
return new_blocks return new_blocks
......
...@@ -138,7 +138,7 @@ class Scheduler: ...@@ -138,7 +138,7 @@ class Scheduler:
assert num_new_tokens > 0 assert num_new_tokens > 0
while True: while True:
new_blocks = self.kv_cache_manager.append_slots( new_blocks = self.kv_cache_manager.allocate_slots(
request, num_new_tokens) request, num_new_tokens)
if new_blocks is None: if new_blocks is None:
# The request cannot be scheduled. # The request cannot be scheduled.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment