"src/vscode:/vscode.git/clone" did not exist on "a41e4c506bea0179ac6e556620c7ed45cc4c5f29"
Unverified Commit cb9e0e41 authored by huangtingwei's avatar huangtingwei Committed by GitHub
Browse files

[HiCacheStorage] fix abort request host memory leaks (#9874)


Co-authored-by: default avatarZhiqiang Xie <xiezhq@stanford.edu>
parent 9db80253
...@@ -2403,6 +2403,9 @@ class Scheduler( ...@@ -2403,6 +2403,9 @@ class Scheduler(
# This only works for requests that have not started anything. # This only works for requests that have not started anything.
# We still need to send something back to TokenizerManager to clean up the state. # We still need to send something back to TokenizerManager to clean up the state.
req = self.waiting_queue.pop(i) req = self.waiting_queue.pop(i)
if self.enable_hicache_storage:
# to release prefetch events associated with the request
self.tree_cache.release_aborted_request(req.rid)
self.send_to_tokenizer.send_pyobj(AbortReq(req.rid)) self.send_to_tokenizer.send_pyobj(AbortReq(req.rid))
# For disaggregation decode mode, the request in the waiting queue has KV cache allocated. # For disaggregation decode mode, the request in the waiting queue has KV cache allocated.
if self.disaggregation_mode == DisaggregationMode.DECODE: if self.disaggregation_mode == DisaggregationMode.DECODE:
......
...@@ -468,9 +468,9 @@ class HiRadixCache(RadixCache): ...@@ -468,9 +468,9 @@ class HiRadixCache(RadixCache):
# todo: more policies for prefetch progress such as timeout # todo: more policies for prefetch progress such as timeout
# the current policy is to prefetch with best effort and terminate when queuing is over # the current policy is to prefetch with best effort and terminate when queuing is over
last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch[ last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch.pop(
req_id req_id
] )
if operation.host_indices is None: if operation.host_indices is None:
# prefetch has not been issued due to insufficient host memory # prefetch has not been issued due to insufficient host memory
...@@ -512,7 +512,6 @@ class HiRadixCache(RadixCache): ...@@ -512,7 +512,6 @@ class HiRadixCache(RadixCache):
host_indices[min_completed_tokens:completed_tokens] host_indices[min_completed_tokens:completed_tokens]
) )
last_host_node.release_host() last_host_node.release_host()
del self.ongoing_prefetch[req_id]
self.cache_controller.prefetch_tokens_occupied -= len(token_ids) self.cache_controller.prefetch_tokens_occupied -= len(token_ids)
return True return True
...@@ -771,3 +770,20 @@ class HiRadixCache(RadixCache): ...@@ -771,3 +770,20 @@ class HiRadixCache(RadixCache):
if not cur_child.evicted: if not cur_child.evicted:
stack.append(cur_child) stack.append(cur_child)
return ret_list return ret_list
def release_aborted_request(self, rid: str):
if rid not in self.ongoing_prefetch:
return
last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch.pop(
rid
)
if operation.host_indices is None:
return
completed_tokens, _ = self.cache_controller.terminate_prefetch(operation)
if self.tp_world_size > 1:
torch.distributed.barrier(group=self.tp_group)
last_host_node.release_host()
self.cache_controller.append_host_mem_release(host_indices[:completed_tokens])
self.cache_controller.prefetch_tokens_occupied -= len(token_ids)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment