Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
cb9e0e41
Unverified
Commit
cb9e0e41
authored
Sep 02, 2025
by
huangtingwei
Committed by
GitHub
Sep 01, 2025
Browse files
[HiCacheStorage] fix abort request host memory leaks (#9874)
Co-authored-by:
Zhiqiang Xie
<
xiezhq@stanford.edu
>
parent
9db80253
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
22 additions
and
3 deletions
+22
-3
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+3
-0
python/sglang/srt/mem_cache/hiradix_cache.py
python/sglang/srt/mem_cache/hiradix_cache.py
+19
-3
No files found.
python/sglang/srt/managers/scheduler.py
View file @
cb9e0e41
...
@@ -2403,6 +2403,9 @@ class Scheduler(
...
@@ -2403,6 +2403,9 @@ class Scheduler(
# This only works for requests that have not started anything.
# This only works for requests that have not started anything.
# We still need to send something back to TokenizerManager to clean up the state.
# We still need to send something back to TokenizerManager to clean up the state.
req
=
self
.
waiting_queue
.
pop
(
i
)
req
=
self
.
waiting_queue
.
pop
(
i
)
if
self
.
enable_hicache_storage
:
# to release prefetch events associated with the request
self
.
tree_cache
.
release_aborted_request
(
req
.
rid
)
self
.
send_to_tokenizer
.
send_pyobj
(
AbortReq
(
req
.
rid
))
self
.
send_to_tokenizer
.
send_pyobj
(
AbortReq
(
req
.
rid
))
# For disaggregation decode mode, the request in the waiting queue has KV cache allocated.
# For disaggregation decode mode, the request in the waiting queue has KV cache allocated.
if
self
.
disaggregation_mode
==
DisaggregationMode
.
DECODE
:
if
self
.
disaggregation_mode
==
DisaggregationMode
.
DECODE
:
...
...
python/sglang/srt/mem_cache/hiradix_cache.py
View file @
cb9e0e41
...
@@ -468,9 +468,9 @@ class HiRadixCache(RadixCache):
...
@@ -468,9 +468,9 @@ class HiRadixCache(RadixCache):
# todo: more policies for prefetch progress such as timeout
# todo: more policies for prefetch progress such as timeout
# the current policy is to prefetch with best effort and terminate when queuing is over
# the current policy is to prefetch with best effort and terminate when queuing is over
last_host_node
,
token_ids
,
host_indices
,
operation
=
self
.
ongoing_prefetch
[
last_host_node
,
token_ids
,
host_indices
,
operation
=
self
.
ongoing_prefetch
.
pop
(
req_id
req_id
]
)
if
operation
.
host_indices
is
None
:
if
operation
.
host_indices
is
None
:
# prefetch has not been issued due to insufficient host memory
# prefetch has not been issued due to insufficient host memory
...
@@ -512,7 +512,6 @@ class HiRadixCache(RadixCache):
...
@@ -512,7 +512,6 @@ class HiRadixCache(RadixCache):
host_indices
[
min_completed_tokens
:
completed_tokens
]
host_indices
[
min_completed_tokens
:
completed_tokens
]
)
)
last_host_node
.
release_host
()
last_host_node
.
release_host
()
del
self
.
ongoing_prefetch
[
req_id
]
self
.
cache_controller
.
prefetch_tokens_occupied
-=
len
(
token_ids
)
self
.
cache_controller
.
prefetch_tokens_occupied
-=
len
(
token_ids
)
return
True
return
True
...
@@ -771,3 +770,20 @@ class HiRadixCache(RadixCache):
...
@@ -771,3 +770,20 @@ class HiRadixCache(RadixCache):
if
not
cur_child
.
evicted
:
if
not
cur_child
.
evicted
:
stack
.
append
(
cur_child
)
stack
.
append
(
cur_child
)
return
ret_list
return
ret_list
def
release_aborted_request
(
self
,
rid
:
str
):
if
rid
not
in
self
.
ongoing_prefetch
:
return
last_host_node
,
token_ids
,
host_indices
,
operation
=
self
.
ongoing_prefetch
.
pop
(
rid
)
if
operation
.
host_indices
is
None
:
return
completed_tokens
,
_
=
self
.
cache_controller
.
terminate_prefetch
(
operation
)
if
self
.
tp_world_size
>
1
:
torch
.
distributed
.
barrier
(
group
=
self
.
tp_group
)
last_host_node
.
release_host
()
self
.
cache_controller
.
append_host_mem_release
(
host_indices
[:
completed_tokens
])
self
.
cache_controller
.
prefetch_tokens_occupied
-=
len
(
token_ids
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment