Unverified Commit 37565b7f authored by JinYan Su's avatar JinYan Su Committed by GitHub
Browse files

fix(cache): move ongoing_prefetch pop after validation to prevent leak (#9927)


Co-authored-by: default avatarZhiqiang Xie <xiezhq@stanford.edu>
parent 6243c367
...@@ -468,9 +468,9 @@ class HiRadixCache(RadixCache): ...@@ -468,9 +468,9 @@ class HiRadixCache(RadixCache):
# todo: more policies for prefetch progress such as timeout # todo: more policies for prefetch progress such as timeout
# the current policy is to prefetch with best effort and terminate when queuing is over # the current policy is to prefetch with best effort and terminate when queuing is over
last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch.pop( last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch[
req_id req_id
) ]
if operation.host_indices is None: if operation.host_indices is None:
# prefetch has not been issued due to insufficient host memory # prefetch has not been issued due to insufficient host memory
...@@ -512,6 +512,7 @@ class HiRadixCache(RadixCache): ...@@ -512,6 +512,7 @@ class HiRadixCache(RadixCache):
host_indices[min_completed_tokens:completed_tokens] host_indices[min_completed_tokens:completed_tokens]
) )
last_host_node.release_host() last_host_node.release_host()
del self.ongoing_prefetch[req_id]
self.cache_controller.prefetch_tokens_occupied -= len(token_ids) self.cache_controller.prefetch_tokens_occupied -= len(token_ids)
return True return True
...@@ -775,9 +776,7 @@ class HiRadixCache(RadixCache): ...@@ -775,9 +776,7 @@ class HiRadixCache(RadixCache):
if rid not in self.ongoing_prefetch: if rid not in self.ongoing_prefetch:
return return
last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch.pop( last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch[rid]
rid
)
if operation.host_indices is None: if operation.host_indices is None:
return return
...@@ -785,5 +784,6 @@ class HiRadixCache(RadixCache): ...@@ -785,5 +784,6 @@ class HiRadixCache(RadixCache):
if self.tp_world_size > 1: if self.tp_world_size > 1:
torch.distributed.barrier(group=self.tp_group) torch.distributed.barrier(group=self.tp_group)
last_host_node.release_host() last_host_node.release_host()
del self.ongoing_prefetch[rid]
self.cache_controller.append_host_mem_release(host_indices[:completed_tokens]) self.cache_controller.append_host_mem_release(host_indices[:completed_tokens])
self.cache_controller.prefetch_tokens_occupied -= len(token_ids) self.cache_controller.prefetch_tokens_occupied -= len(token_ids)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment