Unverified Commit efe73e9b authored by Kuntai Du's avatar Kuntai Du Committed by GitHub
Browse files

[Core][Hybrid allocator + connector 2/n] Unify `remove_skipped_blocks` by...


[Core][Hybrid allocator + connector 2/n] Unify `remove_skipped_blocks` by `get_last_useful_token` (#25431)
Signed-off-by: default avatarKuntaiDu <kuntai@uchicago.edu>
parent 0b8e871e
......@@ -243,18 +243,53 @@ class SingleTypeKVCacheManager(ABC):
raise NotImplementedError
@abstractmethod
def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None:
"""
Remove the blocks that are no longer needed from `blocks` and free the
blocks. The removed blocks should be replaced by null_block.
Need to be customized for each attention type.
Remove and free the blocks that are no longer needed for attention computation.
The removed blocks should be replaced by null_block.
This function depends on `get_num_skipped_tokens`, which need to be implemented
differently for each attention type.
Args:
request_id: The request ID.
num_computed_tokens: The number of tokens that have been computed.
"""
raise NotImplementedError
# Remove the blocks that will be skipped during attention computation.
num_skipped_tokens = self.get_num_skipped_tokens(num_computed_tokens)
if num_skipped_tokens <= 0:
# This indicates that ALL tokens are inside attention window.
# Thus we do not need to free any blocks outside attention window.
# A typical case is full attention that we never free any token
# before the request is finished.
return
num_skipped_blocks = num_skipped_tokens // self.block_size
blocks = self.req_to_blocks[request_id]
removed_blocks: list[KVCacheBlock] = []
# Because the block starts from index 0, the num_skipped_block-th block
# corresponds to index num_skipped_blocks - 1.
for i in range(num_skipped_blocks - 1, -1, -1):
if blocks[i] == self._null_block:
# If the block is already a null block, the blocks before it
# should also have been set to null blocks by the previous calls
# to this function.
break
removed_blocks.append(blocks[i])
blocks[i] = self._null_block
self.block_pool.free_blocks(removed_blocks)
def get_num_skipped_tokens(self, num_computed_tokens: int) -> int:
"""
Get the number of tokens that will be skipped for attention computation.
Args:
num_computed_tokens: The number of tokens that have been computed.
Returns:
The number of tokens that will be skipped for attention computation.
"""
# The default behavior is to not skip any tokens.
return 0
class FullAttentionManager(SingleTypeKVCacheManager):
......@@ -298,10 +333,6 @@ class FullAttentionManager(SingleTypeKVCacheManager):
computed.pop()
return computed_blocks
def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None:
# No need to remove blocks for full attention.
pass
def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
blocks = self.req_to_blocks[running_request_id]
num_common_blocks = 0
......@@ -389,28 +420,33 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
computed.pop()
return computed_blocks
def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None:
# Remove the blocks that are no longer be in the sliding window and
# skipped during the attention computation.
last_useful_token = num_computed_tokens - self.sliding_window + 1
last_useful_block = last_useful_token // self.block_size
if last_useful_block <= 0:
# Early return if tokens are not enough to fill the sliding window
return
blocks = self.req_to_blocks[request_id]
if blocks[last_useful_block - 1] == self._null_block:
# Early return if there are no blocks to remove
return
removed_blocks: list[KVCacheBlock] = []
for i in range(last_useful_block - 1, -1, -1):
if blocks[i] == self._null_block:
# If the block is already a null block, the blocks before it
# should also have been set to null blocks by the previous calls
# to this function.
break
removed_blocks.append(blocks[i])
blocks[i] = self._null_block
self.block_pool.free_blocks(removed_blocks)
def get_num_skipped_tokens(self, num_computed_tokens: int) -> int:
"""
Get the number of tokens that will be skipped for attention computation.
For sliding window, this corresponds to the tokens that are prior to
the current sliding window.
Example:
sliding_window=4, num_computed_tokens=7
Tokens: [ 0 1 2 3 4 5 6 7 ]
| ---- computed -----|
^ next token to be computed
|-----------| sliding window for next token
|--skipped---|
The current window contains tokens 4~7. Tokens 0~3 will be skipped for
attention computation since they are outside the sliding window.
Thus, get_num_skipped_tokens(7) == 4.
Args:
num_computed_tokens: The number of tokens that have been computed.
Returns:
The number of tokens that will be skipped for attention computation.
"""
return num_computed_tokens - self.sliding_window + 1
def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
"""
......@@ -511,40 +547,51 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager):
break
return computed_blocks
def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None:
# Remove the blocks that are no longer be in the chunked attention
# window and skipped during the attention computation.
# [chunk 0][chunk 1]local_attention_start_idx ... current
# we computed previous number of chunks to get the idx of
# current chunk window starting offset,
# e.g. for computed 1024 tokens, the 1024th token (0 indexed)
# is in the second chunk, there are 1 prev chunk, the start idx
# is 1024. for 1023, it will be 0.
num_cached_block = self.num_cached_block.get(request_id, 0)
local_attention_start_idx = (
(num_computed_tokens)
// self.attention_chunk_size
* self.attention_chunk_size
)
first_useful_block_idx = local_attention_start_idx // self.block_size
if num_cached_block > 0:
# Make sure we don't delete the last cached block
first_useful_block_idx = min(first_useful_block_idx, num_cached_block - 1)
# if block size = 128, 0 -> block 0, 1024 (= 128 * 8) ->
# block 8, 372 (= 128 * 2 + 116) -> block 2
blocks = self.req_to_blocks[request_id]
removed_blocks: list[KVCacheBlock] = []
# we need to keep the last block to get the previous hash key
for i in range(first_useful_block_idx - 1, -1, -1):
if blocks[i] == self._null_block:
# If the block is already a null block, the blocks before it
# should also have been set to null blocks by the previous calls
# to this function.
break
removed_blocks.append(blocks[i])
blocks[i] = self._null_block
self.block_pool.free_blocks(removed_blocks)
def get_num_skipped_tokens(self, num_computed_tokens: int) -> int:
"""
Get the number of tokens that will be skipped for attention computation.
For chunked local attention, this corresponds to the tokens that are on
the left side of the current chunk.
Example 1:
chunk size = 8, num_computed_tokens = 13
Tokens: [ 0 1 2 3 4 5 6 7 | 8 9 10 11 12 13 14 15 ] ...
| ----- computed ---------------|
^^ next token to be computed
|----------------| <-- attention window for
next token
|--- skipped -----|
Output: get_num_skipped_tokens(13) == 8
Example 2:
chunk size = 8, num_computed_tokens = 8
Tokens: [ 0 1 2 3 4 5 6 7 | 8 9 10 11 12 13 14 15 ] ...
| --- computed ---|
^ next token to be computed
|--| <-- attention window for next token
| --- skipped ----|
Output: get_num_skipped_tokens(8) == 8
Example 3:
chunk size = 8, num_computed_tokens = 7
Tokens: [ 0 1 2 3 4 5 6 7 | 8 9 10 11 12 13 14 15 ] ...
|---computed---|
^ next token to be computed
|-----------------| <-- attention window for next token
no token should be skipped.
Output: get_num_skipped_tokens(7) == 0
Args:
num_computed_tokens: The number of tokens that have been computed.
Returns:
The number of tokens that will be skipped for attention computation.
"""
num_skipped_tokens = (
num_computed_tokens // self.attention_chunk_size
) * self.attention_chunk_size
return num_skipped_tokens
def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
"""
......@@ -590,12 +637,6 @@ class MambaManager(SingleTypeKVCacheManager):
return computed_blocks
def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None:
# Here unused blocks may be freed up for running requests.
# TODO(@s3woz) Free up all blocks that aren't needed by Mamba2
# (for which find_longest_cache_hit returns block_pool.null_block)
pass
def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
"""
cascade attention is not supported by mamba
......@@ -676,11 +717,6 @@ class CrossAttentionManager(SingleTypeKVCacheManager):
# Return empty blocks to indicate no cache hits
raise NotImplementedError("CrossAttentionManager does not support caching")
def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None:
# Cross-attention blocks represent encoder states which are needed
# for the entire decoding process, so no blocks should be skipped
pass
spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = {
FullAttentionSpec: FullAttentionManager,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment