Unverified Commit fbdb5050 authored by Zhiqiang Xie's avatar Zhiqiang Xie Committed by GitHub
Browse files

Hot fix for hicache with new page aligned radixtree (#4397)

parent f0afaf52
......@@ -248,6 +248,8 @@ class HiCacheController:
if device_indices is None:
return None
self.mem_pool_host.protect_load(host_indices)
# to ensure the device indices are ready before accessed by another CUDA stream
torch.cuda.current_stream().synchronize()
self.load_queue.put(
CacheOperation(host_indices, device_indices, node_id, priority)
)
......
......@@ -434,6 +434,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
tp_cache_group=self.tp_worker.get_tp_cpu_group(),
page_size=self.page_size,
)
else:
self.tree_cache = RadixCache(
......
......@@ -25,11 +25,17 @@ class HiRadixCache(RadixCache):
req_to_token_pool: ReqToTokenPool,
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
tp_cache_group: torch.distributed.ProcessGroup,
page_size: int,
):
if page_size != 1:
raise ValueError(
"Page size larger than 1 is not yet supported in HiRadixCache."
)
self.token_to_kv_pool_host = MHATokenToKVPoolHost(
token_to_kv_pool_allocator.get_kvcache()
)
self.tp_group = tp_cache_group
self.page_size = page_size
self.load_cache_event = threading.Event()
self.cache_controller = HiCacheController(
......@@ -45,7 +51,9 @@ class HiRadixCache(RadixCache):
# todo: dynamically adjust the threshold
self.write_through_threshold = 1
self.load_back_threshold = 10
super().__init__(req_to_token_pool, token_to_kv_pool_allocator, disable=False)
super().__init__(
req_to_token_pool, token_to_kv_pool_allocator, self.page_size, disable=False
)
def reset(self):
TreeNode.counter = 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment