Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
d658f049
Unverified
Commit
d658f049
authored
Oct 19, 2025
by
Liangsheng Yin
Committed by
GitHub
Oct 19, 2025
Browse files
[overlap-spec] fix stop condition and trimming (#11819)
parent
57e25de7
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
97 additions
and
50 deletions
+97
-50
python/sglang/srt/managers/detokenizer_manager.py
python/sglang/srt/managers/detokenizer_manager.py
+1
-0
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+84
-46
python/sglang/srt/managers/scheduler_output_processor_mixin.py
...n/sglang/srt/managers/scheduler_output_processor_mixin.py
+12
-4
No files found.
python/sglang/srt/managers/detokenizer_manager.py
View file @
d658f049
...
...
@@ -142,6 +142,7 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
if
output
[
-
1
]
==
200012
and
self
.
is_tool_call_parser_gpt_oss
:
return
output
assert
len
(
output
)
>
0
# NOTE: We can always assume the last token is the matched stop token
return
output
[:
-
1
]
return
output
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
d658f049
...
...
@@ -486,6 +486,8 @@ class Req:
# Check finish
self
.
tokenizer
=
None
self
.
finished_reason
=
None
# finished position (in output_ids), used when checking stop conditions with speculative decoding
self
.
finished_len
=
None
# Whether this request has finished output
self
.
finished_output
=
None
# If we want to abort the request in the middle of the event loop, set this to true
...
...
@@ -651,6 +653,13 @@ class Req:
spec_alg
=
get_global_server_args
().
speculative_algorithm
return
self
.
sampling_params
.
max_new_tokens
==
0
and
spec_alg
is
None
@
property
def
output_ids_through_stop
(
self
)
->
List
[
int
]:
"""Get the output ids through the stop condition. Stop position is included."""
if
self
.
finished_len
is
not
None
:
return
self
.
output_ids
[:
self
.
finished_len
]
return
self
.
output_ids
def
add_latency
(
self
,
stage
:
RequestStage
):
if
self
.
metrics_collector
is
None
:
return
...
...
@@ -702,18 +711,20 @@ class Req:
def
init_incremental_detokenize
(
self
):
first_iter
=
self
.
surr_offset
is
None
or
self
.
read_offset
is
None
output_ids
=
self
.
output_ids_through_stop
if
first_iter
:
self
.
read_offset
=
len
(
self
.
origin_input_ids_unpadded
)
self
.
surr_offset
=
max
(
self
.
read_offset
-
INIT_INCREMENTAL_DETOKENIZATION_OFFSET
,
0
)
self
.
surr_and_decode_ids
=
(
self
.
origin_input_ids_unpadded
[
self
.
surr_offset
:]
+
self
.
output_ids
self
.
origin_input_ids_unpadded
[
self
.
surr_offset
:]
+
output_ids
)
self
.
cur_decode_ids_len
=
len
(
self
.
output_ids
)
self
.
cur_decode_ids_len
=
len
(
output_ids
)
else
:
self
.
surr_and_decode_ids
.
extend
(
self
.
output_ids
[
self
.
cur_decode_ids_len
:])
self
.
cur_decode_ids_len
=
len
(
self
.
output_ids
)
self
.
surr_and_decode_ids
.
extend
(
output_ids
[
self
.
cur_decode_ids_len
:])
self
.
cur_decode_ids_len
=
len
(
output_ids
)
return
self
.
surr_and_decode_ids
,
self
.
read_offset
-
self
.
surr_offset
...
...
@@ -760,55 +771,31 @@ class Req:
return
False
def
check_finished
(
self
):
if
self
.
finished
():
return
if
self
.
to_abort
:
self
.
finished_reason
=
FINISH_ABORT
(
message
=
self
.
to_abort_message
,
)
return
if
len
(
self
.
output_ids
)
>=
self
.
sampling_params
.
max_new_tokens
:
self
.
finished_reason
=
FINISH_LENGTH
(
length
=
self
.
sampling_params
.
max_new_tokens
)
return
if
self
.
grammar
is
not
None
:
if
self
.
grammar
.
is_terminated
():
self
.
finished_reason
=
FINISH_MATCHED_TOKEN
(
matched
=
self
.
output_ids
[
-
1
])
return
last_token_id
=
self
.
output_ids
[
-
1
]
def
_check_token_based_finish
(
self
,
new_accepted_tokens
:
List
[
int
])
->
bool
:
if
self
.
sampling_params
.
ignore_eos
:
return
False
if
not
self
.
sampling_params
.
ignore_eos
:
matched_eos
=
False
# Check stop token ids
matched_eos
=
False
# Check stop
token
ids
for
i
,
token_id
in
enumerate
(
new_accepted_
token
s
):
if
self
.
sampling_params
.
stop_token_ids
:
matched_eos
=
last_
token_id
in
self
.
sampling_params
.
stop_token_ids
matched_eos
|
=
token_id
in
self
.
sampling_params
.
stop_token_ids
if
self
.
eos_token_ids
:
matched_eos
|=
last_
token_id
in
self
.
eos_token_ids
matched_eos
|=
token_id
in
self
.
eos_token_ids
if
self
.
tokenizer
is
not
None
:
matched_eos
|=
last_
token_id
==
self
.
tokenizer
.
eos_token_id
matched_eos
|=
token_id
==
self
.
tokenizer
.
eos_token_id
if
self
.
tokenizer
.
additional_stop_token_ids
:
matched_eos
|=
(
last_token_id
in
self
.
tokenizer
.
additional_stop_token_ids
)
matched_eos
|=
token_id
in
self
.
tokenizer
.
additional_stop_token_ids
if
matched_eos
:
self
.
finished_reason
=
FINISH_MATCHED_TOKEN
(
matched
=
last_token_id
)
return
self
.
finished_reason
=
FINISH_MATCHED_TOKEN
(
matched
=
token_id
)
matched_pos
=
len
(
self
.
output_ids
)
-
len
(
new_accepted_tokens
)
+
i
self
.
finished_len
=
matched_pos
+
1
return
True
if
last_token_id
>
self
.
vocab_size
or
last_token_id
<
0
:
if
self
.
sampling_params
.
stop_token_ids
:
self
.
output_ids
[
-
1
]
=
next
(
iter
(
self
.
sampling_params
.
stop_token_ids
))
if
self
.
eos_token_ids
:
self
.
output_ids
[
-
1
]
=
next
(
iter
(
self
.
eos_token_ids
))
self
.
finished_reason
=
FINISH_MATCHED_STR
(
matched
=
"NaN happened"
)
return
return
False
def
_check_str_based_finish
(
self
):
if
(
len
(
self
.
sampling_params
.
stop_strs
)
>
0
or
len
(
self
.
sampling_params
.
stop_regex_strs
)
>
0
...
...
@@ -820,7 +807,7 @@ class Req:
for
stop_str
in
self
.
sampling_params
.
stop_strs
:
if
stop_str
in
tail_str
or
stop_str
in
self
.
decoded_text
:
self
.
finished_reason
=
FINISH_MATCHED_STR
(
matched
=
stop_str
)
return
return
True
# Check stop regex
if
len
(
self
.
sampling_params
.
stop_regex_strs
)
>
0
:
...
...
@@ -829,7 +816,58 @@ class Req:
self
.
finished_reason
=
FINISHED_MATCHED_REGEX
(
matched
=
stop_regex_str
)
return
return
True
return
False
def
_check_vocab_boundary_finish
(
self
,
new_accepted_tokens
:
List
[
int
]
=
None
):
for
i
,
token_id
in
enumerate
(
new_accepted_tokens
):
if
token_id
>
self
.
vocab_size
or
token_id
<
0
:
offset
=
len
(
self
.
output_ids
)
-
len
(
new_accepted_tokens
)
+
i
if
self
.
sampling_params
.
stop_token_ids
:
self
.
output_ids
[
offset
]
=
next
(
iter
(
self
.
sampling_params
.
stop_token_ids
)
)
if
self
.
eos_token_ids
:
self
.
output_ids
[
offset
]
=
next
(
iter
(
self
.
eos_token_ids
))
self
.
finished_reason
=
FINISH_MATCHED_STR
(
matched
=
"NaN happened"
)
self
.
finished_len
=
offset
+
1
return
True
return
False
def
check_finished
(
self
,
new_accepted_len
:
int
=
1
):
if
self
.
finished
():
return
if
self
.
to_abort
:
self
.
finished_reason
=
FINISH_ABORT
(
message
=
self
.
to_abort_message
,
)
return
if
len
(
self
.
output_ids
)
>=
self
.
sampling_params
.
max_new_tokens
:
self
.
finished_reason
=
FINISH_LENGTH
(
length
=
self
.
sampling_params
.
max_new_tokens
)
self
.
finished_len
=
self
.
sampling_params
.
max_new_tokens
return
if
self
.
grammar
is
not
None
:
if
self
.
grammar
.
is_terminated
():
self
.
finished_reason
=
FINISH_MATCHED_TOKEN
(
matched
=
self
.
output_ids
[
-
1
])
return
new_accepted_tokens
=
self
.
output_ids
[
-
new_accepted_len
:]
if
self
.
_check_token_based_finish
(
new_accepted_tokens
):
return
if
self
.
_check_vocab_boundary_finish
(
new_accepted_tokens
):
return
if
self
.
_check_str_based_finish
():
return
def
reset_for_retract
(
self
):
self
.
prefix_indices
=
torch
.
empty
((
0
,),
dtype
=
torch
.
int64
)
...
...
python/sglang/srt/managers/scheduler_output_processor_mixin.py
View file @
d658f049
...
...
@@ -286,13 +286,16 @@ class SchedulerOutputProcessorMixin:
self
.
token_to_kv_pool_allocator
.
free
(
indices_to_free
)
continue
new_accepted_len
=
1
if
batch
.
spec_algorithm
.
is_none
():
req
.
output_ids
.
append
(
next_token_id
)
elif
batch
.
is_v2_eagle
:
# Only v2 eagle's output_ids are updated here.
req
.
output_ids
.
extend
(
next_token_id
)
new_accepted_len
=
len
(
next_token_id
)
req
.
check_finished
(
new_accepted_len
)
req
.
check_finished
()
if
req
.
finished
():
if
batch
.
is_v2_eagle
and
self
.
cur_batch
.
forward_mode
.
is_extend
():
# FIXME(lsyin): fix the messy logic here
...
...
@@ -734,6 +737,8 @@ class SchedulerOutputProcessorMixin:
# because of the one additional delayed token. This "continue" prevented the dummy output.
continue
req
.
finished_output
=
True
if
req
.
finished_len
is
None
:
req
.
finished_len
=
len
(
req
.
output_ids
)
should_output
=
True
else
:
if
req
.
stream
:
...
...
@@ -776,17 +781,20 @@ class SchedulerOutputProcessorMixin:
else
:
decode_ids_list
.
append
(
decode_ids
[
req
.
send_decode_id_offset
:])
# Exclude the tokens after stop condition
output_ids_
=
req
.
output_ids_through_stop
req
.
send_decode_id_offset
=
len
(
decode_ids
)
read_offsets
.
append
(
read_offset
)
output_ids
.
append
(
req
.
output_ids
[
send_token_offset
:])
req
.
send_token_offset
=
len
(
req
.
output_ids
)
output_ids
.
append
(
output_ids
_
[
send_token_offset
:])
req
.
send_token_offset
=
len
(
output_ids
_
)
skip_special_tokens
.
append
(
req
.
sampling_params
.
skip_special_tokens
)
spaces_between_special_tokens
.
append
(
req
.
sampling_params
.
spaces_between_special_tokens
)
no_stop_trim
.
append
(
req
.
sampling_params
.
no_stop_trim
)
prompt_tokens
.
append
(
len
(
req
.
origin_input_ids
))
completion_tokens
.
append
(
len
(
req
.
output_ids
))
completion_tokens
.
append
(
len
(
output_ids
_
))
cached_tokens
.
append
(
req
.
cached_tokens
)
if
not
self
.
spec_algorithm
.
is_none
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment