Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
14cbe42f
Unverified
Commit
14cbe42f
authored
Oct 29, 2025
by
Liangsheng Yin
Committed by
GitHub
Oct 29, 2025
Browse files
Refactor abortion in event loop (#12312)
parent
685c0645
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
17 additions
and
26 deletions
+17
-26
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+9
-11
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+3
-10
python/sglang/srt/managers/scheduler_output_processor_mixin.py
...n/sglang/srt/managers/scheduler_output_processor_mixin.py
+1
-1
python/sglang/srt/managers/session_controller.py
python/sglang/srt/managers/session_controller.py
+4
-4
No files found.
python/sglang/srt/managers/schedule_batch.py
View file @
14cbe42f
...
...
@@ -505,16 +505,15 @@ class Req:
# Check finish
self
.
tokenizer
=
None
self
.
finished_reason
=
None
self
.
finished_reason
:
Optional
[
BaseFinishReason
]
=
None
# finished position (in output_ids), used when checking stop conditions with speculative decoding
self
.
finished_len
=
None
# Whether this request has finished output
self
.
finished_output
=
None
# If we want to abort the request in the middle of the event loop, set this to true
# If we want to abort the request in the middle of the event loop,
# set to_finish instead of directly setting finished_reason.
# Note: We should never set finished_reason in the middle, the req will get filtered and never respond
self
.
to_abort
=
False
# This carries the error message for `.to_abort` and will be attached to the finished_reason at the end of the event loop
self
.
to_abort_message
:
str
=
None
self
.
to_finish
:
Optional
[
BaseFinishReason
]
=
None
self
.
stream
=
stream
self
.
eos_token_ids
=
eos_token_ids
self
.
vocab_size
=
vocab_size
...
...
@@ -866,10 +865,9 @@ class Req:
if
self
.
finished
():
return
if
self
.
to_abort
:
self
.
finished_reason
=
FINISH_ABORT
(
message
=
self
.
to_abort_message
,
)
if
self
.
to_finish
:
self
.
finished_reason
=
self
.
to_finish
self
.
to_finish
=
None
return
if
len
(
self
.
output_ids
)
>=
self
.
sampling_params
.
max_new_tokens
:
...
...
@@ -945,7 +943,7 @@ class Req:
self
.
grammar
=
None
self
.
origin_input_ids
=
[
0
]
# set it to one token to skip the long prefill
self
.
return_logprob
=
False
self
.
finish
ed_reason
=
FINISH_ABORT
(
self
.
to_
finish
=
FINISH_ABORT
(
error_msg
,
HTTPStatus
.
BAD_REQUEST
,
"BadRequestError"
)
...
...
@@ -1509,7 +1507,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
)
# avoid zero division
new_estimate_ratio
=
min
(
1.0
,
new_estimate_ratio
)
return
retracted_reqs
,
new_estimate_ratio
,
[]
return
retracted_reqs
,
new_estimate_ratio
def
release_req
(
self
,
idx
:
int
,
remaing_req_count
:
int
,
server_args
:
ServerArgs
):
req
=
self
.
reqs
[
idx
]
...
...
python/sglang/srt/managers/scheduler.py
View file @
14cbe42f
...
...
@@ -1817,20 +1817,13 @@ class Scheduler(
TEST_RETRACT
and
self
.
forward_ct
%
TEST_RETRACT_INTERVAL
==
0
):
old_ratio
=
self
.
new_token_ratio
retracted_reqs
,
new_token_ratio
,
reqs_to_abort
=
batch
.
retract_decode
(
self
.
server_args
)
retracted_reqs
,
new_token_ratio
=
batch
.
retract_decode
(
self
.
server_args
)
self
.
num_retracted_reqs
=
len
(
retracted_reqs
)
self
.
new_token_ratio
=
new_token_ratio
for
req
in
reqs_to_abort
:
self
.
send_to_tokenizer
.
send_output
(
AbortReq
(
abort_reason
=
req
.
to_abort_message
,
rid
=
req
.
rid
),
req
)
logger
.
info
(
"KV cache pool is full. Retract requests. "
f
"#retracted_reqs:
{
len
(
retracted_reqs
)
}
, "
f
"#aborted_retracted_reqs:
{
len
(
reqs_to_abort
)
}
, "
f
"#new_token_ratio:
{
old_ratio
:.
4
f
}
->
{
new_token_ratio
:.
4
f
}
"
)
...
...
@@ -2534,11 +2527,11 @@ class Scheduler(
if
not
req
.
finished
()
and
(
recv_req
.
abort_all
or
req
.
rid
.
startswith
(
recv_req
.
rid
)
):
# Abort method 3: set `to_
abort=True
`
# Abort method 3: set `to_
finish
`
# The request will still run one decode forward pass.
# Then we reuse all existing code to clean up the KV cache allocation.
logger
.
debug
(
f
"Abort running request.
{
req
.
rid
=
}
"
)
req
.
to_
abort
=
True
req
.
to_
finish
=
FINISH_ABORT
()
def
_pause_engine
(
self
)
->
Tuple
[
List
[
Req
],
int
]:
raise
NotImplementedError
()
...
...
python/sglang/srt/managers/scheduler_output_processor_mixin.py
View file @
14cbe42f
...
...
@@ -789,7 +789,7 @@ class SchedulerOutputProcessorMixin:
continue
# Multimodal partial stream chunks break the detokenizer, so drop aborted requests here.
if
self
.
model_config
.
is_multimodal_gen
and
req
.
to_
abort
:
if
self
.
model_config
.
is_multimodal_gen
and
req
.
to_
finish
:
continue
if
req
.
finished
():
...
...
python/sglang/srt/managers/session_controller.py
View file @
14cbe42f
...
...
@@ -15,11 +15,11 @@ import uuid
from
typing
import
Dict
,
Optional
from
sglang.srt.managers.io_struct
import
TokenizedGenerateReqInput
from
sglang.srt.managers.schedule_batch
import
Req
from
sglang.srt.managers.schedule_batch
import
FINISH_ABORT
,
Req
class
SessionReqNode
:
def
__init__
(
self
,
req
,
parent
=
None
,
childs
=
None
):
def
__init__
(
self
,
req
:
Req
,
parent
=
None
,
childs
=
None
):
self
.
req
=
req
self
.
parent
=
parent
if
parent
is
not
None
:
...
...
@@ -36,12 +36,12 @@ class SessionReqNode:
req_node
.
clear
(
req_dict
)
if
self
.
req
.
finished_reason
is
None
:
self
.
req
.
to_
abort
=
True
self
.
req
.
to_
finish
=
FINISH_ABORT
()
del
req_dict
[
self
.
req
.
rid
]
def
abort
(
self
):
if
self
.
req
.
finished_reason
is
None
:
self
.
req
.
to_
abort
=
True
self
.
req
.
to_
finish
=
FINISH_ABORT
()
def
__str__
(
self
):
return
self
.
_str_helper
(
self
.
req
.
rid
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment