Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
de167cf5
Unverified
Commit
de167cf5
authored
May 10, 2025
by
Lianmin Zheng
Committed by
GitHub
May 10, 2025
Browse files
Fix request abortion (#6184)
parent
4319978c
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
148 additions
and
84 deletions
+148
-84
.github/workflows/pr-test.yml
.github/workflows/pr-test.yml
+2
-2
python/sglang/srt/entrypoints/http_server.py
python/sglang/srt/entrypoints/http_server.py
+11
-0
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+31
-14
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+27
-40
python/sglang/srt/managers/scheduler_output_processor_mixin.py
...n/sglang/srt/managers/scheduler_output_processor_mixin.py
+22
-13
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+38
-9
python/sglang/srt/sampling/sampling_params.py
python/sglang/srt/sampling/sampling_params.py
+2
-0
python/sglang/test/send_one.py
python/sglang/test/send_one.py
+13
-3
test/srt/test_bench_serving.py
test/srt/test_bench_serving.py
+1
-1
test/srt/test_flashmla.py
test/srt/test_flashmla.py
+1
-2
No files found.
.github/workflows/pr-test.yml
View file @
de167cf5
...
@@ -56,7 +56,7 @@ jobs:
...
@@ -56,7 +56,7 @@ jobs:
strategy
:
strategy
:
fail-fast
:
false
fail-fast
:
false
matrix
:
matrix
:
part
:
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
]
part
:
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
]
steps
:
steps
:
-
name
:
Checkout code
-
name
:
Checkout code
uses
:
actions/checkout@v4
uses
:
actions/checkout@v4
...
@@ -69,7 +69,7 @@ jobs:
...
@@ -69,7 +69,7 @@ jobs:
timeout-minutes
:
30
timeout-minutes
:
30
run
:
|
run
:
|
cd test/srt
cd test/srt
python3 run_suite.py --suite per-commit --auto-partition-id ${{ matrix.part }} --auto-partition-size
8
python3 run_suite.py --suite per-commit --auto-partition-id ${{ matrix.part }} --auto-partition-size
9
unit-test-backend-2-gpu
:
unit-test-backend-2-gpu
:
if
:
(github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') &&
if
:
(github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') &&
...
...
python/sglang/srt/entrypoints/http_server.py
View file @
de167cf5
...
@@ -49,6 +49,7 @@ from sglang.srt.disaggregation.utils import (
...
@@ -49,6 +49,7 @@ from sglang.srt.disaggregation.utils import (
from
sglang.srt.entrypoints.engine
import
_launch_subprocesses
from
sglang.srt.entrypoints.engine
import
_launch_subprocesses
from
sglang.srt.function_call_parser
import
FunctionCallParser
from
sglang.srt.function_call_parser
import
FunctionCallParser
from
sglang.srt.managers.io_struct
import
(
from
sglang.srt.managers.io_struct
import
(
AbortReq
,
CloseSessionReqInput
,
CloseSessionReqInput
,
ConfigureLoggingReq
,
ConfigureLoggingReq
,
EmbeddingReqInput
,
EmbeddingReqInput
,
...
@@ -539,6 +540,16 @@ async def configure_logging(obj: ConfigureLoggingReq, request: Request):
...
@@ -539,6 +540,16 @@ async def configure_logging(obj: ConfigureLoggingReq, request: Request):
return
Response
(
status_code
=
200
)
return
Response
(
status_code
=
200
)
@
app
.
post
(
"/abort_request"
)
async
def
abort_request
(
obj
:
AbortReq
,
request
:
Request
):
"""Abort a request."""
try
:
_global_state
.
tokenizer_manager
.
abort_request
(
rid
=
obj
.
rid
)
return
Response
(
status_code
=
200
)
except
Exception
as
e
:
return
_create_error_response
(
e
)
@
app
.
post
(
"/parse_function_call"
)
@
app
.
post
(
"/parse_function_call"
)
async
def
parse_function_call_request
(
obj
:
ParseFunctionCallReq
,
request
:
Request
):
async
def
parse_function_call_request
(
obj
:
ParseFunctionCallReq
,
request
:
Request
):
"""
"""
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
de167cf5
from
__future__
import
annotations
from
__future__
import
annotations
import
hashlib
from
enum
import
Enum
,
auto
# Copyright 2023-2024 SGLang Team
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -30,12 +27,16 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
...
@@ -30,12 +27,16 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
It will be transformed from CPU scheduler to GPU model runner.
It will be transformed from CPU scheduler to GPU model runner.
- ForwardBatch is managed by `model_runner.py::ModelRunner`.
- ForwardBatch is managed by `model_runner.py::ModelRunner`.
It contains low-level tensor data. Most of the data consists of GPU tensors.
It contains low-level tensor data. Most of the data consists of GPU tensors.
TODO(lmzheng): ModelWorkerBatch seems a bit redundant and we consider removing it in the future.
"""
"""
import
copy
import
copy
import
dataclasses
import
dataclasses
import
hashlib
import
logging
import
logging
import
threading
import
threading
from
enum
import
Enum
,
auto
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Set
,
Tuple
,
Union
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Set
,
Tuple
,
Union
import
numpy
as
np
import
numpy
as
np
...
@@ -134,9 +135,9 @@ class FINISH_LENGTH(BaseFinishReason):
...
@@ -134,9 +135,9 @@ class FINISH_LENGTH(BaseFinishReason):
class
FINISH_ABORT
(
BaseFinishReason
):
class
FINISH_ABORT
(
BaseFinishReason
):
def
__init__
(
self
,
message
=
"Unknown error"
,
status_code
=
None
,
err_type
=
None
):
def
__init__
(
self
,
message
=
None
,
status_code
=
None
,
err_type
=
None
):
super
().
__init__
(
is_error
=
True
)
super
().
__init__
(
is_error
=
True
)
self
.
message
=
message
self
.
message
=
message
or
"Aborted"
self
.
status_code
=
status_code
self
.
status_code
=
status_code
self
.
err_type
=
err_type
self
.
err_type
=
err_type
...
@@ -441,11 +442,13 @@ class Req:
...
@@ -441,11 +442,13 @@ class Req:
# Check finish
# Check finish
self
.
tokenizer
=
None
self
.
tokenizer
=
None
self
.
finished_reason
=
None
self
.
finished_reason
=
None
# Whether this request has finished output
self
.
finished_output
=
None
# If we want to abort the request in the middle of the event loop, set this to true
# If we want to abort the request in the middle of the event loop, set this to true
# Note: We should never set finished_reason in the middle, the req will get filtered and never respond
# Note: We should never set finished_reason in the middle, the req will get filtered and never respond
self
.
to_abort
=
False
self
.
to_abort
=
False
# This carries the error message for `.to_abort` and will be attached to the finished_reason at the end of the event loop
# This carries the error message for `.to_abort` and will be attached to the finished_reason at the end of the event loop
self
.
to_abort_message
:
str
=
"Unknown error"
self
.
to_abort_message
:
str
=
None
self
.
stream
=
stream
self
.
stream
=
stream
self
.
eos_token_ids
=
eos_token_ids
self
.
eos_token_ids
=
eos_token_ids
...
@@ -546,8 +549,6 @@ class Req:
...
@@ -546,8 +549,6 @@ class Req:
self
.
bootstrap_room
:
Optional
[
int
]
=
bootstrap_room
self
.
bootstrap_room
:
Optional
[
int
]
=
bootstrap_room
self
.
disagg_kv_sender
:
Optional
[
BaseKVSender
]
=
None
self
.
disagg_kv_sender
:
Optional
[
BaseKVSender
]
=
None
# used for warmup because we don't have a pair yet when init
self
.
skip_kv_transfer
:
bool
=
False
# the start index of the sent kv cache
# the start index of the sent kv cache
# We want to send it chunk by chunk for chunked prefill.
# We want to send it chunk by chunk for chunked prefill.
# After every chunk forward, we do the following:
# After every chunk forward, we do the following:
...
@@ -555,15 +556,15 @@ class Req:
...
@@ -555,15 +556,15 @@ class Req:
# start_send_idx = len(req.fill_ids)
# start_send_idx = len(req.fill_ids)
self
.
start_send_idx
:
int
=
0
self
.
start_send_idx
:
int
=
0
self
.
metadata_buffer_index
:
int
=
-
1
# The first output_id transferred from prefill instance.
self
.
transferred_output_id
:
Optional
[
int
]
=
None
# For overlap schedule, we delay the kv transfer until `process_batch_result_disagg_prefill` rather than `process_prefill_chunk` in non-overlap
# For overlap schedule, we delay the kv transfer until `process_batch_result_disagg_prefill` rather than `process_prefill_chunk` in non-overlap
# This is because kv is not ready in `process_prefill_chunk`.
# This is because kv is not ready in `process_prefill_chunk`.
# We use `tmp_end_idx` to store the end index of the kv cache to send.
# We use `tmp_end_idx` to store the end index of the kv cache to send.
self
.
tmp_end_idx
:
int
=
-
1
self
.
tmp_end_idx
:
int
=
-
1
self
.
metadata_buffer_index
:
int
=
-
1
# The first output_id transferred from prefill instance.
self
.
transferred_output_id
:
Optional
[
int
]
=
None
@
property
@
property
def
seqlen
(
self
):
def
seqlen
(
self
):
return
len
(
self
.
origin_input_ids
)
+
len
(
self
.
output_ids
)
return
len
(
self
.
origin_input_ids
)
+
len
(
self
.
output_ids
)
...
@@ -697,13 +698,29 @@ class Req:
...
@@ -697,13 +698,29 @@ class Req:
self
.
req_pool_idx
=
None
self
.
req_pool_idx
=
None
self
.
already_computed
=
0
self
.
already_computed
=
0
def
offload_kv_cache
(
self
,
req_to_token_pool
,
token_to_kv_pool_allocator
):
token_indices
=
req_to_token_pool
.
req_to_token
[
self
.
req_pool_idx
,
:
self
.
seqlen
-
1
]
self
.
kv_cache_cpu
=
token_to_kv_pool_allocator
.
get_cpu_copy
(
token_indices
)
def
load_kv_cache
(
self
,
req_to_token_pool
,
token_to_kv_pool_allocator
):
token_indices
=
req_to_token_pool
.
req_to_token
[
self
.
req_pool_idx
,
:
self
.
seqlen
-
1
]
token_to_kv_pool_allocator
.
load_cpu_copy
(
self
.
kv_cache_cpu
,
token_indices
)
del
self
.
kv_cache_cpu
def
__repr__
(
self
):
def
__repr__
(
self
):
return
(
return
(
f
"Req(rid=
{
self
.
rid
}
, "
f
"Req(rid=
{
self
.
rid
}
, "
f
"input_ids=
{
self
.
origin_input_ids
}
, output_ids=
{
self
.
output_ids
}
)"
f
"input_ids=
{
self
.
origin_input_ids
}
, output_ids=
{
self
.
output_ids
}
, "
f
"
{
self
.
grammar
=
}
, "
f
"
{
self
.
sampling_params
=
}
)"
)
)
# Batch id
bid
=
0
bid
=
0
...
@@ -1447,7 +1464,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -1447,7 +1464,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
i
i
for
i
in
range
(
len
(
self
.
reqs
))
for
i
in
range
(
len
(
self
.
reqs
))
if
not
self
.
reqs
[
i
].
finished
()
if
not
self
.
reqs
[
i
].
finished
()
and
not
self
.
reqs
[
i
]
in
chunked_req_to_exclude
and
self
.
reqs
[
i
]
not
in
chunked_req_to_exclude
]
]
if
keep_indices
is
None
or
len
(
keep_indices
)
==
0
:
if
keep_indices
is
None
or
len
(
keep_indices
)
==
0
:
...
...
python/sglang/srt/managers/scheduler.py
View file @
de167cf5
...
@@ -20,7 +20,6 @@ import signal
...
@@ -20,7 +20,6 @@ import signal
import
sys
import
sys
import
threading
import
threading
import
time
import
time
import
warnings
from
collections
import
defaultdict
,
deque
from
collections
import
defaultdict
,
deque
from
concurrent
import
futures
from
concurrent
import
futures
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
...
@@ -121,11 +120,7 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache
...
@@ -121,11 +120,7 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache
from
sglang.srt.mem_cache.hiradix_cache
import
HiRadixCache
from
sglang.srt.mem_cache.hiradix_cache
import
HiRadixCache
from
sglang.srt.mem_cache.radix_cache
import
RadixCache
from
sglang.srt.mem_cache.radix_cache
import
RadixCache
from
sglang.srt.metrics.collector
import
SchedulerMetricsCollector
,
SchedulerStats
from
sglang.srt.metrics.collector
import
SchedulerMetricsCollector
,
SchedulerStats
from
sglang.srt.model_executor.forward_batch_info
import
(
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
,
PPProxyTensors
ForwardBatch
,
ForwardMode
,
PPProxyTensors
,
)
from
sglang.srt.reasoning_parser
import
ReasoningParser
from
sglang.srt.reasoning_parser
import
ReasoningParser
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.speculative.spec_info
import
SpeculativeAlgorithm
from
sglang.srt.speculative.spec_info
import
SpeculativeAlgorithm
...
@@ -135,6 +130,7 @@ from sglang.srt.utils import (
...
@@ -135,6 +130,7 @@ from sglang.srt.utils import (
broadcast_pyobj
,
broadcast_pyobj
,
configure_logger
,
configure_logger
,
crash_on_warnings
,
crash_on_warnings
,
disable_request_logging
,
get_bool_env_var
,
get_bool_env_var
,
get_zmq_socket
,
get_zmq_socket
,
kill_itself_when_parent_died
,
kill_itself_when_parent_died
,
...
@@ -907,19 +903,6 @@ class Scheduler(
...
@@ -907,19 +903,6 @@ class Scheduler(
fake_input_ids
=
[
1
]
*
seq_length
fake_input_ids
=
[
1
]
*
seq_length
recv_req
.
input_ids
=
fake_input_ids
recv_req
.
input_ids
=
fake_input_ids
# Handle custom logit processor passed to the request
custom_logit_processor
=
recv_req
.
custom_logit_processor
if
(
not
self
.
server_args
.
enable_custom_logit_processor
and
custom_logit_processor
is
not
None
):
logger
.
warning
(
"The SGLang server is not configured to enable custom logit processor."
"The custom logit processor passed in will be ignored."
"Please set --enable-custom-logits-processor to enable this feature."
)
custom_logit_processor
=
None
if
recv_req
.
bootstrap_port
is
None
:
if
recv_req
.
bootstrap_port
is
None
:
# Use default bootstrap port
# Use default bootstrap port
recv_req
.
bootstrap_port
=
self
.
server_args
.
disaggregation_bootstrap_port
recv_req
.
bootstrap_port
=
self
.
server_args
.
disaggregation_bootstrap_port
...
@@ -935,7 +918,7 @@ class Scheduler(
...
@@ -935,7 +918,7 @@ class Scheduler(
stream
=
recv_req
.
stream
,
stream
=
recv_req
.
stream
,
lora_path
=
recv_req
.
lora_path
,
lora_path
=
recv_req
.
lora_path
,
input_embeds
=
recv_req
.
input_embeds
,
input_embeds
=
recv_req
.
input_embeds
,
custom_logit_processor
=
custom_logit_processor
,
custom_logit_processor
=
recv_req
.
custom_logit_processor
,
return_hidden_states
=
recv_req
.
return_hidden_states
,
return_hidden_states
=
recv_req
.
return_hidden_states
,
eos_token_ids
=
self
.
model_config
.
hf_eos_token_id
,
eos_token_ids
=
self
.
model_config
.
hf_eos_token_id
,
bootstrap_host
=
recv_req
.
bootstrap_host
,
bootstrap_host
=
recv_req
.
bootstrap_host
,
...
@@ -1246,8 +1229,6 @@ class Scheduler(
...
@@ -1246,8 +1229,6 @@ class Scheduler(
f
"
{
self
.
token_to_kv_pool_allocator
.
available_size
()
=
}
\n
"
f
"
{
self
.
token_to_kv_pool_allocator
.
available_size
()
=
}
\n
"
f
"
{
self
.
tree_cache
.
evictable_size
()
=
}
\n
"
f
"
{
self
.
tree_cache
.
evictable_size
()
=
}
\n
"
)
)
warnings
.
warn
(
msg
)
if
crash_on_warnings
():
raise
ValueError
(
msg
)
raise
ValueError
(
msg
)
if
len
(
self
.
req_to_token_pool
.
free_slots
)
!=
self
.
req_to_token_pool
.
size
:
if
len
(
self
.
req_to_token_pool
.
free_slots
)
!=
self
.
req_to_token_pool
.
size
:
...
@@ -1256,8 +1237,6 @@ class Scheduler(
...
@@ -1256,8 +1237,6 @@ class Scheduler(
f
"available_size=
{
len
(
self
.
req_to_token_pool
.
free_slots
)
}
, "
f
"available_size=
{
len
(
self
.
req_to_token_pool
.
free_slots
)
}
, "
f
"total_size=
{
self
.
req_to_token_pool
.
size
}
\n
"
f
"total_size=
{
self
.
req_to_token_pool
.
size
}
\n
"
)
)
warnings
.
warn
(
msg
)
if
crash_on_warnings
():
raise
ValueError
(
msg
)
raise
ValueError
(
msg
)
if
(
if
(
...
@@ -1774,13 +1753,13 @@ class Scheduler(
...
@@ -1774,13 +1753,13 @@ class Scheduler(
if
self
.
cur_batch
is
not
None
:
if
self
.
cur_batch
is
not
None
:
if
self
.
watchdog_last_forward_ct
==
self
.
forward_ct
:
if
self
.
watchdog_last_forward_ct
==
self
.
forward_ct
:
if
current
>
self
.
watchdog_last_time
+
self
.
watchdog_timeout
:
if
current
>
self
.
watchdog_last_time
+
self
.
watchdog_timeout
:
logger
.
error
(
f
"Watchdog timeout (
{
self
.
watchdog_timeout
=
}
)"
)
break
break
else
:
else
:
self
.
watchdog_last_forward_ct
=
self
.
forward_ct
self
.
watchdog_last_forward_ct
=
self
.
forward_ct
self
.
watchdog_last_time
=
current
self
.
watchdog_last_time
=
current
time
.
sleep
(
self
.
watchdog_timeout
//
2
)
time
.
sleep
(
self
.
watchdog_timeout
//
2
)
if
not
disable_request_logging
():
# Print batch size and memory pool info to check whether there are de-sync issues.
# Print batch size and memory pool info to check whether there are de-sync issues.
logger
.
error
(
logger
.
error
(
f
"
{
self
.
cur_batch
.
batch_size
()
=
}
, "
f
"
{
self
.
cur_batch
.
batch_size
()
=
}
, "
...
@@ -1788,10 +1767,13 @@ class Scheduler(
...
@@ -1788,10 +1767,13 @@ class Scheduler(
f
"
{
self
.
token_to_kv_pool_allocator
.
available_size
()
=
}
, "
f
"
{
self
.
token_to_kv_pool_allocator
.
available_size
()
=
}
, "
f
"
{
self
.
tree_cache
.
evictable_size
()
=
}
, "
f
"
{
self
.
tree_cache
.
evictable_size
()
=
}
, "
)
)
# Wait for some time so that the parent process can print the error.
pyspy_dump_schedulers
()
pyspy_dump_schedulers
()
logger
.
error
(
f
"Watchdog timeout (
{
self
.
watchdog_timeout
=
}
)"
)
print
(
file
=
sys
.
stderr
,
flush
=
True
)
print
(
file
=
sys
.
stderr
,
flush
=
True
)
print
(
file
=
sys
.
stdout
,
flush
=
True
)
print
(
file
=
sys
.
stdout
,
flush
=
True
)
# Wait for some time so that the parent process can print the error.
time
.
sleep
(
5
)
time
.
sleep
(
5
)
self
.
parent_process
.
send_signal
(
signal
.
SIGQUIT
)
self
.
parent_process
.
send_signal
(
signal
.
SIGQUIT
)
...
@@ -1923,25 +1905,30 @@ class Scheduler(
...
@@ -1923,25 +1905,30 @@ class Scheduler(
)
)
def
abort_request
(
self
,
recv_req
:
AbortReq
):
def
abort_request
(
self
,
recv_req
:
AbortReq
):
# TODO(lmzheng): abort the requests in the grammar queue.
# Delete requests in the waiting queue
# Delete requests in the waiting queue
to_del
=
[]
to_del
=
[]
for
i
,
req
in
enumerate
(
self
.
waiting_queue
):
for
i
,
req
in
enumerate
(
self
.
waiting_queue
):
if
req
.
rid
.
startswith
(
recv_req
.
rid
):
if
req
.
rid
.
startswith
(
recv_req
.
rid
):
to_del
.
append
(
i
)
to_del
.
append
(
i
)
break
# Sort in reverse order to avoid index issues when deleting
# Sort in reverse order to avoid index issues when deleting
for
i
in
sort
ed
(
to_del
,
reverse
=
True
):
for
i
in
revers
ed
(
to_del
):
req
=
self
.
waiting_queue
.
pop
(
i
)
req
=
self
.
waiting_queue
.
pop
(
i
)
self
.
send_to_tokenizer
.
send_pyobj
(
AbortReq
(
req
.
rid
))
logger
.
debug
(
f
"Abort queued request.
{
req
.
rid
=
}
"
)
logger
.
debug
(
f
"Abort queued request.
{
req
.
rid
=
}
"
)
return
# Delete requests in the running batch
# Delete requests in the running batch
for
req
in
self
.
running_batch
.
reqs
:
if
self
.
cur_batch
is
self
.
running_batch
or
self
.
cur_batch
is
None
:
reqs
=
self
.
running_batch
.
reqs
else
:
reqs
=
self
.
running_batch
.
reqs
+
self
.
cur_batch
.
reqs
for
req
in
reqs
:
if
req
.
rid
.
startswith
(
recv_req
.
rid
)
and
not
req
.
finished
():
if
req
.
rid
.
startswith
(
recv_req
.
rid
)
and
not
req
.
finished
():
logger
.
debug
(
f
"Abort running request.
{
req
.
rid
=
}
"
)
logger
.
debug
(
f
"Abort running request.
{
req
.
rid
=
}
"
)
req
.
to_abort
=
True
req
.
to_abort
=
True
return
def
_pause_engine
(
self
)
->
Tuple
[
List
[
Req
],
int
]:
def
_pause_engine
(
self
)
->
Tuple
[
List
[
Req
],
int
]:
raise
NotImplementedError
()
raise
NotImplementedError
()
...
...
python/sglang/srt/managers/scheduler_output_processor_mixin.py
View file @
de167cf5
...
@@ -15,6 +15,8 @@ if TYPE_CHECKING:
...
@@ -15,6 +15,8 @@ if TYPE_CHECKING:
Scheduler
,
Scheduler
,
)
)
DEFAULT_FORCE_STREAM_INTERVAL
=
50
class
SchedulerOutputProcessorMixin
:
class
SchedulerOutputProcessorMixin
:
"""
"""
...
@@ -512,19 +514,26 @@ class SchedulerOutputProcessorMixin:
...
@@ -512,19 +514,26 @@ class SchedulerOutputProcessorMixin:
if
self
.
model_config
.
is_multimodal_gen
and
req
.
to_abort
:
if
self
.
model_config
.
is_multimodal_gen
and
req
.
to_abort
:
continue
continue
if
(
if
req
.
finished
():
req
.
finished
()
if
req
.
finished_output
:
# If stream, follow the given stream_interval
# With the overlap schedule, a request will try to output twice and hit this line twice
or
(
req
.
stream
and
len
(
req
.
output_ids
)
%
self
.
stream_interval
==
0
)
# because of the one additional delayed token. This "continue" prevented the dummy output.
# If not stream, we still want to output some tokens to get the benefit of incremental decoding.
continue
# TODO(lianmin): this is wrong for speculative decoding because len(req.output_ids) does not
req
.
finished_output
=
True
# always increase one-by-one.
should_output
=
True
or
(
else
:
not
req
.
stream
if
req
.
stream
:
and
len
(
req
.
output_ids
)
%
50
==
0
stream_interval
=
(
req
.
sampling_params
.
stream_interval
or
self
.
stream_interval
)
should_output
=
len
(
req
.
output_ids
)
%
stream_interval
==
0
else
:
should_output
=
(
len
(
req
.
output_ids
)
%
DEFAULT_FORCE_STREAM_INTERVAL
==
0
and
not
self
.
model_config
.
is_multimodal_gen
and
not
self
.
model_config
.
is_multimodal_gen
)
)
):
if
should_output
:
rids
.
append
(
req
.
rid
)
rids
.
append
(
req
.
rid
)
finished_reasons
.
append
(
finished_reasons
.
append
(
req
.
finished_reason
.
to_json
()
if
req
.
finished_reason
else
None
req
.
finished_reason
.
to_json
()
if
req
.
finished_reason
else
None
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
de167cf5
...
@@ -288,6 +288,7 @@ class TokenizerManager:
...
@@ -288,6 +288,7 @@ class TokenizerManager:
),
),
self
.
_handle_batch_output
,
self
.
_handle_batch_output
,
),
),
(
AbortReq
,
self
.
_handle_abort_req
),
(
OpenSessionReqOutput
,
self
.
_handle_open_session_req_output
),
(
OpenSessionReqOutput
,
self
.
_handle_open_session_req_output
),
(
(
UpdateWeightFromDiskReqOutput
,
UpdateWeightFromDiskReqOutput
,
...
@@ -341,13 +342,14 @@ class TokenizerManager:
...
@@ -341,13 +342,14 @@ class TokenizerManager:
]
]
)
)
# For pd disaggregtion
self
.
disaggregation_mode
=
DisaggregationMode
(
self
.
disaggregation_mode
=
DisaggregationMode
(
self
.
server_args
.
disaggregation_mode
self
.
server_args
.
disaggregation_mode
)
)
self
.
transfer_backend
=
TransferBackend
(
self
.
transfer_backend
=
TransferBackend
(
self
.
server_args
.
disaggregation_transfer_backend
self
.
server_args
.
disaggregation_transfer_backend
)
)
#
for disaggregtion, s
tart kv boostrap server on prefill
#
S
tart kv boostrap server on prefill
if
self
.
disaggregation_mode
==
DisaggregationMode
.
PREFILL
:
if
self
.
disaggregation_mode
==
DisaggregationMode
.
PREFILL
:
# only start bootstrap server on prefill tm
# only start bootstrap server on prefill tm
kv_bootstrap_server_class
=
get_kv_class
(
kv_bootstrap_server_class
=
get_kv_class
(
...
@@ -482,6 +484,14 @@ class TokenizerManager:
...
@@ -482,6 +484,14 @@ class TokenizerManager:
session_params
=
(
session_params
=
(
SessionParams
(
**
obj
.
session_params
)
if
obj
.
session_params
else
None
SessionParams
(
**
obj
.
session_params
)
if
obj
.
session_params
else
None
)
)
if
(
obj
.
custom_logit_processor
and
not
self
.
server_args
.
enable_custom_logit_processor
):
raise
ValueError
(
"The server is not configured to enable custom logit processor. "
"Please set `--enable-custom-logits-processor` to enable this feature."
)
sampling_params
=
SamplingParams
(
**
obj
.
sampling_params
)
sampling_params
=
SamplingParams
(
**
obj
.
sampling_params
)
sampling_params
.
normalize
(
self
.
tokenizer
)
sampling_params
.
normalize
(
self
.
tokenizer
)
...
@@ -570,9 +580,9 @@ class TokenizerManager:
...
@@ -570,9 +580,9 @@ class TokenizerManager:
tokenized_obj
:
Union
[
TokenizedGenerateReqInput
,
TokenizedEmbeddingReqInput
],
tokenized_obj
:
Union
[
TokenizedGenerateReqInput
,
TokenizedEmbeddingReqInput
],
created_time
:
Optional
[
float
]
=
None
,
created_time
:
Optional
[
float
]
=
None
,
):
):
self
.
send_to_scheduler
.
send_pyobj
(
tokenized_obj
)
state
=
ReqState
([],
False
,
asyncio
.
Event
(),
obj
,
created_time
=
created_time
)
state
=
ReqState
([],
False
,
asyncio
.
Event
(),
obj
,
created_time
=
created_time
)
self
.
rid_to_state
[
obj
.
rid
]
=
state
self
.
rid_to_state
[
obj
.
rid
]
=
state
self
.
send_to_scheduler
.
send_pyobj
(
tokenized_obj
)
async
def
_wait_one_response
(
async
def
_wait_one_response
(
self
,
self
,
...
@@ -587,10 +597,11 @@ class TokenizerManager:
...
@@ -587,10 +597,11 @@ class TokenizerManager:
await
asyncio
.
wait_for
(
state
.
event
.
wait
(),
timeout
=
4
)
await
asyncio
.
wait_for
(
state
.
event
.
wait
(),
timeout
=
4
)
except
asyncio
.
TimeoutError
:
except
asyncio
.
TimeoutError
:
if
request
is
not
None
and
await
request
.
is_disconnected
():
if
request
is
not
None
and
await
request
.
is_disconnected
():
# Abort the request for disconnected requests (non-streaming, waiting queue)
self
.
abort_request
(
obj
.
rid
)
self
.
abort_request
(
obj
.
rid
)
# Use exception to kill the whole call stack and asyncio task
raise
ValueError
(
raise
ValueError
(
"Request is disconnected from the client side. "
f
"Request is disconnected from the client side (type 1). Abort request
{
obj
.
rid
=
}
"
f
"Abort request
{
obj
.
rid
}
"
)
)
continue
continue
...
@@ -605,7 +616,6 @@ class TokenizerManager:
...
@@ -605,7 +616,6 @@ class TokenizerManager:
else
:
else
:
msg
=
f
"Finish: obj=
{
dataclass_to_string_truncated
(
obj
,
max_length
,
skip_names
=
skip_names
)
}
, out=
{
dataclass_to_string_truncated
(
out
,
max_length
,
skip_names
=
out_skip_names
)
}
"
msg
=
f
"Finish: obj=
{
dataclass_to_string_truncated
(
obj
,
max_length
,
skip_names
=
skip_names
)
}
, out=
{
dataclass_to_string_truncated
(
out
,
max_length
,
skip_names
=
out_skip_names
)
}
"
logger
.
info
(
msg
)
logger
.
info
(
msg
)
del
self
.
rid_to_state
[
obj
.
rid
]
# Check if this was an abort/error created by scheduler
# Check if this was an abort/error created by scheduler
if
isinstance
(
out
[
"meta_info"
].
get
(
"finish_reason"
),
dict
):
if
isinstance
(
out
[
"meta_info"
].
get
(
"finish_reason"
),
dict
):
...
@@ -625,10 +635,11 @@ class TokenizerManager:
...
@@ -625,10 +635,11 @@ class TokenizerManager:
yield
out
yield
out
else
:
else
:
if
request
is
not
None
and
await
request
.
is_disconnected
():
if
request
is
not
None
and
await
request
.
is_disconnected
():
# Abort the request for disconnected requests (non-streaming, running)
self
.
abort_request
(
obj
.
rid
)
self
.
abort_request
(
obj
.
rid
)
# Use exception to kill the whole call stack and asyncio task
raise
ValueError
(
raise
ValueError
(
"Request is disconnected from the client side. "
f
"Request is disconnected from the client side (type 3). Abort request
{
obj
.
rid
=
}
"
f
"Abort request
{
obj
.
rid
}
"
)
)
async
def
_handle_batch_request
(
async
def
_handle_batch_request
(
...
@@ -728,7 +739,6 @@ class TokenizerManager:
...
@@ -728,7 +739,6 @@ class TokenizerManager:
def
abort_request
(
self
,
rid
:
str
):
def
abort_request
(
self
,
rid
:
str
):
if
rid
not
in
self
.
rid_to_state
:
if
rid
not
in
self
.
rid_to_state
:
return
return
del
self
.
rid_to_state
[
rid
]
req
=
AbortReq
(
rid
)
req
=
AbortReq
(
rid
)
self
.
send_to_scheduler
.
send_pyobj
(
req
)
self
.
send_to_scheduler
.
send_pyobj
(
req
)
...
@@ -964,7 +974,7 @@ class TokenizerManager:
...
@@ -964,7 +974,7 @@ class TokenizerManager:
def
create_abort_task
(
self
,
obj
:
GenerateReqInput
):
def
create_abort_task
(
self
,
obj
:
GenerateReqInput
):
# Abort the request if the client is disconnected.
# Abort the request if the client is disconnected.
async
def
abort_request
():
async
def
abort_request
():
await
asyncio
.
sleep
(
1
)
await
asyncio
.
sleep
(
2
)
if
obj
.
is_single
:
if
obj
.
is_single
:
self
.
abort_request
(
obj
.
rid
)
self
.
abort_request
(
obj
.
rid
)
else
:
else
:
...
@@ -1035,6 +1045,9 @@ class TokenizerManager:
...
@@ -1035,6 +1045,9 @@ class TokenizerManager:
for
i
,
rid
in
enumerate
(
recv_obj
.
rids
):
for
i
,
rid
in
enumerate
(
recv_obj
.
rids
):
state
=
self
.
rid_to_state
.
get
(
rid
,
None
)
state
=
self
.
rid_to_state
.
get
(
rid
,
None
)
if
state
is
None
:
if
state
is
None
:
logger
.
error
(
f
"Received output for
{
rid
=
}
but the state was deleted in TokenizerManager."
)
continue
continue
# Build meta_info and return value
# Build meta_info and return value
...
@@ -1098,6 +1111,7 @@ class TokenizerManager:
...
@@ -1098,6 +1111,7 @@ class TokenizerManager:
meta_info
[
"spec_verify_ct"
]
=
recv_obj
.
spec_verify_ct
[
i
]
meta_info
[
"spec_verify_ct"
]
=
recv_obj
.
spec_verify_ct
[
i
]
state
.
finished_time
=
time
.
time
()
state
.
finished_time
=
time
.
time
()
meta_info
[
"e2e_latency"
]
=
state
.
finished_time
-
state
.
created_time
meta_info
[
"e2e_latency"
]
=
state
.
finished_time
-
state
.
created_time
del
self
.
rid_to_state
[
rid
]
state
.
out_list
.
append
(
out_dict
)
state
.
out_list
.
append
(
out_dict
)
state
.
event
.
set
()
state
.
event
.
set
()
...
@@ -1246,6 +1260,9 @@ class TokenizerManager:
...
@@ -1246,6 +1260,9 @@ class TokenizerManager:
# Schedule the task to run in the background without awaiting it
# Schedule the task to run in the background without awaiting it
asyncio
.
create_task
(
asyncio
.
to_thread
(
background_task
))
asyncio
.
create_task
(
asyncio
.
to_thread
(
background_task
))
def
_handle_abort_req
(
self
,
recv_obj
):
self
.
rid_to_state
.
pop
(
recv_obj
.
rid
)
def
_handle_open_session_req_output
(
self
,
recv_obj
):
def
_handle_open_session_req_output
(
self
,
recv_obj
):
self
.
session_futures
[
recv_obj
.
session_id
].
set_result
(
self
.
session_futures
[
recv_obj
.
session_id
].
set_result
(
recv_obj
.
session_id
if
recv_obj
.
success
else
None
recv_obj
.
session_id
if
recv_obj
.
success
else
None
...
@@ -1325,3 +1342,15 @@ class _Communicator(Generic[T]):
...
@@ -1325,3 +1342,15 @@ class _Communicator(Generic[T]):
self
.
_result_values
.
append
(
recv_obj
)
self
.
_result_values
.
append
(
recv_obj
)
if
len
(
self
.
_result_values
)
==
self
.
_fan_out
:
if
len
(
self
.
_result_values
)
==
self
.
_fan_out
:
self
.
_result_event
.
set
()
self
.
_result_event
.
set
()
# Note: request abort handling logic
# We should handle all of the following cases correctly.
#
# | entrypoint | is_streaming | status | abort engine | cancel asyncio task | rid_to_state |
# | ---------- | ------------ | --------------- | --------------- | --------------------- | --------------------------- |
# | http | yes | waiting queue | background task | fast api | del in _handle_abort_req |
# | http | yes | running | background task | fast api | del in _handle_batch_output |
# | http | no | waiting queue | type 1 | type 1 exception | del in _handle_abort_req |
# | http | no | running | type 3 | type 3 exception | del in _handle_batch_output |
#
python/sglang/srt/sampling/sampling_params.py
View file @
de167cf5
...
@@ -50,6 +50,7 @@ class SamplingParams:
...
@@ -50,6 +50,7 @@ class SamplingParams:
spaces_between_special_tokens
:
bool
=
True
,
spaces_between_special_tokens
:
bool
=
True
,
no_stop_trim
:
bool
=
False
,
no_stop_trim
:
bool
=
False
,
custom_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
custom_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
stream_interval
:
Optional
[
int
]
=
None
,
)
->
None
:
)
->
None
:
self
.
max_new_tokens
=
max_new_tokens
self
.
max_new_tokens
=
max_new_tokens
self
.
stop_strs
=
stop
self
.
stop_strs
=
stop
...
@@ -75,6 +76,7 @@ class SamplingParams:
...
@@ -75,6 +76,7 @@ class SamplingParams:
self
.
spaces_between_special_tokens
=
spaces_between_special_tokens
self
.
spaces_between_special_tokens
=
spaces_between_special_tokens
self
.
no_stop_trim
=
no_stop_trim
self
.
no_stop_trim
=
no_stop_trim
self
.
custom_params
=
custom_params
self
.
custom_params
=
custom_params
self
.
stream_interval
=
stream_interval
# Process some special cases
# Process some special cases
if
0
<=
self
.
temperature
<
_SAMPLING_EPS
:
if
0
<=
self
.
temperature
<
_SAMPLING_EPS
:
...
...
python/sglang/test/send_one.py
View file @
de167cf5
...
@@ -27,6 +27,7 @@ class BenchArgs:
...
@@ -27,6 +27,7 @@ class BenchArgs:
"Human: Give me a fully functional FastAPI server. Show the python code.
\n\n
Assistant:"
"Human: Give me a fully functional FastAPI server. Show the python code.
\n\n
Assistant:"
)
)
image
:
bool
=
False
image
:
bool
=
False
many_images
:
bool
=
False
stream
:
bool
=
False
stream
:
bool
=
False
@
staticmethod
@
staticmethod
...
@@ -48,6 +49,7 @@ class BenchArgs:
...
@@ -48,6 +49,7 @@ class BenchArgs:
parser
.
add_argument
(
"--return-logprob"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--return-logprob"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--prompt"
,
type
=
str
,
default
=
BenchArgs
.
prompt
)
parser
.
add_argument
(
"--prompt"
,
type
=
str
,
default
=
BenchArgs
.
prompt
)
parser
.
add_argument
(
"--image"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--image"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--many-images"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--stream"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--stream"
,
action
=
"store_true"
)
@
classmethod
@
classmethod
...
@@ -62,6 +64,17 @@ def send_one_prompt(args):
...
@@ -62,6 +64,17 @@ def send_one_prompt(args):
"Human: Describe this image in a very short sentence.
\n\n
Assistant:"
"Human: Describe this image in a very short sentence.
\n\n
Assistant:"
)
)
image_data
=
"https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png"
image_data
=
"https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png"
elif
args
.
many_images
:
args
.
prompt
=
(
"Human: I have one reference image and many images."
"Describe their relationship in a very short sentence.
\n\n
Assistant:"
)
image_data
=
[
"https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png"
,
"https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png"
,
"https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png"
,
"https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png"
,
]
else
:
else
:
image_data
=
None
image_data
=
None
...
@@ -74,9 +87,6 @@ def send_one_prompt(args):
...
@@ -74,9 +87,6 @@ def send_one_prompt(args):
"Write in a format of json.
\n
Assistant:"
"Write in a format of json.
\n
Assistant:"
)
)
json_schema
=
"$$ANY$$"
json_schema
=
"$$ANY$$"
json_schema
=
(
'{"type": "object", "properties": {"population": {"type": "integer"}}}'
)
else
:
else
:
json_schema
=
None
json_schema
=
None
...
...
test/srt/test_bench_serving.py
View file @
de167cf5
...
@@ -190,7 +190,7 @@ class TestBenchServing(CustomTestCase):
...
@@ -190,7 +190,7 @@ class TestBenchServing(CustomTestCase):
f
"### test_vlm_online_latency
\n
"
f
"### test_vlm_online_latency
\n
"
f
'median_e2e_latency_ms:
{
res
[
"median_e2e_latency_ms"
]:.
2
f
}
ms
\n
'
f
'median_e2e_latency_ms:
{
res
[
"median_e2e_latency_ms"
]:.
2
f
}
ms
\n
'
)
)
self
.
assertLess
(
res
[
"median_e2e_latency_ms"
],
16
0
00
)
self
.
assertLess
(
res
[
"median_e2e_latency_ms"
],
16
5
00
)
if
os
.
getenv
(
"SGLANG_AMD_CI"
)
==
"1"
:
if
os
.
getenv
(
"SGLANG_AMD_CI"
)
==
"1"
:
self
.
assertLess
(
res
[
"median_ttft_ms"
],
150
)
self
.
assertLess
(
res
[
"median_ttft_ms"
],
150
)
# TODO: not set yet, need AMD machine
# TODO: not set yet, need AMD machine
...
...
test/srt/test_flashmla.py
View file @
de167cf5
...
@@ -3,7 +3,6 @@ Usage:
...
@@ -3,7 +3,6 @@ Usage:
python3 test/srt/test_flashmla.py
python3 test/srt/test_flashmla.py
"""
"""
import
os
import
unittest
import
unittest
from
types
import
SimpleNamespace
from
types
import
SimpleNamespace
...
@@ -61,7 +60,7 @@ class TestFlashMLAAttnBackend(unittest.TestCase):
...
@@ -61,7 +60,7 @@ class TestFlashMLAAttnBackend(unittest.TestCase):
metrics
=
run_eval_few_shot_gsm8k
(
args
)
metrics
=
run_eval_few_shot_gsm8k
(
args
)
print
(
metrics
)
print
(
metrics
)
self
.
assertGreater
(
metrics
[
"accuracy"
],
0.6
2
)
self
.
assertGreater
(
metrics
[
"accuracy"
],
0.6
0
)
class
TestFlashMLAAttnLatency
(
unittest
.
TestCase
):
class
TestFlashMLAAttnLatency
(
unittest
.
TestCase
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment