Unverified Commit 5095e966 authored by Cody Yu's avatar Cody Yu Committed by GitHub
Browse files

[V1] Revert `uncache_blocks` and support recaching full blocks (#12415)


Signed-off-by: default avatarCody Yu <hao.yu.cody@gmail.com>
parent cf58b9c4
...@@ -629,33 +629,3 @@ def test_reset_prefix_cache(): ...@@ -629,33 +629,3 @@ def test_reset_prefix_cache():
assert manager.reset_prefix_cache() assert manager.reset_prefix_cache()
assert not manager.cached_block_hash_to_block assert not manager.cached_block_hash_to_block
assert all([blk.block_hash is None for blk in manager.block_pool]) assert all([blk.block_hash is None for blk in manager.block_pool])
def test_uncache_blocks():
manager = KVCacheManager(
block_size=16,
num_gpu_blocks=10,
max_model_len=8192,
sliding_window=None,
enable_caching=True,
num_preallocate_tokens=0,
)
req0 = make_request("0", list(range(30)))
blocks = manager.allocate_slots(req0, 30)
assert [b.block_id for b in blocks] == [0, 1]
assert len(manager.cached_block_hash_to_block) == 1
req0.num_computed_tokens = 30
# Simulate speculative tokens.
for _ in range(5):
req0.append_output_token_ids(8)
manager.allocate_slots(req0, 5)
assert len(manager.cached_block_hash_to_block) == 2
# After sampling, assuming only 1 token is accepted.
req0.num_computed_tokens = 31
num_uncached_blocks = manager.uncache_blocks(req0)
assert num_uncached_blocks == 1
assert len(manager.cached_block_hash_to_block) == 1
...@@ -252,29 +252,6 @@ class KVCacheManager: ...@@ -252,29 +252,6 @@ class KVCacheManager:
if block.ref_cnt == 0: if block.ref_cnt == 0:
self.free_block_queue.append(block) self.free_block_queue.append(block)
def uncache_blocks(self, request: Request) -> int:
"""Uncache the blocks that are no longer full based on the
num_computed_tokens in the given request. This happens when
the blocks were full and cached due to speculative tokens, but the
speculative tokens are not accepted.
Args:
request: The request.
Returns:
The number of uncached blocks.
"""
blocks = self.req_to_blocks[request.request_id]
num_computed_tokens = request.num_computed_tokens
num_full_blocks = num_computed_tokens // self.block_size
num_uncached_blocks = 0
for block in blocks[num_full_blocks:]:
# If the block is not cached, the following blocks are not cached.
if not self._maybe_evict_cached_block(block):
break
num_uncached_blocks += 1
return num_uncached_blocks
def reset_prefix_cache(self) -> bool: def reset_prefix_cache(self) -> bool:
"""Reset prefix cache. This function may be used in RLHF """Reset prefix cache. This function may be used in RLHF
flows to invalid prefix caching after the weights are updated, flows to invalid prefix caching after the weights are updated,
...@@ -470,8 +447,22 @@ class KVCacheManager: ...@@ -470,8 +447,22 @@ class KVCacheManager:
assert prev_block.block_hash is not None assert prev_block.block_hash is not None
prev_block_hash_value = prev_block.block_hash.hash_value prev_block_hash_value = prev_block.block_hash.hash_value
for i, blk in enumerate(full_blocks): # Find the first uncached block. This case should only happen when
blk_idx = blk_start_idx + i # speculative decoding is used.
offset = 0
for blk in full_blocks:
if blk.block_hash is None:
break
else:
prev_block_hash_value = blk.block_hash.hash_value
offset += 1
else:
# All blocks are cached.
return
for i, blk in enumerate(full_blocks[offset:]):
blk_idx = blk_start_idx + offset + i
assert blk.block_hash is None
if blk_idx < num_cached_block_hashes: if blk_idx < num_cached_block_hashes:
# The block hash may already be computed in # The block hash may already be computed in
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment