Unverified Commit 52bf0665 authored by Yifan Qiao's avatar Yifan Qiao Committed by GitHub
Browse files

[Core][Hybrid allocator + connector] Support hybrid allocator + kv cache connector (#30166)


Signed-off-by: default avatarYifan Qiao <yifanqiao@berkeley.edu>
Co-authored-by: default avatarKuntaiDu <kuntai@uchicago.edu>
parent 5326c898
...@@ -21,13 +21,23 @@ from vllm.v1.kv_cache_interface import ChunkedLocalAttentionSpec, SlidingWindowS ...@@ -21,13 +21,23 @@ from vllm.v1.kv_cache_interface import ChunkedLocalAttentionSpec, SlidingWindowS
pytestmark = pytest.mark.cpu_test pytestmark = pytest.mark.cpu_test
def get_sliding_window_manager(sliding_window_spec, block_pool): def get_sliding_window_manager(sliding_window_spec, block_pool, enable_caching=True):
return SlidingWindowManager(sliding_window_spec, block_pool, kv_cache_group_id=0) return SlidingWindowManager(
sliding_window_spec,
block_pool,
enable_caching=enable_caching,
kv_cache_group_id=0,
)
def get_chunked_local_attention_manager(chunked_local_attention_spec, block_pool): def get_chunked_local_attention_manager(
chunked_local_attention_spec, block_pool, enable_caching=True
):
return ChunkedLocalAttentionManager( return ChunkedLocalAttentionManager(
chunked_local_attention_spec, block_pool, kv_cache_group_id=0 chunked_local_attention_spec,
block_pool,
enable_caching=enable_caching,
kv_cache_group_id=0,
) )
...@@ -332,11 +342,53 @@ def test_get_num_blocks_to_allocate(): ...@@ -332,11 +342,53 @@ def test_get_num_blocks_to_allocate():
] ]
assert ( assert (
manager.get_num_blocks_to_allocate("1", 20 * block_size, cached_blocks_1) == 20 manager.get_num_blocks_to_allocate("1", 20 * block_size, cached_blocks_1, 0)
== 20
) )
assert ( assert (
manager.get_num_blocks_to_allocate("2", 20 * block_size, cached_blocks_2) == 15 manager.get_num_blocks_to_allocate("2", 20 * block_size, cached_blocks_2, 0)
== 15
)
def test_evictable_cached_blocks_not_double_allocated():
block_size = 2
sliding_window_length = 2 * block_size
sliding_window_spec = SlidingWindowSpec(
block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
sliding_window=sliding_window_length,
)
block_pool = BlockPool(
num_gpu_blocks=100, enable_caching=True, hash_block_size=block_size
)
manager = get_sliding_window_manager(sliding_window_spec, block_pool)
request_id = "req"
evictable_block = block_pool.blocks[1] # ref_cnt == 0, eviction candidate
num_blocks_to_allocate = manager.get_num_blocks_to_allocate(
request_id=request_id,
num_tokens=2 * block_size,
new_computed_blocks=[evictable_block],
total_computed_tokens=block_size,
)
# Free capacity check should count evictable cached blocks, but allocation
# should only allocate the truly new block.
assert num_blocks_to_allocate == 2
manager.allocate_new_computed_blocks(
request_id,
[evictable_block],
num_local_computed_tokens=block_size,
num_external_computed_tokens=0,
) )
new_blocks = manager.allocate_new_blocks(request_id, num_tokens=4)
assert len(new_blocks) == 1
assert len(manager.req_to_blocks[request_id]) == 2
def test_chunked_local_attention_get_num_blocks_to_allocate(): def test_chunked_local_attention_get_num_blocks_to_allocate():
...@@ -359,8 +411,10 @@ def test_chunked_local_attention_get_num_blocks_to_allocate(): ...@@ -359,8 +411,10 @@ def test_chunked_local_attention_get_num_blocks_to_allocate():
] ]
assert ( assert (
manager.get_num_blocks_to_allocate("1", 20 * block_size, cached_blocks_1) == 20 manager.get_num_blocks_to_allocate("1", 20 * block_size, cached_blocks_1, 0)
== 20
) )
assert ( assert (
manager.get_num_blocks_to_allocate("2", 20 * block_size, cached_blocks_2) == 15 manager.get_num_blocks_to_allocate("2", 20 * block_size, cached_blocks_2, 0)
== 15
) )
...@@ -254,6 +254,10 @@ class BlockPool: ...@@ -254,6 +254,10 @@ class BlockPool:
[] if self.enable_kv_cache_events else None [] if self.enable_kv_cache_events else None
) )
for i, blk in enumerate(new_full_blocks): for i, blk in enumerate(new_full_blocks):
# Some blocks may be null blocks when enabling sparse attention like
# sliding window attention. We skip null blocks here.
if blk.is_null:
continue
assert blk.block_hash is None assert blk.block_hash is None
block_hash = new_block_hashes[i] block_hash = new_block_hashes[i]
...@@ -361,7 +365,7 @@ class BlockPool: ...@@ -361,7 +365,7 @@ class BlockPool:
) )
return True return True
def touch(self, blocks: tuple[Sequence[KVCacheBlock], ...]) -> None: def touch(self, blocks: Sequence[KVCacheBlock]) -> None:
"""Touch a block increases its reference count by 1, and may remove """Touch a block increases its reference count by 1, and may remove
the block from the free queue. This is used when a block is hit by the block from the free queue. This is used when a block is hit by
another request with the same prefix. another request with the same prefix.
...@@ -369,15 +373,14 @@ class BlockPool: ...@@ -369,15 +373,14 @@ class BlockPool:
Args: Args:
blocks: A list of blocks to touch. blocks: A list of blocks to touch.
""" """
for blocks_per_group in blocks: for block in blocks:
for block in blocks_per_group: # ref_cnt=0 means this block is in the free list (i.e. eviction
# ref_cnt=0 means this block is in the free list (i.e. eviction # candidate), so remove it.
# candidate), so remove it. if block.ref_cnt == 0 and not block.is_null:
if block.ref_cnt == 0 and not block.is_null: self.free_block_queue.remove(block)
self.free_block_queue.remove(block) block.ref_cnt += 1
block.ref_cnt += 1 if self.metrics_collector:
if self.metrics_collector: self.metrics_collector.on_block_accessed(block)
self.metrics_collector.on_block_accessed(block)
def free_blocks(self, ordered_blocks: Iterable[KVCacheBlock]) -> None: def free_blocks(self, ordered_blocks: Iterable[KVCacheBlock]) -> None:
"""Free a list of blocks. The blocks should be ordered by their """Free a list of blocks. The blocks should be ordered by their
......
...@@ -60,6 +60,7 @@ class KVCacheCoordinator(ABC): ...@@ -60,6 +60,7 @@ class KVCacheCoordinator(ABC):
get_manager_for_kv_cache_spec( get_manager_for_kv_cache_spec(
kv_cache_spec=kv_cache_group.kv_cache_spec, kv_cache_spec=kv_cache_group.kv_cache_spec,
block_pool=self.block_pool, block_pool=self.block_pool,
enable_caching=enable_caching,
kv_cache_group_id=i, kv_cache_group_id=i,
dcp_world_size=dcp_world_size, dcp_world_size=dcp_world_size,
pcp_world_size=pcp_world_size, pcp_world_size=pcp_world_size,
...@@ -73,6 +74,7 @@ class KVCacheCoordinator(ABC): ...@@ -73,6 +74,7 @@ class KVCacheCoordinator(ABC):
num_tokens: int, num_tokens: int,
new_computed_blocks: tuple[Sequence[KVCacheBlock], ...], new_computed_blocks: tuple[Sequence[KVCacheBlock], ...],
num_encoder_tokens: int, num_encoder_tokens: int,
total_computed_tokens: int,
) -> int: ) -> int:
""" """
Get the number of blocks needed to be allocated for the request. Get the number of blocks needed to be allocated for the request.
...@@ -85,9 +87,10 @@ class KVCacheCoordinator(ABC): ...@@ -85,9 +87,10 @@ class KVCacheCoordinator(ABC):
prefix caching. prefix caching.
num_encoder_tokens: The number of encoder tokens for allocating num_encoder_tokens: The number of encoder tokens for allocating
blocks for cross-attention. blocks for cross-attention.
total_computed_tokens: Include both local and external tokens.
Returns: Returns:
The number of blocks. The number of blocks to allocate.
""" """
num_blocks_to_allocate = 0 num_blocks_to_allocate = 0
for i, manager in enumerate(self.single_type_managers): for i, manager in enumerate(self.single_type_managers):
...@@ -95,30 +98,48 @@ class KVCacheCoordinator(ABC): ...@@ -95,30 +98,48 @@ class KVCacheCoordinator(ABC):
# For cross-attention, we issue a single static allocation # For cross-attention, we issue a single static allocation
# of blocks based on the number of encoder input tokens. # of blocks based on the number of encoder input tokens.
num_blocks_to_allocate += manager.get_num_blocks_to_allocate( num_blocks_to_allocate += manager.get_num_blocks_to_allocate(
request_id, num_encoder_tokens, [] request_id, num_encoder_tokens, [], 0
) )
else: else:
num_blocks_to_allocate += manager.get_num_blocks_to_allocate( num_blocks_to_allocate += manager.get_num_blocks_to_allocate(
request_id, num_tokens, new_computed_blocks[i] request_id,
num_tokens,
new_computed_blocks[i],
total_computed_tokens,
) )
return num_blocks_to_allocate return num_blocks_to_allocate
def save_new_computed_blocks( def allocate_new_computed_blocks(
self, request_id: str, new_computed_blocks: tuple[Sequence[KVCacheBlock], ...] self,
request_id: str,
new_computed_blocks: tuple[Sequence[KVCacheBlock], ...],
num_local_computed_tokens: int,
num_external_computed_tokens: int,
) -> None: ) -> None:
""" """
Add the new computed blocks to the request. Add the new computed blocks to the request. Optionally allocate new
blocks for external computed tokens (if any).
Args: Args:
request_id: The request ID. request_id: The request ID.
new_computed_blocks: The new computed blocks just hitting the new_computed_blocks: The new computed blocks just hitting the
prefix cache. prefix cache.
num_local_computed_tokens: The number of local computed tokens.
num_external_computed_tokens: The number of external computed tokens.
""" """
for i, manager in enumerate(self.single_type_managers): for i, manager in enumerate(self.single_type_managers):
manager.save_new_computed_blocks(request_id, new_computed_blocks[i]) manager.allocate_new_computed_blocks(
request_id,
new_computed_blocks[i],
num_local_computed_tokens,
num_external_computed_tokens,
)
def allocate_new_blocks( def allocate_new_blocks(
self, request_id: str, num_tokens: int, num_encoder_tokens: int = 0 self,
request_id: str,
num_tokens: int,
num_encoder_tokens: int = 0,
) -> tuple[list[KVCacheBlock], ...]: ) -> tuple[list[KVCacheBlock], ...]:
""" """
Allocate new blocks for the request to give it at least `num_tokens` Allocate new blocks for the request to give it at least `num_tokens`
...@@ -184,17 +205,20 @@ class KVCacheCoordinator(ABC): ...@@ -184,17 +205,20 @@ class KVCacheCoordinator(ABC):
for manager in self.single_type_managers for manager in self.single_type_managers
] ]
def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None: def remove_skipped_blocks(
self, request_id: str, total_computed_tokens: int
) -> None:
""" """
Remove the blocks that are no longer needed from `blocks` and replace Remove the blocks that are no longer needed from `blocks` and replace
the removed blocks with null_block. the removed blocks with null_block.
Args: Args:
request_id: The request ID. request_id: The request ID.
num_computed_tokens: The number of tokens that have been computed. total_computed_tokens: The total number of computed tokens, including
local computed tokens and external computed tokens.
""" """
for manager in self.single_type_managers: for manager in self.single_type_managers:
manager.remove_skipped_blocks(request_id, num_computed_tokens) manager.remove_skipped_blocks(request_id, total_computed_tokens)
def get_blocks(self, request_id: str) -> tuple[list[KVCacheBlock], ...]: def get_blocks(self, request_id: str) -> tuple[list[KVCacheBlock], ...]:
""" """
......
...@@ -210,6 +210,7 @@ class KVCacheManager: ...@@ -210,6 +210,7 @@ class KVCacheManager:
num_new_computed_tokens: int = 0, num_new_computed_tokens: int = 0,
new_computed_blocks: KVCacheBlocks | None = None, new_computed_blocks: KVCacheBlocks | None = None,
num_lookahead_tokens: int = 0, num_lookahead_tokens: int = 0,
num_external_computed_tokens: int = 0,
delay_cache_blocks: bool = False, delay_cache_blocks: bool = False,
num_encoder_tokens: int = 0, num_encoder_tokens: int = 0,
) -> KVCacheBlocks | None: ) -> KVCacheBlocks | None:
...@@ -217,16 +218,16 @@ class KVCacheManager: ...@@ -217,16 +218,16 @@ class KVCacheManager:
Args: Args:
request: The request to allocate slots. request: The request to allocate slots.
num_new_tokens: The number of tokens to allocate, including external num_new_tokens: The number of new tokens to be allocated and computed.
tokens. Note that this does not include tokens that have
already been computed locally (i.e. new_computed_blocks).
num_new_computed_tokens: The number of new computed tokens just num_new_computed_tokens: The number of new computed tokens just
hitting the prefix caching, excluding external tokens. hitting the prefix caching, excluding external tokens.
new_computed_blocks: The cached blocks for the above new computed new_computed_blocks: The cached blocks for the above new computed
tokens. tokens, grouped as a tuple by kv cache groups.
num_lookahead_tokens: The number of speculative tokens to allocate. num_lookahead_tokens: The number of speculative tokens to allocate.
This is used by spec decode proposers with kv-cache such This is used by spec decode proposers with kv-cache such
as eagle. as eagle.
num_external_computed_tokens: The number of tokens that their
KV caches are not cached by vLLM but cached by the connector.
delay_cache_blocks: Whether to skip caching the blocks. This is delay_cache_blocks: Whether to skip caching the blocks. This is
used by P/D when allocating blocks used in a KV transfer used by P/D when allocating blocks used in a KV transfer
which will complete in a future step. which will complete in a future step.
...@@ -236,29 +237,81 @@ class KVCacheManager: ...@@ -236,29 +237,81 @@ class KVCacheManager:
Blocks layout: Blocks layout:
``` ```
----------------------------------------------------------------------- ----------------------------------------------------------------------
| < computed > | < new computed > | < new > | < pre-allocated > | | < comp > | < new_comp > | < ext_comp > | < new > | < lookahead > |
----------------------------------------------------------------------- ----------------------------------------------------------------------
| < required > | | < to be computed > |
-------------------------------------------------- ----------------------------------------------------------------------
| < full > | | < to be allocated > |
------------------------------------------------ ----------------------------------------------------------------------
| <new full> | | < to be cached (roughly, |
-------------- | details below)> |
----------------------------------------------------------------------
| Prefix-cached tokens from either vLLM |
| or connector. Can be safely removed if |
| they are outside sliding window. |
----------------------------------------------------------------------
| < cached by vLLM > | not cached by |
| vLLM, but |
| ref_cnt | ref_cnt not | cached by |
| increased| increased yet| connector |
----------------------------------------------------------------------
``` ```
The following *_blocks are illustrated in this layout.
Abbrivations:
```
comp = request.num_computed_tokens
new_comp = num_new_computed_tokens
= len(new_computed_blocks) * block_size
ext_comp = num_external_computed_tokens, cached by the connector
new = num_new_tokens, including unverified draft tokens
lookahead = num_lookahead_tokens
```
NOTE: for new tokens which include both verified and unverified draft
tokens, we only cache the verified tokens (by capping the number at
`request.num_tokens`).
The allocation has three stages:
- Free unnecessary blocks in `comp` and check
if we have sufficient free blocks (return None if not).
- Handle prefix tokens (`comp + new_comp + ext_comp`):
- Free unnecessary blocks (e.g. outside sliding window)
- Allocate new blocks for `ext_comp` tokens inside
sliding window
- Allocate new blocks for tokens to be computed (`new + lookahead`)
Returns: Returns:
A list of new allocated blocks. A list of new allocated blocks.
""" """
if num_new_tokens == 0: # When loading KV data asynchronously, we may have zero new tokens to
raise ValueError("num_new_tokens must be greater than 0") # compute while still allocating slots for externally computed tokens.
if num_new_tokens == 0 and num_external_computed_tokens == 0:
raise ValueError(
"num_new_tokens must be greater than 0 when there are no "
"external computed tokens"
)
if new_computed_blocks is not None: if new_computed_blocks is not None:
new_computed_block_list = new_computed_blocks.blocks new_computed_block_list = new_computed_blocks.blocks
else: else:
new_computed_block_list = self.empty_kv_cache_blocks.blocks new_computed_block_list = self.empty_kv_cache_blocks.blocks
# The number of computed tokens is the number of computed tokens plus
# the new prefix caching hits
num_local_computed_tokens = (
request.num_computed_tokens + num_new_computed_tokens
)
total_computed_tokens = min(
num_local_computed_tokens + num_external_computed_tokens,
self.max_model_len,
)
num_tokens_need_slot = min(
total_computed_tokens + num_new_tokens + num_lookahead_tokens,
self.max_model_len,
)
# Free the blocks that are skipped during the attention computation # Free the blocks that are skipped during the attention computation
# (e.g., tokens outside the sliding window). # (e.g., tokens outside the sliding window).
# We can do this even if we cannot schedule this request due to # We can do this even if we cannot schedule this request due to
...@@ -266,15 +319,7 @@ class KVCacheManager: ...@@ -266,15 +319,7 @@ class KVCacheManager:
# Should call this function before allocating new blocks to reduce # Should call this function before allocating new blocks to reduce
# the number of evicted blocks. # the number of evicted blocks.
self.coordinator.remove_skipped_blocks( self.coordinator.remove_skipped_blocks(
request.request_id, request.num_computed_tokens request.request_id, total_computed_tokens
)
# The number of computed tokens is the number of computed tokens plus
# the new prefix caching hits
num_computed_tokens = request.num_computed_tokens + num_new_computed_tokens
num_tokens_need_slot = min(
num_computed_tokens + num_new_tokens + num_lookahead_tokens,
self.max_model_len,
) )
num_blocks_to_allocate = self.coordinator.get_num_blocks_to_allocate( num_blocks_to_allocate = self.coordinator.get_num_blocks_to_allocate(
...@@ -282,25 +327,25 @@ class KVCacheManager: ...@@ -282,25 +327,25 @@ class KVCacheManager:
num_tokens=num_tokens_need_slot, num_tokens=num_tokens_need_slot,
new_computed_blocks=new_computed_block_list, new_computed_blocks=new_computed_block_list,
num_encoder_tokens=num_encoder_tokens, num_encoder_tokens=num_encoder_tokens,
total_computed_tokens=num_local_computed_tokens
+ num_external_computed_tokens,
) )
if num_blocks_to_allocate > self.block_pool.get_num_free_blocks(): if num_blocks_to_allocate > self.block_pool.get_num_free_blocks():
# Cannot allocate new blocks # Cannot allocate new blocks
return None return None
# Touch the computed blocks to make sure they won't be evicted. if (
if self.enable_caching: new_computed_block_list is not self.empty_kv_cache_blocks.blocks
self.block_pool.touch(new_computed_block_list) or num_external_computed_tokens > 0
else: ):
assert not any(new_computed_block_list), (
"Computed blocks should be empty when prefix caching is disabled"
)
if new_computed_block_list is not self.empty_kv_cache_blocks.blocks:
# Append the new computed blocks to the request blocks until now to # Append the new computed blocks to the request blocks until now to
# avoid the case where the new blocks cannot be allocated. # avoid the case where the new blocks cannot be allocated.
self.coordinator.save_new_computed_blocks( self.coordinator.allocate_new_computed_blocks(
request.request_id, new_computed_block_list request_id=request.request_id,
new_computed_blocks=new_computed_block_list,
num_local_computed_tokens=num_local_computed_tokens,
num_external_computed_tokens=num_external_computed_tokens,
) )
new_blocks = self.coordinator.allocate_new_blocks( new_blocks = self.coordinator.allocate_new_blocks(
...@@ -312,12 +357,14 @@ class KVCacheManager: ...@@ -312,12 +357,14 @@ class KVCacheManager:
if not self.enable_caching or delay_cache_blocks: if not self.enable_caching or delay_cache_blocks:
return self.create_kv_cache_blocks(new_blocks) return self.create_kv_cache_blocks(new_blocks)
# NOTE(woosuk): We want to commit (cache) up to num_computed_tokens + # NOTE(woosuk): We want to commit (cache) up to num_local_computed_tokens
# num_new_tokens, but must exclude "non-committable" tokens (e.g., # + num_external_computed_tokens + num_new_tokens, but must exclude
# draft tokens that could be rejected). Therefore, we cap the number # "non-committable" tokens (e.g., draft tokens that could be rejected).
# at `request.num_tokens`, ensuring only "finalized" tokens are cached. # Therefore, we cap the number at `request.num_tokens`, ensuring only
# "finalized" tokens are cached.
num_tokens_to_cache = min( num_tokens_to_cache = min(
num_computed_tokens + num_new_tokens, request.num_tokens total_computed_tokens + num_new_tokens,
request.num_tokens,
) )
self.coordinator.cache_blocks(request, num_tokens_to_cache) self.coordinator.cache_blocks(request, num_tokens_to_cache)
...@@ -333,6 +380,19 @@ class KVCacheManager: ...@@ -333,6 +380,19 @@ class KVCacheManager:
""" """
self.coordinator.free(request.request_id) self.coordinator.free(request.request_id)
def remove_skipped_blocks(
self, request_id: str, total_computed_tokens: int
) -> None:
"""Remove the blocks that are no longer needed from `blocks` and replace
the removed blocks with null_block.
Args:
request_id: The request ID.
total_computed_tokens: The total number of computed tokens, including
local computed tokens and external computed tokens.
"""
self.coordinator.remove_skipped_blocks(request_id, total_computed_tokens)
def evict_blocks(self, block_ids: set[int]) -> None: def evict_blocks(self, block_ids: set[int]) -> None:
"""evict blocks from the prefix cache by their block IDs. """evict blocks from the prefix cache by their block IDs.
...@@ -408,7 +468,13 @@ class KVCacheManager: ...@@ -408,7 +468,13 @@ class KVCacheManager:
return self.get_blocks(request_id).get_block_ids() return self.get_blocks(request_id).get_block_ids()
def cache_blocks(self, request: Request, num_computed_tokens: int) -> None: def cache_blocks(self, request: Request, num_computed_tokens: int) -> None:
"""Cache the blocks for the request, if enabled.""" """Cache the blocks for the request, if enabled.
Args:
request: The request to cache the blocks.
num_computed_tokens: The number of computed tokens, including tokens
that are already cached and tokens to be cached.
"""
if self.enable_caching: if self.enable_caching:
self.coordinator.cache_blocks(request, num_computed_tokens) self.coordinator.cache_blocks(request, num_computed_tokens)
......
...@@ -587,10 +587,11 @@ class Scheduler(SchedulerInterface): ...@@ -587,10 +587,11 @@ class Scheduler(SchedulerInterface):
new_blocks = self.kv_cache_manager.allocate_slots( new_blocks = self.kv_cache_manager.allocate_slots(
request, request,
num_new_tokens + num_external_computed_tokens, num_new_tokens,
num_new_local_computed_tokens, num_new_computed_tokens=num_new_local_computed_tokens,
new_computed_blocks, new_computed_blocks=new_computed_blocks,
num_lookahead_tokens=effective_lookahead_tokens, num_lookahead_tokens=effective_lookahead_tokens,
num_external_computed_tokens=num_external_computed_tokens,
delay_cache_blocks=load_kv_async, delay_cache_blocks=load_kv_async,
num_encoder_tokens=num_encoder_tokens, num_encoder_tokens=num_encoder_tokens,
) )
...@@ -606,7 +607,7 @@ class Scheduler(SchedulerInterface): ...@@ -606,7 +607,7 @@ class Scheduler(SchedulerInterface):
if self.connector is not None: if self.connector is not None:
self.connector.update_state_after_alloc( self.connector.update_state_after_alloc(
request, request,
new_computed_blocks + new_blocks, self.kv_cache_manager.get_blocks(request.request_id),
num_external_computed_tokens, num_external_computed_tokens,
) )
...@@ -1580,6 +1581,13 @@ class Scheduler(SchedulerInterface): ...@@ -1580,6 +1581,13 @@ class Scheduler(SchedulerInterface):
if self.connector is None: if self.connector is None:
return False, None return False, None
# Free any out-of-window prefix blocks before we hand the block table to
# the connector.
self.kv_cache_manager.remove_skipped_blocks(
request_id=request.request_id,
total_computed_tokens=request.num_tokens,
)
block_ids = self.kv_cache_manager.get_block_ids(request.request_id) block_ids = self.kv_cache_manager.get_block_ids(request.request_id)
if not isinstance(self.connector, SupportsHMA): if not isinstance(self.connector, SupportsHMA):
......
...@@ -30,6 +30,7 @@ class SingleTypeKVCacheManager(ABC): ...@@ -30,6 +30,7 @@ class SingleTypeKVCacheManager(ABC):
self, self,
kv_cache_spec: KVCacheSpec, kv_cache_spec: KVCacheSpec,
block_pool: BlockPool, block_pool: BlockPool,
enable_caching: bool,
kv_cache_group_id: int, kv_cache_group_id: int,
dcp_world_size: int = 1, dcp_world_size: int = 1,
pcp_world_size: int = 1, pcp_world_size: int = 1,
...@@ -48,6 +49,7 @@ class SingleTypeKVCacheManager(ABC): ...@@ -48,6 +49,7 @@ class SingleTypeKVCacheManager(ABC):
self.block_size *= dcp_world_size * pcp_world_size self.block_size *= dcp_world_size * pcp_world_size
self.kv_cache_spec = kv_cache_spec self.kv_cache_spec = kv_cache_spec
self.block_pool = block_pool self.block_pool = block_pool
self.enable_caching = enable_caching
# Mapping from request ID to blocks to track the blocks allocated # Mapping from request ID to blocks to track the blocks allocated
# for each request, so that we can free the blocks when the request # for each request, so that we can free the blocks when the request
...@@ -68,6 +70,7 @@ class SingleTypeKVCacheManager(ABC): ...@@ -68,6 +70,7 @@ class SingleTypeKVCacheManager(ABC):
request_id: str, request_id: str,
num_tokens: int, num_tokens: int,
new_computed_blocks: Sequence[KVCacheBlock], new_computed_blocks: Sequence[KVCacheBlock],
total_computed_tokens: int,
) -> int: ) -> int:
""" """
Get the number of blocks needed to be allocated for the request. Get the number of blocks needed to be allocated for the request.
...@@ -78,46 +81,121 @@ class SingleTypeKVCacheManager(ABC): ...@@ -78,46 +81,121 @@ class SingleTypeKVCacheManager(ABC):
tokens that are already allocated). tokens that are already allocated).
new_computed_blocks: The new computed blocks just hitting the new_computed_blocks: The new computed blocks just hitting the
prefix caching. prefix caching.
total_computed_tokens: Include both local and external computed
tokens.
Returns: Returns:
The number of blocks. The number of blocks to allocate.
""" """
num_required_blocks = cdiv(num_tokens, self.block_size) num_required_blocks = cdiv(num_tokens, self.block_size)
num_new_blocks = ( num_req_blocks = len(self.req_to_blocks.get(request_id, ()))
num_required_blocks
- len(new_computed_blocks) if request_id in self.num_cached_block:
- len(self.req_to_blocks[request_id]) # Fast-path: a running request won't have any new prefix-cache hits.
assert len(new_computed_blocks) == 0
# NOTE: With speculative decoding, request's blocks may be allocated
# for draft tokens which are later rejected. In this case,
# num_required_blocks may be smaller than num_req_blocks.
return max(num_required_blocks - num_req_blocks, 0)
num_skipped_tokens = self.get_num_skipped_tokens(total_computed_tokens)
num_local_computed_blocks = len(new_computed_blocks) + num_req_blocks
# Number of whole blocks that are skipped by the attention window.
# If nothing is skipped, this is 0.
num_skipped_blocks = num_skipped_tokens // self.block_size
# We need blocks for the non-skipped suffix. If there are still
# local-computed blocks inside the window, they contribute to the
# required capacity; otherwise, skipped blocks dominate.
num_new_blocks = max(
num_required_blocks - max(num_skipped_blocks, num_local_computed_blocks),
0,
) )
# If a computed block of a request is an eviction candidate (in the
# free queue and ref_cnt == 0), it will be changed from a free block # Among the `new_computed_blocks`, the first `num_skipped_blocks` worth
# to a computed block when the request is allocated, so we also count # of blocks are skipped; `num_req_blocks` of those may already be in
# it as needed to be allocated. # `req_to_blocks`, so only skip the remainder from `new_computed_blocks`.
num_evictable_computed_blocks = sum( num_skipped_new_computed_blocks = max(0, num_skipped_blocks - num_req_blocks)
blk.ref_cnt == 0 and not blk.is_null for blk in new_computed_blocks
# If a computed block is an eviction candidate (in the free queue and
# ref_cnt == 0), it will be removed from the free queue when touched by
# the allocated request, so we must count it in the free-capacity check.
num_evictable_blocks = sum(
blk.ref_cnt == 0 and not blk.is_null
for blk in new_computed_blocks[num_skipped_new_computed_blocks:]
) )
return num_new_blocks + num_evictable_computed_blocks return num_new_blocks + num_evictable_blocks
def save_new_computed_blocks( def allocate_new_computed_blocks(
self, request_id: str, new_computed_blocks: Sequence[KVCacheBlock] self,
request_id: str,
new_computed_blocks: Sequence[KVCacheBlock],
num_local_computed_tokens: int,
num_external_computed_tokens: int,
) -> None: ) -> None:
""" """
Add the new computed blocks to the request. Add the new computed blocks to the request. This involves three steps:
1. Touch the computed blocks to make sure they won't be evicted.
1.5. (Optional) For sliding window, skip blocks are padded with null blocks.
2. Add the remaining computed blocks.
3. (Optional) For KV connectors, allocate new blocks for external computed
tokens (if any).
Args: Args:
request_id: The request ID. request_id: The request ID.
new_computed_blocks: The new computed blocks just hitting the new_computed_blocks: The new computed blocks just hitting the
prefix cache. prefix cache.
num_local_computed_tokens: The number of local computed tokens.
num_external_computed_tokens: The number of external computed tokens.
""" """
if request_id not in self.num_cached_block:
# A new request. if request_id in self.num_cached_block:
req_blocks = self.req_to_blocks[request_id] # Fast-path: a running request won't have any new prefix-cache hits.
assert len(req_blocks) == 0 # It should not have any new computed blocks.
req_blocks.extend(new_computed_blocks)
self.num_cached_block[request_id] = len(new_computed_blocks)
else:
# A running request. Should not have new computed blocks.
assert len(new_computed_blocks) == 0 assert len(new_computed_blocks) == 0
return
# A new request.
req_blocks = self.req_to_blocks[request_id]
assert len(req_blocks) == 0
num_total_computed_tokens = (
num_local_computed_tokens + num_external_computed_tokens
)
num_skipped_tokens = self.get_num_skipped_tokens(num_total_computed_tokens)
num_skipped_blocks = num_skipped_tokens // self.block_size
if num_skipped_blocks > 0:
# It is possible that all new computed blocks are skipped when
# num_skipped_blocks > len(new_computed_blocks).
new_computed_blocks = new_computed_blocks[num_skipped_blocks:]
# Some external computed tokens may be skipped too.
num_external_computed_tokens = min(
num_total_computed_tokens - num_skipped_tokens,
num_external_computed_tokens,
)
# Touch the computed blocks to make sure they won't be evicted.
if self.enable_caching:
self.block_pool.touch(new_computed_blocks)
else:
assert not any(new_computed_blocks), (
"Computed blocks should be empty when prefix caching is disabled"
)
# Skip blocks are padded with null blocks.
req_blocks.extend([self._null_block] * num_skipped_blocks)
# Add the remaining computed blocks.
req_blocks.extend(new_computed_blocks)
# All cached hits (including skipped nulls) are already cached; mark
# them so cache_blocks() will not try to re-cache blocks that already
# have a block_hash set.
self.num_cached_block[request_id] = len(req_blocks)
if num_external_computed_tokens > 0:
# Allocate new blocks for external computed tokens.
allocated_blocks = self.block_pool.get_new_blocks(
cdiv(num_total_computed_tokens, self.block_size) - len(req_blocks)
)
req_blocks.extend(allocated_blocks)
def allocate_new_blocks( def allocate_new_blocks(
self, request_id: str, num_tokens: int self, request_id: str, num_tokens: int
...@@ -252,7 +330,9 @@ class SingleTypeKVCacheManager(ABC): ...@@ -252,7 +330,9 @@ class SingleTypeKVCacheManager(ABC):
raise NotImplementedError raise NotImplementedError
def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None: def remove_skipped_blocks(
self, request_id: str, total_computed_tokens: int
) -> None:
""" """
Remove and free the blocks that are no longer needed for attention computation. Remove and free the blocks that are no longer needed for attention computation.
The removed blocks should be replaced by null_block. The removed blocks should be replaced by null_block.
...@@ -262,18 +342,24 @@ class SingleTypeKVCacheManager(ABC): ...@@ -262,18 +342,24 @@ class SingleTypeKVCacheManager(ABC):
Args: Args:
request_id: The request ID. request_id: The request ID.
num_computed_tokens: The number of tokens that have been computed. total_computed_tokens: The total number of computed tokens, including
local computed tokens and external computed tokens.
""" """
# Remove the blocks that will be skipped during attention computation. # Remove the blocks that will be skipped during attention computation.
num_skipped_tokens = self.get_num_skipped_tokens(num_computed_tokens) num_skipped_tokens = self.get_num_skipped_tokens(total_computed_tokens)
if num_skipped_tokens <= 0: if num_skipped_tokens <= 0:
# This indicates that ALL tokens are inside attention window. # This indicates that ALL tokens are inside attention window.
# Thus we do not need to free any blocks outside attention window. # Thus we do not need to free any blocks outside attention window.
# A typical case is full attention that we never free any token # A typical case is full attention that we never free any token
# before the request is finished. # before the request is finished.
return return
num_skipped_blocks = num_skipped_tokens // self.block_size
blocks = self.req_to_blocks[request_id] blocks = self.req_to_blocks[request_id]
num_skipped_blocks = num_skipped_tokens // self.block_size
# `num_skipped_tokens` may include tokens that haven't been allocated yet
# (e.g., when the attention window moves into the external computed tokens
# range), so we must cap to the number of blocks that currently exist for
# this request.
num_skipped_blocks = min(num_skipped_blocks, len(blocks))
removed_blocks: list[KVCacheBlock] = [] removed_blocks: list[KVCacheBlock] = []
# Because the block starts from index 0, the num_skipped_block-th block # Because the block starts from index 0, the num_skipped_block-th block
# corresponds to index num_skipped_blocks - 1. # corresponds to index num_skipped_blocks - 1.
...@@ -486,7 +572,7 @@ class SlidingWindowManager(SingleTypeKVCacheManager): ...@@ -486,7 +572,7 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
Returns: Returns:
The number of tokens that will be skipped for attention computation. The number of tokens that will be skipped for attention computation.
""" """
return num_computed_tokens - self.sliding_window + 1 return max(0, num_computed_tokens - self.sliding_window + 1)
def get_num_common_prefix_blocks(self, running_request_id: str) -> int: def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
""" """
...@@ -711,6 +797,7 @@ class MambaManager(SingleTypeKVCacheManager): ...@@ -711,6 +797,7 @@ class MambaManager(SingleTypeKVCacheManager):
request_id: str, request_id: str,
num_tokens: int, num_tokens: int,
new_computed_blocks: Sequence[KVCacheBlock], new_computed_blocks: Sequence[KVCacheBlock],
total_computed_tokens: int,
) -> int: ) -> int:
# Allocate extra `num_speculative_blocks` blocks for # Allocate extra `num_speculative_blocks` blocks for
# speculative decoding (MTP/EAGLE) with linear attention. # speculative decoding (MTP/EAGLE) with linear attention.
...@@ -721,7 +808,7 @@ class MambaManager(SingleTypeKVCacheManager): ...@@ -721,7 +808,7 @@ class MambaManager(SingleTypeKVCacheManager):
* self.kv_cache_spec.num_speculative_blocks * self.kv_cache_spec.num_speculative_blocks
) )
return super().get_num_blocks_to_allocate( return super().get_num_blocks_to_allocate(
request_id, num_tokens, new_computed_blocks request_id, num_tokens, new_computed_blocks, total_computed_tokens
) )
def allocate_new_blocks( def allocate_new_blocks(
...@@ -749,8 +836,12 @@ class MambaManager(SingleTypeKVCacheManager): ...@@ -749,8 +836,12 @@ class MambaManager(SingleTypeKVCacheManager):
class CrossAttentionManager(SingleTypeKVCacheManager): class CrossAttentionManager(SingleTypeKVCacheManager):
"""Manager for cross-attention KV cache in encoder-decoder models.""" """Manager for cross-attention KV cache in encoder-decoder models."""
def save_new_computed_blocks( def allocate_new_computed_blocks(
self, request_id: str, new_computed_blocks: Sequence[KVCacheBlock] self,
request_id: str,
new_computed_blocks: Sequence[KVCacheBlock],
num_local_computed_tokens: int,
num_external_computed_tokens: int,
) -> None: ) -> None:
# We do not cache blocks for cross-attention to be shared between # We do not cache blocks for cross-attention to be shared between
# requests, so `new_computed_blocks` should always be empty. # requests, so `new_computed_blocks` should always be empty.
......
...@@ -624,7 +624,7 @@ class Worker(WorkerBase): ...@@ -624,7 +624,7 @@ class Worker(WorkerBase):
output = self.model_runner.execute_model( output = self.model_runner.execute_model(
scheduler_output, intermediate_tensors scheduler_output, intermediate_tensors
) )
if isinstance(output, (ModelRunnerOutput, NoneType)): if isinstance(output, ModelRunnerOutput | NoneType):
return output return output
assert isinstance(output, IntermediateTensors) assert isinstance(output, IntermediateTensors)
......
...@@ -304,6 +304,13 @@ class TPUWorker: ...@@ -304,6 +304,13 @@ class TPUWorker:
def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None: def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
"""Allocate GPU KV cache with the specified kv_cache_config.""" """Allocate GPU KV cache with the specified kv_cache_config."""
# Init kv cache connector here, because it requires
# `kv_cache_config`.
# NOTE(Kuntai): This need to be done before `initialize_kv_cache`,
# because `initialize_kv_cache` will inject kv cache groups not
# related to kv cache connector (e.g. kv cache sharing layers).
ensure_kv_transfer_initialized(self.vllm_config, kv_cache_config)
self.model_runner.initialize_kv_cache(kv_cache_config) self.model_runner.initialize_kv_cache(kv_cache_config)
def check_health(self) -> None: def check_health(self) -> None:
...@@ -336,8 +343,6 @@ class TPUWorker: ...@@ -336,8 +343,6 @@ class TPUWorker:
parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size
) )
ensure_kv_transfer_initialized(vllm_config)
def shutdown(self) -> None: def shutdown(self) -> None:
self.model_runner.ensure_kv_transfer_shutdown() self.model_runner.ensure_kv_transfer_shutdown()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment