Unverified Commit fbdb5050 authored by Zhiqiang Xie's avatar Zhiqiang Xie Committed by GitHub
Browse files

Hot fix for hicache with new page aligned radixtree (#4397)

parent f0afaf52
...@@ -248,6 +248,8 @@ class HiCacheController: ...@@ -248,6 +248,8 @@ class HiCacheController:
if device_indices is None: if device_indices is None:
return None return None
self.mem_pool_host.protect_load(host_indices) self.mem_pool_host.protect_load(host_indices)
# to ensure the device indices are ready before accessed by another CUDA stream
torch.cuda.current_stream().synchronize()
self.load_queue.put( self.load_queue.put(
CacheOperation(host_indices, device_indices, node_id, priority) CacheOperation(host_indices, device_indices, node_id, priority)
) )
......
...@@ -434,6 +434,7 @@ class Scheduler(SchedulerOutputProcessorMixin): ...@@ -434,6 +434,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
req_to_token_pool=self.req_to_token_pool, req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator, token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
tp_cache_group=self.tp_worker.get_tp_cpu_group(), tp_cache_group=self.tp_worker.get_tp_cpu_group(),
page_size=self.page_size,
) )
else: else:
self.tree_cache = RadixCache( self.tree_cache = RadixCache(
......
...@@ -25,11 +25,17 @@ class HiRadixCache(RadixCache): ...@@ -25,11 +25,17 @@ class HiRadixCache(RadixCache):
req_to_token_pool: ReqToTokenPool, req_to_token_pool: ReqToTokenPool,
token_to_kv_pool_allocator: TokenToKVPoolAllocator, token_to_kv_pool_allocator: TokenToKVPoolAllocator,
tp_cache_group: torch.distributed.ProcessGroup, tp_cache_group: torch.distributed.ProcessGroup,
page_size: int,
): ):
if page_size != 1:
raise ValueError(
"Page size larger than 1 is not yet supported in HiRadixCache."
)
self.token_to_kv_pool_host = MHATokenToKVPoolHost( self.token_to_kv_pool_host = MHATokenToKVPoolHost(
token_to_kv_pool_allocator.get_kvcache() token_to_kv_pool_allocator.get_kvcache()
) )
self.tp_group = tp_cache_group self.tp_group = tp_cache_group
self.page_size = page_size
self.load_cache_event = threading.Event() self.load_cache_event = threading.Event()
self.cache_controller = HiCacheController( self.cache_controller = HiCacheController(
...@@ -45,7 +51,9 @@ class HiRadixCache(RadixCache): ...@@ -45,7 +51,9 @@ class HiRadixCache(RadixCache):
# todo: dynamically adjust the threshold # todo: dynamically adjust the threshold
self.write_through_threshold = 1 self.write_through_threshold = 1
self.load_back_threshold = 10 self.load_back_threshold = 10
super().__init__(req_to_token_pool, token_to_kv_pool_allocator, disable=False) super().__init__(
req_to_token_pool, token_to_kv_pool_allocator, self.page_size, disable=False
)
def reset(self): def reset(self):
TreeNode.counter = 0 TreeNode.counter = 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment