Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
747dd450
Unverified
Commit
747dd450
authored
Jul 28, 2025
by
harrisonlimh
Committed by
GitHub
Jul 28, 2025
Browse files
feat: throttle requests at scheduler based on --max_queued_requests (#7565)
parent
b5821592
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
218 additions
and
6 deletions
+218
-6
python/sglang/srt/entrypoints/http_server.py
python/sglang/srt/entrypoints/http_server.py
+13
-1
python/sglang/srt/entrypoints/openai/serving_base.py
python/sglang/srt/entrypoints/openai/serving_base.py
+5
-2
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+2
-0
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+19
-0
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+25
-3
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+5
-0
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+8
-0
python/sglang/test/test_utils.py
python/sglang/test/test_utils.py
+53
-0
test/srt/run_suite.py
test/srt/run_suite.py
+1
-0
test/srt/test_request_queue_validation.py
test/srt/test_request_queue_validation.py
+87
-0
No files found.
python/sglang/srt/entrypoints/http_server.py
View file @
747dd450
...
...
@@ -38,7 +38,7 @@ import orjson
import
requests
import
uvicorn
import
uvloop
from
fastapi
import
Depends
,
FastAPI
,
Request
,
UploadFile
from
fastapi
import
Depends
,
FastAPI
,
HTTPException
,
Request
,
UploadFile
from
fastapi.exceptions
import
RequestValidationError
from
fastapi.middleware.cors
import
CORSMiddleware
from
fastapi.responses
import
ORJSONResponse
,
Response
,
StreamingResponse
...
...
@@ -174,6 +174,18 @@ app.add_middleware(
)
@
app
.
exception_handler
(
HTTPException
)
async
def
validation_exception_handler
(
request
:
Request
,
exc
:
HTTPException
):
"""Enrich HTTP exception with status code and other details"""
error
=
ErrorResponse
(
object
=
"error"
,
message
=
exc
.
detail
,
type
=
str
(
exc
.
status_code
),
code
=
exc
.
status_code
,
)
return
ORJSONResponse
(
content
=
error
.
model_dump
(),
status_code
=
exc
.
status_code
)
# Custom exception handlers to change validation error status codes
@
app
.
exception_handler
(
RequestValidationError
)
async
def
validation_exception_handler
(
request
:
Request
,
exc
:
RequestValidationError
):
...
...
python/sglang/srt/entrypoints/openai/serving_base.py
View file @
747dd450
...
...
@@ -4,7 +4,7 @@ import uuid
from
abc
import
ABC
,
abstractmethod
from
typing
import
Any
,
Optional
,
Union
from
fastapi
import
Request
from
fastapi
import
HTTPException
,
Request
from
fastapi.responses
import
ORJSONResponse
,
StreamingResponse
from
sglang.srt.entrypoints.openai.protocol
import
ErrorResponse
,
OpenAIServingRequest
...
...
@@ -45,7 +45,10 @@ class OpenAIServingBase(ABC):
return
await
self
.
_handle_non_streaming_request
(
adapted_request
,
processed_request
,
raw_request
)
except
HTTPException
as
e
:
return
self
.
create_error_response
(
message
=
e
.
detail
,
err_type
=
str
(
e
.
status_code
),
status_code
=
e
.
status_code
)
except
Exception
as
e
:
logger
.
exception
(
f
"Error in request:
{
e
}
"
)
return
self
.
create_error_response
(
...
...
python/sglang/srt/managers/io_struct.py
View file @
747dd450
...
...
@@ -911,6 +911,8 @@ class AbortReq:
rid
:
str
=
""
# Whether to abort all requests
abort_all
:
bool
=
False
# The finished reason data
finished_reason
:
Optional
[
Dict
[
str
,
Any
]]
=
None
@
dataclass
...
...
python/sglang/srt/managers/scheduler.py
View file @
747dd450
...
...
@@ -24,6 +24,7 @@ import time
from
collections
import
defaultdict
,
deque
from
concurrent
import
futures
from
dataclasses
import
dataclass
from
http
import
HTTPStatus
from
pathlib
import
Path
from
types
import
SimpleNamespace
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
...
...
@@ -370,6 +371,7 @@ class Scheduler(
self
.
max_total_num_tokens
,
self
.
max_prefill_tokens
,
self
.
max_running_requests
,
self
.
max_queued_requests
,
self
.
max_req_len
,
self
.
max_req_input_len
,
self
.
random_seed
,
...
...
@@ -1086,6 +1088,19 @@ class Scheduler(
self
.
return_health_check_ct
+=
1
continue
# If it is a work request, accept or reject the request based on the request queue size.
if
is_work_request
(
recv_req
):
if
len
(
self
.
waiting_queue
)
+
1
>
self
.
max_queued_requests
:
abort_req
=
AbortReq
(
recv_req
.
rid
,
finished_reason
=
{
"type"
:
"abort"
,
"status_code"
:
HTTPStatus
.
SERVICE_UNAVAILABLE
,
"message"
:
"The request queue is full."
,
},
)
self
.
send_to_tokenizer
.
send_pyobj
(
abort_req
)
continue
output
=
self
.
_request_dispatcher
(
recv_req
)
if
output
is
not
None
:
if
isinstance
(
output
,
RpcReqOutput
):
...
...
@@ -2902,6 +2917,10 @@ def is_health_check_generate_req(recv_req):
return
getattr
(
recv_req
,
"rid"
,
""
).
startswith
(
"HEALTH_CHECK"
)
def
is_work_request
(
recv_req
):
return
isinstance
(
recv_req
,
(
TokenizedGenerateReqInput
,
TokenizedEmbeddingReqInput
))
def
_export_static_state
(
model
):
return
dict
(
buffers
=
[
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
747dd450
...
...
@@ -766,6 +766,19 @@ class TokenizerManager:
):
raise
ValueError
(
finish_reason
[
"message"
])
if
(
finish_reason
.
get
(
"type"
)
==
"abort"
and
finish_reason
.
get
(
"status_code"
)
==
HTTPStatus
.
SERVICE_UNAVAILABLE
):
# This is an abort request initiated by scheduler.
# Delete the key to prevent resending abort request to the scheduler and
# to ensure aborted request state is cleaned up.
del
self
.
rid_to_state
[
state
.
obj
.
rid
]
raise
fastapi
.
HTTPException
(
status_code
=
finish_reason
[
"status_code"
],
detail
=
finish_reason
[
"message"
],
)
yield
out
break
...
...
@@ -1705,8 +1718,15 @@ class TokenizerManager:
def
_handle_abort_req
(
self
,
recv_obj
):
state
=
self
.
rid_to_state
[
recv_obj
.
rid
]
state
.
finished
=
True
state
.
out_list
.
append
(
{
if
recv_obj
.
finished_reason
:
out
=
{
"meta_info"
:
{
"id"
:
recv_obj
.
rid
,
"finish_reason"
:
recv_obj
.
finished_reason
,
},
}
else
:
out
=
{
"text"
:
""
,
"meta_info"
:
{
"id"
:
recv_obj
.
rid
,
...
...
@@ -1718,7 +1738,7 @@ class TokenizerManager:
"completion_tokens"
:
0
,
},
}
)
state
.
out_list
.
append
(
out
)
state
.
event
.
set
()
def
_handle_open_session_req_output
(
self
,
recv_obj
):
...
...
@@ -1910,8 +1930,10 @@ class _Communicator(Generic[T]):
#
# | entrypoint | is_streaming | status | abort engine | cancel asyncio task | rid_to_state |
# | ---------- | ------------ | --------------- | --------------- | --------------------- | --------------------------- |
# | http | yes | validation | background task | fast api | del in _handle_abort_req |
# | http | yes | waiting queue | background task | fast api | del in _handle_abort_req |
# | http | yes | running | background task | fast api | del in _handle_batch_output |
# | http | no | validation | http exception | http exception | del in _handle_abort_req |
# | http | no | waiting queue | type 1 | type 1 exception | del in _handle_abort_req |
# | http | no | running | type 3 | type 3 exception | del in _handle_batch_output |
#
python/sglang/srt/managers/tp_worker.py
View file @
747dd450
...
...
@@ -130,6 +130,10 @@ class TpModelWorker:
self
.
model_runner
.
req_to_token_pool
.
size
,
)
assert
self
.
max_running_requests
>
0
,
"max_running_request is zero"
self
.
max_queued_requests
=
server_args
.
max_queued_requests
assert
(
self
.
max_running_requests
>
0
),
"max_queued_requests is zero. We need to be at least 1 to schedule a request."
self
.
max_req_len
=
min
(
self
.
model_config
.
context_len
-
1
,
self
.
max_total_num_tokens
-
1
,
...
...
@@ -165,6 +169,7 @@ class TpModelWorker:
self
.
max_total_num_tokens
,
self
.
max_prefill_tokens
,
self
.
max_running_requests
,
self
.
max_queued_requests
,
self
.
max_req_len
,
self
.
max_req_input_len
,
self
.
random_seed
,
...
...
python/sglang/srt/server_args.py
View file @
747dd450
...
...
@@ -19,6 +19,7 @@ import json
import
logging
import
os
import
random
import
sys
import
tempfile
from
typing
import
List
,
Literal
,
Optional
,
Union
...
...
@@ -74,6 +75,7 @@ class ServerArgs:
# Memory and scheduling
mem_fraction_static
:
Optional
[
float
]
=
None
max_running_requests
:
Optional
[
int
]
=
None
max_queued_requests
:
Optional
[
int
]
=
sys
.
maxsize
max_total_tokens
:
Optional
[
int
]
=
None
chunked_prefill_size
:
Optional
[
int
]
=
None
max_prefill_tokens
:
int
=
16384
...
...
@@ -805,6 +807,12 @@ class ServerArgs:
default
=
ServerArgs
.
max_running_requests
,
help
=
"The maximum number of running requests."
,
)
parser
.
add_argument
(
"--max-queued-requests"
,
type
=
int
,
default
=
ServerArgs
.
max_queued_requests
,
help
=
"The maximum number of queued requests. This option is ignored when using disaggregation-mode."
,
)
parser
.
add_argument
(
"--max-total-tokens"
,
type
=
int
,
...
...
python/sglang/test/test_utils.py
View file @
747dd450
...
...
@@ -19,6 +19,7 @@ from pathlib import Path
from
types
import
SimpleNamespace
from
typing
import
Awaitable
,
Callable
,
List
,
Optional
,
Tuple
import
aiohttp
import
numpy
as
np
import
requests
import
torch
...
...
@@ -1303,6 +1304,58 @@ def run_logprob_check(self: unittest.TestCase, arg: Tuple):
raise
def
send_generate_requests
(
base_url
:
str
,
num_requests
:
int
)
->
List
[
str
]:
"""Sends generate request serially and returns status codes. Max concurrency is 1."""
def
generate
():
prompt
=
"""
System: You are a helpful assistant.
User: What is the capital of France?
Assistant: The capital of France is
"""
response
=
requests
.
post
(
f
"
{
base_url
}
/generate"
,
json
=
{
"text"
:
prompt
,
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
50
,
},
},
)
return
response
.
status_code
return
[
generate
()
for
_
in
range
(
num_requests
)]
async
def
send_concurrent_generate_requests
(
base_url
:
str
,
num_requests
:
int
)
->
List
[
str
]:
"""Sends generate request concurrently and returns status codes. Max concurrency is num_requests."""
async
def
async_generate
():
async
with
aiohttp
.
ClientSession
()
as
session
:
prompt
=
"""
System: You are a helpful assistant.
User: What is the capital of France?
Assistant: The capital of France is
"""
async
with
session
.
post
(
f
"
{
base_url
}
/generate"
,
json
=
{
"text"
:
prompt
,
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
50
,
},
},
)
as
response
:
return
response
.
status
tasks
=
[
asyncio
.
create_task
(
async_generate
())
for
_
in
range
(
num_requests
)]
return
await
asyncio
.
gather
(
*
tasks
)
class
CustomTestCase
(
unittest
.
TestCase
):
def
_callTestMethod
(
self
,
method
):
max_retry
=
int
(
...
...
test/srt/run_suite.py
View file @
747dd450
...
...
@@ -86,6 +86,7 @@ suites = {
TestFile
(
"test_radix_attention.py"
,
105
),
TestFile
(
"test_regex_constrained.py"
,
64
),
TestFile
(
"test_retract_decode.py"
,
54
),
TestFile
(
"test_request_queue_validation.py"
,
30
),
TestFile
(
"test_server_args.py"
,
1
),
TestFile
(
"test_skip_tokenizer_init.py"
,
117
),
TestFile
(
"test_srt_engine.py"
,
261
),
...
...
test/srt/test_request_queue_validation.py
0 → 100644
View file @
747dd450
import
asyncio
import
os
import
re
import
unittest
from
concurrent.futures
import
ThreadPoolExecutor
from
sglang.srt.utils
import
kill_process_tree
from
sglang.test.test_utils
import
(
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
STDERR_FILENAME
,
STDOUT_FILENAME
,
CustomTestCase
,
popen_launch_server
,
send_concurrent_generate_requests
,
send_generate_requests
,
)
class
TestMaxQueuedRequests
(
CustomTestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
stdout
=
open
(
STDOUT_FILENAME
,
"w"
)
cls
.
stderr
=
open
(
STDERR_FILENAME
,
"w"
)
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
(
"--max-running-requests"
,
# Enforce max request concurrency is 1
"1"
,
"--max-queued-requests"
,
# Enforce max queued request number is 1
"1"
,
),
return_stdout_stderr
=
(
cls
.
stdout
,
cls
.
stderr
),
)
@
classmethod
def
tearDownClass
(
cls
):
kill_process_tree
(
cls
.
process
.
pid
)
cls
.
stdout
.
close
()
cls
.
stderr
.
close
()
os
.
remove
(
STDOUT_FILENAME
)
os
.
remove
(
STDERR_FILENAME
)
def
test_max_queued_requests_validation_with_serial_requests
(
self
):
"""Verify request is not throttled when the max concurrency is 1."""
status_codes
=
send_generate_requests
(
self
.
base_url
,
num_requests
=
10
,
)
for
status_code
in
status_codes
:
assert
status_code
==
200
# request shouldn't be throttled
def
test_max_queued_requests_validation_with_concurrent_requests
(
self
):
"""Verify request throttling with concurrent requests."""
status_codes
=
asyncio
.
run
(
send_concurrent_generate_requests
(
self
.
base_url
,
num_requests
=
10
)
)
assert
200
in
status_codes
assert
503
in
status_codes
assert
all
(
status_code
in
[
200
,
503
]
for
status_code
in
status_codes
)
def
test_max_running_requests_and_max_queued_request_validation
(
self
):
"""Verify running request and queued request numbers based on server logs."""
rr_pattern
=
re
.
compile
(
r
"#running-req:\s*(\d+)"
)
qr_pattern
=
re
.
compile
(
r
"#queue-req:\s*(\d+)"
)
with
open
(
STDERR_FILENAME
)
as
lines
:
for
line
in
lines
:
rr_match
,
qr_match
=
rr_pattern
.
search
(
line
),
qr_pattern
.
search
(
line
)
if
rr_match
:
assert
int
(
rr_match
.
group
(
1
))
<=
1
if
qr_match
:
assert
int
(
qr_match
.
group
(
1
))
<=
1
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment