Unverified Commit 6e0b6468 authored by Zhiqiang Xie's avatar Zhiqiang Xie Committed by GitHub
Browse files

HiCache Storage tp fix (#8878)

parent 4a9f3eef
...@@ -570,10 +570,6 @@ class HiCacheController: ...@@ -570,10 +570,6 @@ class HiCacheController:
) )
completed_tokens += self.page_size completed_tokens += self.page_size
else: else:
# operation terminated by controller, release pre-allocated memory
self.mem_pool_host.free(
operation.host_indices[operation.completed_tokens :]
)
break break
def mooncake_page_transfer(self, operation): def mooncake_page_transfer(self, operation):
...@@ -599,6 +595,14 @@ class HiCacheController: ...@@ -599,6 +595,14 @@ class HiCacheController:
self.generic_page_transfer(operation, batch_size=128) self.generic_page_transfer(operation, batch_size=128)
else: else:
self.generic_page_transfer(operation) self.generic_page_transfer(operation)
if self.tp_world_size > 1:
# to ensure all TP workers release the host memory at the same time
torch.distributed.barrier(group=self.prefetch_tp_group)
# operation terminated by controller, release pre-allocated memory
self.mem_pool_host.free(
operation.host_indices[operation.completed_tokens :]
)
except Empty: except Empty:
continue continue
...@@ -626,7 +630,9 @@ class HiCacheController: ...@@ -626,7 +630,9 @@ class HiCacheController:
continue continue
storage_hit_count = 0 storage_hit_count = 0
if self.prefetch_rate_limit_check(): if (
operation.host_indices is not None
) and self.prefetch_rate_limit_check():
last_hash = operation.last_hash last_hash = operation.last_hash
tokens_to_fetch = operation.token_ids tokens_to_fetch = operation.token_ids
...@@ -670,6 +676,7 @@ class HiCacheController: ...@@ -670,6 +676,7 @@ class HiCacheController:
if storage_hit_count < self.prefetch_threshold: if storage_hit_count < self.prefetch_threshold:
# not to prefetch if not enough benefits # not to prefetch if not enough benefits
self.prefetch_revoke_queue.put(operation.request_id) self.prefetch_revoke_queue.put(operation.request_id)
if operation.host_indices is not None:
self.mem_pool_host.free(operation.host_indices) self.mem_pool_host.free(operation.host_indices)
logger.debug( logger.debug(
f"Revoking prefetch for request {operation.request_id} due to insufficient hits ({storage_hit_count})." f"Revoking prefetch for request {operation.request_id} due to insufficient hits ({storage_hit_count})."
......
...@@ -471,6 +471,10 @@ class HiRadixCache(RadixCache): ...@@ -471,6 +471,10 @@ class HiRadixCache(RadixCache):
req_id req_id
] ]
if operation.host_indices is None:
# prefetch has not been issued due to insufficient host memory
return True
if not self.can_terminate_prefetch(operation): if not self.can_terminate_prefetch(operation):
return False return False
...@@ -565,10 +569,6 @@ class HiRadixCache(RadixCache): ...@@ -565,10 +569,6 @@ class HiRadixCache(RadixCache):
if host_indices is None: if host_indices is None:
self.evict_host(prefetch_length) self.evict_host(prefetch_length)
host_indices = self.cache_controller.mem_pool_host.alloc(prefetch_length) host_indices = self.cache_controller.mem_pool_host.alloc(prefetch_length)
if host_indices is None:
last_host_node.release_host()
# no sufficient host memory to prefetch
return
operation = self.cache_controller.prefetch( operation = self.cache_controller.prefetch(
req_id, host_indices, new_input_tokens, last_hash req_id, host_indices, new_input_tokens, last_hash
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment