Unverified Commit 528bd1ed authored by Zhiqiang Xie's avatar Zhiqiang Xie Committed by GitHub
Browse files

HiCache, check before terminate prefetching (#8372)

parent 62a6b7c7
...@@ -201,8 +201,9 @@ class PrefetchOperation(StorageOperation): ...@@ -201,8 +201,9 @@ class PrefetchOperation(StorageOperation):
def increment(self, num_tokens: int): def increment(self, num_tokens: int):
with self._lock: with self._lock:
if self._done_flag: if self._done_flag:
return return False
self.completed_tokens += num_tokens self.completed_tokens += num_tokens
return True
def mark_done(self): def mark_done(self):
with self._lock: with self._lock:
...@@ -528,12 +529,12 @@ class HiCacheController: ...@@ -528,12 +529,12 @@ class HiCacheController:
f"Prefetch operation {operation.request_id} failed to retrieve page {h}." f"Prefetch operation {operation.request_id} failed to retrieve page {h}."
) )
break break
if operation.increment(self.page_size):
self.mem_pool_host.set_from_flat_data_page( self.mem_pool_host.set_from_flat_data_page(
operation.host_indices[operation.completed_tokens], operation.host_indices[operation.completed_tokens],
page_data, page_data,
) )
operation.increment(self.page_size) else:
if operation.is_done():
# operation terminated by controller, release pre-allocated memory # operation terminated by controller, release pre-allocated memory
self.mem_pool_host.free( self.mem_pool_host.free(
operation.host_indices[operation.completed_tokens :] operation.host_indices[operation.completed_tokens :]
...@@ -589,6 +590,7 @@ class HiCacheController: ...@@ -589,6 +590,7 @@ class HiCacheController:
if storage_hit_count < self.prefetch_threshold: if storage_hit_count < self.prefetch_threshold:
# not to prefetch if not enough benefits # not to prefetch if not enough benefits
self.prefetch_revoke_queue.put(operation.request_id) self.prefetch_revoke_queue.put(operation.request_id)
self.mem_pool_host.free(operation.host_indices)
logger.debug( logger.debug(
f"Revoking prefetch for request {operation.request_id} due to insufficient hits ({storage_hit_count})." f"Revoking prefetch for request {operation.request_id} due to insufficient hits ({storage_hit_count})."
) )
......
...@@ -365,10 +365,12 @@ class HiRadixCache(RadixCache): ...@@ -365,10 +365,12 @@ class HiRadixCache(RadixCache):
for _ in range(queue_size.item()): for _ in range(queue_size.item()):
req_id = self.cache_controller.prefetch_revoke_queue.get() req_id = self.cache_controller.prefetch_revoke_queue.get()
if req_id in self.ongoing_prefetch: if req_id in self.ongoing_prefetch:
last_host_node, _, host_indices, _ = self.ongoing_prefetch[req_id] last_host_node, _, _, _ = self.ongoing_prefetch[req_id]
last_host_node.release_host() last_host_node.release_host()
self.cache_controller.mem_pool_host.free(host_indices)
del self.ongoing_prefetch[req_id] del self.ongoing_prefetch[req_id]
else:
# the revoked operation already got terminated
pass
def check_backup_progress(self): def check_backup_progress(self):
queue_size = torch.tensor( queue_size = torch.tensor(
...@@ -403,6 +405,7 @@ class HiRadixCache(RadixCache): ...@@ -403,6 +405,7 @@ class HiRadixCache(RadixCache):
last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch[ last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch[
req_id req_id
] ]
completed_tokens, hash_value = self.cache_controller.terminate_prefetch( completed_tokens, hash_value = self.cache_controller.terminate_prefetch(
operation operation
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment