Unverified Commit 54e872d3 authored by Zhiqiang Xie's avatar Zhiqiang Xie Committed by GitHub
Browse files

[HiCache] resolve conflict between chunked-prefill and hicache hit count (#9776)

parent e5b29bf1
...@@ -567,7 +567,7 @@ class SchedulerDisaggregationPrefillMixin: ...@@ -567,7 +567,7 @@ class SchedulerDisaggregationPrefillMixin:
# Move the chunked request out of the batch so that we can merge # Move the chunked request out of the batch so that we can merge
# only finished requests to running_batch. # only finished requests to running_batch.
self.last_batch.filter_batch(chunked_req_to_exclude=self.chunked_req) self.last_batch.filter_batch(chunked_req_to_exclude=self.chunked_req)
self.tree_cache.cache_unfinished_req(self.chunked_req) self.tree_cache.cache_unfinished_req(self.chunked_req, chunked=True)
if self.enable_overlap: if self.enable_overlap:
# Delay KV transfer to process_batch_result_disagg_prefill when overlap is enabled to ensure results are resolved # Delay KV transfer to process_batch_result_disagg_prefill when overlap is enabled to ensure results are resolved
self.chunked_req.tmp_end_idx = min( self.chunked_req.tmp_end_idx = min(
......
...@@ -1503,7 +1503,7 @@ class Scheduler( ...@@ -1503,7 +1503,7 @@ class Scheduler(
# Move the chunked request out of the batch so that we can merge # Move the chunked request out of the batch so that we can merge
# only finished requests to running_batch. # only finished requests to running_batch.
chunked_req_to_exclude.add(self.chunked_req) chunked_req_to_exclude.add(self.chunked_req)
self.tree_cache.cache_unfinished_req(self.chunked_req) self.tree_cache.cache_unfinished_req(self.chunked_req, chunked=True)
# chunked request keeps its rid but will get a new req_pool_idx # chunked request keeps its rid but will get a new req_pool_idx
self.req_to_token_pool.free(self.chunked_req.req_pool_idx) self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
if self.last_batch and self.last_batch.forward_mode.is_extend(): if self.last_batch and self.last_batch.forward_mode.is_extend():
......
...@@ -47,7 +47,7 @@ class ChunkCache(BasePrefixCache): ...@@ -47,7 +47,7 @@ class ChunkCache(BasePrefixCache):
self.req_to_token_pool.free(req.req_pool_idx) self.req_to_token_pool.free(req.req_pool_idx)
self.token_to_kv_pool_allocator.free(kv_indices) self.token_to_kv_pool_allocator.free(kv_indices)
def cache_unfinished_req(self, req: Req): def cache_unfinished_req(self, req: Req, chunked=False):
kv_indices = self.req_to_token_pool.req_to_token[ kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, : len(req.fill_ids) req.req_pool_idx, : len(req.fill_ids)
] ]
......
...@@ -102,7 +102,7 @@ class HiRadixCache(RadixCache): ...@@ -102,7 +102,7 @@ class HiRadixCache(RadixCache):
self.ongoing_backup = {} self.ongoing_backup = {}
# todo: dynamically adjust the threshold # todo: dynamically adjust the threshold
self.write_through_threshold = ( self.write_through_threshold = (
1 if hicache_write_policy == "write_through" else 3 1 if hicache_write_policy == "write_through" else 2
) )
self.write_through_threshold_storage = ( self.write_through_threshold_storage = (
1 if hicache_write_policy == "write_through" else 3 1 if hicache_write_policy == "write_through" else 3
...@@ -155,8 +155,9 @@ class HiRadixCache(RadixCache): ...@@ -155,8 +155,9 @@ class HiRadixCache(RadixCache):
self.ongoing_backup[operation_id] = node self.ongoing_backup[operation_id] = node
node.protect_host() node.protect_host()
def inc_hit_count(self, node: TreeNode): def _inc_hit_count(self, node: TreeNode, chunked=False):
if self.cache_controller.write_policy == "write_back": # skip the hit count update for chunked requests
if self.cache_controller.write_policy == "write_back" or chunked:
return return
node.hit_count += 1 node.hit_count += 1
...@@ -672,11 +673,11 @@ class HiRadixCache(RadixCache): ...@@ -672,11 +673,11 @@ class HiRadixCache(RadixCache):
new_node.parent.children[self.get_child_key_fn(key)] = new_node new_node.parent.children[self.get_child_key_fn(key)] = new_node
return new_node return new_node
def _insert_helper(self, node: TreeNode, key: List, value): def insert(self, key: List, value, chunked=False):
node.last_access_time = time.monotonic()
if len(key) == 0: if len(key) == 0:
return 0 return 0
node = self.root_node
child_key = self.get_child_key_fn(key) child_key = self.get_child_key_fn(key)
total_prefix_length = 0 total_prefix_length = 0
...@@ -693,7 +694,7 @@ class HiRadixCache(RadixCache): ...@@ -693,7 +694,7 @@ class HiRadixCache(RadixCache):
self.token_to_kv_pool_host.update_synced(node.host_value) self.token_to_kv_pool_host.update_synced(node.host_value)
self.evictable_size_ += len(node.value) self.evictable_size_ += len(node.value)
else: else:
self.inc_hit_count(node) self._inc_hit_count(node, chunked)
total_prefix_length += prefix_len total_prefix_length += prefix_len
else: else:
# partial match, split the node # partial match, split the node
...@@ -703,7 +704,7 @@ class HiRadixCache(RadixCache): ...@@ -703,7 +704,7 @@ class HiRadixCache(RadixCache):
self.token_to_kv_pool_host.update_synced(new_node.host_value) self.token_to_kv_pool_host.update_synced(new_node.host_value)
self.evictable_size_ += len(new_node.value) self.evictable_size_ += len(new_node.value)
else: else:
self.inc_hit_count(new_node) self._inc_hit_count(new_node, chunked)
total_prefix_length += prefix_len total_prefix_length += prefix_len
node = new_node node = new_node
...@@ -737,7 +738,7 @@ class HiRadixCache(RadixCache): ...@@ -737,7 +738,7 @@ class HiRadixCache(RadixCache):
last_hash = new_node.hash_value[-1] last_hash = new_node.hash_value[-1]
if self.cache_controller.write_policy != "write_back": if self.cache_controller.write_policy != "write_back":
self.inc_hit_count(new_node) self._inc_hit_count(new_node, chunked)
return total_prefix_length return total_prefix_length
def _collect_leaves_device(self): def _collect_leaves_device(self):
......
...@@ -183,7 +183,7 @@ class LoRARadixCache(BasePrefixCache): ...@@ -183,7 +183,7 @@ class LoRARadixCache(BasePrefixCache):
self.req_to_token_pool.free(req.req_pool_idx) self.req_to_token_pool.free(req.req_pool_idx)
self.dec_lock_ref(req.last_node) self.dec_lock_ref(req.last_node)
def cache_unfinished_req(self, req: Req): def cache_unfinished_req(self, req: Req, chunked=False):
"""Cache request when it is unfinished.""" """Cache request when it is unfinished."""
if self.disable: if self.disable:
return return
......
...@@ -195,7 +195,7 @@ class RadixCache(BasePrefixCache): ...@@ -195,7 +195,7 @@ class RadixCache(BasePrefixCache):
last_host_node=last_node, last_host_node=last_node,
) )
def insert(self, key: List, value=None): def insert(self, key: List, value=None, chunked=False):
if self.disable: if self.disable:
return 0 return 0
...@@ -240,7 +240,7 @@ class RadixCache(BasePrefixCache): ...@@ -240,7 +240,7 @@ class RadixCache(BasePrefixCache):
self.req_to_token_pool.free(req.req_pool_idx) self.req_to_token_pool.free(req.req_pool_idx)
self.dec_lock_ref(req.last_node) self.dec_lock_ref(req.last_node)
def cache_unfinished_req(self, req: Req): def cache_unfinished_req(self, req: Req, chunked=False):
"""Cache request when it is unfinished.""" """Cache request when it is unfinished."""
if self.disable: if self.disable:
return return
...@@ -261,7 +261,9 @@ class RadixCache(BasePrefixCache): ...@@ -261,7 +261,9 @@ class RadixCache(BasePrefixCache):
page_aligned_token_ids = token_ids[:page_aligned_len] page_aligned_token_ids = token_ids[:page_aligned_len]
# Radix Cache takes one ref in memory pool # Radix Cache takes one ref in memory pool
new_prefix_len = self.insert(page_aligned_token_ids, page_aligned_kv_indices) new_prefix_len = self.insert(
page_aligned_token_ids, page_aligned_kv_indices, chunked=chunked
)
self.token_to_kv_pool_allocator.free( self.token_to_kv_pool_allocator.free(
kv_indices[len(req.prefix_indices) : new_prefix_len] kv_indices[len(req.prefix_indices) : new_prefix_len]
) )
......
...@@ -181,7 +181,7 @@ class RadixCacheCpp(BasePrefixCache): ...@@ -181,7 +181,7 @@ class RadixCacheCpp(BasePrefixCache):
self.dec_lock_ref(req.last_node) self.dec_lock_ref(req.last_node)
self.req_to_token_pool.free(req.req_pool_idx) self.req_to_token_pool.free(req.req_pool_idx)
def cache_unfinished_req(self, req: Req): def cache_unfinished_req(self, req: Req, chunked=False):
"""Cache request when it is unfinished.""" """Cache request when it is unfinished."""
assert req.req_pool_idx is not None assert req.req_pool_idx is not None
token_ids = req.fill_ids token_ids = req.fill_ids
......
...@@ -464,7 +464,7 @@ class SWARadixCache(BasePrefixCache): ...@@ -464,7 +464,7 @@ class SWARadixCache(BasePrefixCache):
self.req_to_token_pool.free(req.req_pool_idx) self.req_to_token_pool.free(req.req_pool_idx)
self.dec_lock_ref(req.last_node, req.swa_uuid_for_lock) self.dec_lock_ref(req.last_node, req.swa_uuid_for_lock)
def cache_unfinished_req(self, req: Req) -> None: def cache_unfinished_req(self, req: Req, chunked=False) -> None:
"""Cache request when it is unfinished.""" """Cache request when it is unfinished."""
if self.disable: if self.disable:
kv_indices = self.req_to_token_pool.req_to_token[ kv_indices = self.req_to_token_pool.req_to_token[
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment