Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
747dd450
Unverified
Commit
747dd450
authored
Jul 28, 2025
by
harrisonlimh
Committed by
GitHub
Jul 28, 2025
Browse files
feat: throttle requests at scheduler based on --max_queued_requests (#7565)
parent
b5821592
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
218 additions
and
6 deletions
+218
-6
python/sglang/srt/entrypoints/http_server.py
python/sglang/srt/entrypoints/http_server.py
+13
-1
python/sglang/srt/entrypoints/openai/serving_base.py
python/sglang/srt/entrypoints/openai/serving_base.py
+5
-2
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+2
-0
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+19
-0
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+25
-3
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+5
-0
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+8
-0
python/sglang/test/test_utils.py
python/sglang/test/test_utils.py
+53
-0
test/srt/run_suite.py
test/srt/run_suite.py
+1
-0
test/srt/test_request_queue_validation.py
test/srt/test_request_queue_validation.py
+87
-0
No files found.
python/sglang/srt/entrypoints/http_server.py
View file @
747dd450
...
@@ -38,7 +38,7 @@ import orjson
...
@@ -38,7 +38,7 @@ import orjson
import
requests
import
requests
import
uvicorn
import
uvicorn
import
uvloop
import
uvloop
from
fastapi
import
Depends
,
FastAPI
,
Request
,
UploadFile
from
fastapi
import
Depends
,
FastAPI
,
HTTPException
,
Request
,
UploadFile
from
fastapi.exceptions
import
RequestValidationError
from
fastapi.exceptions
import
RequestValidationError
from
fastapi.middleware.cors
import
CORSMiddleware
from
fastapi.middleware.cors
import
CORSMiddleware
from
fastapi.responses
import
ORJSONResponse
,
Response
,
StreamingResponse
from
fastapi.responses
import
ORJSONResponse
,
Response
,
StreamingResponse
...
@@ -174,6 +174,18 @@ app.add_middleware(
...
@@ -174,6 +174,18 @@ app.add_middleware(
)
)
@
app
.
exception_handler
(
HTTPException
)
async
def
validation_exception_handler
(
request
:
Request
,
exc
:
HTTPException
):
"""Enrich HTTP exception with status code and other details"""
error
=
ErrorResponse
(
object
=
"error"
,
message
=
exc
.
detail
,
type
=
str
(
exc
.
status_code
),
code
=
exc
.
status_code
,
)
return
ORJSONResponse
(
content
=
error
.
model_dump
(),
status_code
=
exc
.
status_code
)
# Custom exception handlers to change validation error status codes
# Custom exception handlers to change validation error status codes
@
app
.
exception_handler
(
RequestValidationError
)
@
app
.
exception_handler
(
RequestValidationError
)
async
def
validation_exception_handler
(
request
:
Request
,
exc
:
RequestValidationError
):
async
def
validation_exception_handler
(
request
:
Request
,
exc
:
RequestValidationError
):
...
...
python/sglang/srt/entrypoints/openai/serving_base.py
View file @
747dd450
...
@@ -4,7 +4,7 @@ import uuid
...
@@ -4,7 +4,7 @@ import uuid
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
typing
import
Any
,
Optional
,
Union
from
typing
import
Any
,
Optional
,
Union
from
fastapi
import
Request
from
fastapi
import
HTTPException
,
Request
from
fastapi.responses
import
ORJSONResponse
,
StreamingResponse
from
fastapi.responses
import
ORJSONResponse
,
StreamingResponse
from
sglang.srt.entrypoints.openai.protocol
import
ErrorResponse
,
OpenAIServingRequest
from
sglang.srt.entrypoints.openai.protocol
import
ErrorResponse
,
OpenAIServingRequest
...
@@ -45,7 +45,10 @@ class OpenAIServingBase(ABC):
...
@@ -45,7 +45,10 @@ class OpenAIServingBase(ABC):
return
await
self
.
_handle_non_streaming_request
(
return
await
self
.
_handle_non_streaming_request
(
adapted_request
,
processed_request
,
raw_request
adapted_request
,
processed_request
,
raw_request
)
)
except
HTTPException
as
e
:
return
self
.
create_error_response
(
message
=
e
.
detail
,
err_type
=
str
(
e
.
status_code
),
status_code
=
e
.
status_code
)
except
Exception
as
e
:
except
Exception
as
e
:
logger
.
exception
(
f
"Error in request:
{
e
}
"
)
logger
.
exception
(
f
"Error in request:
{
e
}
"
)
return
self
.
create_error_response
(
return
self
.
create_error_response
(
...
...
python/sglang/srt/managers/io_struct.py
View file @
747dd450
...
@@ -911,6 +911,8 @@ class AbortReq:
...
@@ -911,6 +911,8 @@ class AbortReq:
rid
:
str
=
""
rid
:
str
=
""
# Whether to abort all requests
# Whether to abort all requests
abort_all
:
bool
=
False
abort_all
:
bool
=
False
# The finished reason data
finished_reason
:
Optional
[
Dict
[
str
,
Any
]]
=
None
@
dataclass
@
dataclass
...
...
python/sglang/srt/managers/scheduler.py
View file @
747dd450
...
@@ -24,6 +24,7 @@ import time
...
@@ -24,6 +24,7 @@ import time
from
collections
import
defaultdict
,
deque
from
collections
import
defaultdict
,
deque
from
concurrent
import
futures
from
concurrent
import
futures
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
http
import
HTTPStatus
from
pathlib
import
Path
from
pathlib
import
Path
from
types
import
SimpleNamespace
from
types
import
SimpleNamespace
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
...
@@ -370,6 +371,7 @@ class Scheduler(
...
@@ -370,6 +371,7 @@ class Scheduler(
self
.
max_total_num_tokens
,
self
.
max_total_num_tokens
,
self
.
max_prefill_tokens
,
self
.
max_prefill_tokens
,
self
.
max_running_requests
,
self
.
max_running_requests
,
self
.
max_queued_requests
,
self
.
max_req_len
,
self
.
max_req_len
,
self
.
max_req_input_len
,
self
.
max_req_input_len
,
self
.
random_seed
,
self
.
random_seed
,
...
@@ -1086,6 +1088,19 @@ class Scheduler(
...
@@ -1086,6 +1088,19 @@ class Scheduler(
self
.
return_health_check_ct
+=
1
self
.
return_health_check_ct
+=
1
continue
continue
# If it is a work request, accept or reject the request based on the request queue size.
if
is_work_request
(
recv_req
):
if
len
(
self
.
waiting_queue
)
+
1
>
self
.
max_queued_requests
:
abort_req
=
AbortReq
(
recv_req
.
rid
,
finished_reason
=
{
"type"
:
"abort"
,
"status_code"
:
HTTPStatus
.
SERVICE_UNAVAILABLE
,
"message"
:
"The request queue is full."
,
},
)
self
.
send_to_tokenizer
.
send_pyobj
(
abort_req
)
continue
output
=
self
.
_request_dispatcher
(
recv_req
)
output
=
self
.
_request_dispatcher
(
recv_req
)
if
output
is
not
None
:
if
output
is
not
None
:
if
isinstance
(
output
,
RpcReqOutput
):
if
isinstance
(
output
,
RpcReqOutput
):
...
@@ -2902,6 +2917,10 @@ def is_health_check_generate_req(recv_req):
...
@@ -2902,6 +2917,10 @@ def is_health_check_generate_req(recv_req):
return
getattr
(
recv_req
,
"rid"
,
""
).
startswith
(
"HEALTH_CHECK"
)
return
getattr
(
recv_req
,
"rid"
,
""
).
startswith
(
"HEALTH_CHECK"
)
def
is_work_request
(
recv_req
):
return
isinstance
(
recv_req
,
(
TokenizedGenerateReqInput
,
TokenizedEmbeddingReqInput
))
def
_export_static_state
(
model
):
def
_export_static_state
(
model
):
return
dict
(
return
dict
(
buffers
=
[
buffers
=
[
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
747dd450
...
@@ -766,6 +766,19 @@ class TokenizerManager:
...
@@ -766,6 +766,19 @@ class TokenizerManager:
):
):
raise
ValueError
(
finish_reason
[
"message"
])
raise
ValueError
(
finish_reason
[
"message"
])
if
(
finish_reason
.
get
(
"type"
)
==
"abort"
and
finish_reason
.
get
(
"status_code"
)
==
HTTPStatus
.
SERVICE_UNAVAILABLE
):
# This is an abort request initiated by scheduler.
# Delete the key to prevent resending abort request to the scheduler and
# to ensure aborted request state is cleaned up.
del
self
.
rid_to_state
[
state
.
obj
.
rid
]
raise
fastapi
.
HTTPException
(
status_code
=
finish_reason
[
"status_code"
],
detail
=
finish_reason
[
"message"
],
)
yield
out
yield
out
break
break
...
@@ -1705,8 +1718,15 @@ class TokenizerManager:
...
@@ -1705,8 +1718,15 @@ class TokenizerManager:
def
_handle_abort_req
(
self
,
recv_obj
):
def
_handle_abort_req
(
self
,
recv_obj
):
state
=
self
.
rid_to_state
[
recv_obj
.
rid
]
state
=
self
.
rid_to_state
[
recv_obj
.
rid
]
state
.
finished
=
True
state
.
finished
=
True
state
.
out_list
.
append
(
if
recv_obj
.
finished_reason
:
{
out
=
{
"meta_info"
:
{
"id"
:
recv_obj
.
rid
,
"finish_reason"
:
recv_obj
.
finished_reason
,
},
}
else
:
out
=
{
"text"
:
""
,
"text"
:
""
,
"meta_info"
:
{
"meta_info"
:
{
"id"
:
recv_obj
.
rid
,
"id"
:
recv_obj
.
rid
,
...
@@ -1718,7 +1738,7 @@ class TokenizerManager:
...
@@ -1718,7 +1738,7 @@ class TokenizerManager:
"completion_tokens"
:
0
,
"completion_tokens"
:
0
,
},
},
}
}
)
state
.
out_list
.
append
(
out
)
state
.
event
.
set
()
state
.
event
.
set
()
def
_handle_open_session_req_output
(
self
,
recv_obj
):
def
_handle_open_session_req_output
(
self
,
recv_obj
):
...
@@ -1910,8 +1930,10 @@ class _Communicator(Generic[T]):
...
@@ -1910,8 +1930,10 @@ class _Communicator(Generic[T]):
#
#
# | entrypoint | is_streaming | status | abort engine | cancel asyncio task | rid_to_state |
# | entrypoint | is_streaming | status | abort engine | cancel asyncio task | rid_to_state |
# | ---------- | ------------ | --------------- | --------------- | --------------------- | --------------------------- |
# | ---------- | ------------ | --------------- | --------------- | --------------------- | --------------------------- |
# | http | yes | validation | background task | fast api | del in _handle_abort_req |
# | http | yes | waiting queue | background task | fast api | del in _handle_abort_req |
# | http | yes | waiting queue | background task | fast api | del in _handle_abort_req |
# | http | yes | running | background task | fast api | del in _handle_batch_output |
# | http | yes | running | background task | fast api | del in _handle_batch_output |
# | http | no | validation | http exception | http exception | del in _handle_abort_req |
# | http | no | waiting queue | type 1 | type 1 exception | del in _handle_abort_req |
# | http | no | waiting queue | type 1 | type 1 exception | del in _handle_abort_req |
# | http | no | running | type 3 | type 3 exception | del in _handle_batch_output |
# | http | no | running | type 3 | type 3 exception | del in _handle_batch_output |
#
#
python/sglang/srt/managers/tp_worker.py
View file @
747dd450
...
@@ -130,6 +130,10 @@ class TpModelWorker:
...
@@ -130,6 +130,10 @@ class TpModelWorker:
self
.
model_runner
.
req_to_token_pool
.
size
,
self
.
model_runner
.
req_to_token_pool
.
size
,
)
)
assert
self
.
max_running_requests
>
0
,
"max_running_request is zero"
assert
self
.
max_running_requests
>
0
,
"max_running_request is zero"
self
.
max_queued_requests
=
server_args
.
max_queued_requests
assert
(
self
.
max_running_requests
>
0
),
"max_queued_requests is zero. We need to be at least 1 to schedule a request."
self
.
max_req_len
=
min
(
self
.
max_req_len
=
min
(
self
.
model_config
.
context_len
-
1
,
self
.
model_config
.
context_len
-
1
,
self
.
max_total_num_tokens
-
1
,
self
.
max_total_num_tokens
-
1
,
...
@@ -165,6 +169,7 @@ class TpModelWorker:
...
@@ -165,6 +169,7 @@ class TpModelWorker:
self
.
max_total_num_tokens
,
self
.
max_total_num_tokens
,
self
.
max_prefill_tokens
,
self
.
max_prefill_tokens
,
self
.
max_running_requests
,
self
.
max_running_requests
,
self
.
max_queued_requests
,
self
.
max_req_len
,
self
.
max_req_len
,
self
.
max_req_input_len
,
self
.
max_req_input_len
,
self
.
random_seed
,
self
.
random_seed
,
...
...
python/sglang/srt/server_args.py
View file @
747dd450
...
@@ -19,6 +19,7 @@ import json
...
@@ -19,6 +19,7 @@ import json
import
logging
import
logging
import
os
import
os
import
random
import
random
import
sys
import
tempfile
import
tempfile
from
typing
import
List
,
Literal
,
Optional
,
Union
from
typing
import
List
,
Literal
,
Optional
,
Union
...
@@ -74,6 +75,7 @@ class ServerArgs:
...
@@ -74,6 +75,7 @@ class ServerArgs:
# Memory and scheduling
# Memory and scheduling
mem_fraction_static
:
Optional
[
float
]
=
None
mem_fraction_static
:
Optional
[
float
]
=
None
max_running_requests
:
Optional
[
int
]
=
None
max_running_requests
:
Optional
[
int
]
=
None
max_queued_requests
:
Optional
[
int
]
=
sys
.
maxsize
max_total_tokens
:
Optional
[
int
]
=
None
max_total_tokens
:
Optional
[
int
]
=
None
chunked_prefill_size
:
Optional
[
int
]
=
None
chunked_prefill_size
:
Optional
[
int
]
=
None
max_prefill_tokens
:
int
=
16384
max_prefill_tokens
:
int
=
16384
...
@@ -805,6 +807,12 @@ class ServerArgs:
...
@@ -805,6 +807,12 @@ class ServerArgs:
default
=
ServerArgs
.
max_running_requests
,
default
=
ServerArgs
.
max_running_requests
,
help
=
"The maximum number of running requests."
,
help
=
"The maximum number of running requests."
,
)
)
parser
.
add_argument
(
"--max-queued-requests"
,
type
=
int
,
default
=
ServerArgs
.
max_queued_requests
,
help
=
"The maximum number of queued requests. This option is ignored when using disaggregation-mode."
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--max-total-tokens"
,
"--max-total-tokens"
,
type
=
int
,
type
=
int
,
...
...
python/sglang/test/test_utils.py
View file @
747dd450
...
@@ -19,6 +19,7 @@ from pathlib import Path
...
@@ -19,6 +19,7 @@ from pathlib import Path
from
types
import
SimpleNamespace
from
types
import
SimpleNamespace
from
typing
import
Awaitable
,
Callable
,
List
,
Optional
,
Tuple
from
typing
import
Awaitable
,
Callable
,
List
,
Optional
,
Tuple
import
aiohttp
import
numpy
as
np
import
numpy
as
np
import
requests
import
requests
import
torch
import
torch
...
@@ -1303,6 +1304,58 @@ def run_logprob_check(self: unittest.TestCase, arg: Tuple):
...
@@ -1303,6 +1304,58 @@ def run_logprob_check(self: unittest.TestCase, arg: Tuple):
raise
raise
def
send_generate_requests
(
base_url
:
str
,
num_requests
:
int
)
->
List
[
str
]:
"""Sends generate request serially and returns status codes. Max concurrency is 1."""
def
generate
():
prompt
=
"""
System: You are a helpful assistant.
User: What is the capital of France?
Assistant: The capital of France is
"""
response
=
requests
.
post
(
f
"
{
base_url
}
/generate"
,
json
=
{
"text"
:
prompt
,
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
50
,
},
},
)
return
response
.
status_code
return
[
generate
()
for
_
in
range
(
num_requests
)]
async
def
send_concurrent_generate_requests
(
base_url
:
str
,
num_requests
:
int
)
->
List
[
str
]:
"""Sends generate request concurrently and returns status codes. Max concurrency is num_requests."""
async
def
async_generate
():
async
with
aiohttp
.
ClientSession
()
as
session
:
prompt
=
"""
System: You are a helpful assistant.
User: What is the capital of France?
Assistant: The capital of France is
"""
async
with
session
.
post
(
f
"
{
base_url
}
/generate"
,
json
=
{
"text"
:
prompt
,
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
50
,
},
},
)
as
response
:
return
response
.
status
tasks
=
[
asyncio
.
create_task
(
async_generate
())
for
_
in
range
(
num_requests
)]
return
await
asyncio
.
gather
(
*
tasks
)
class
CustomTestCase
(
unittest
.
TestCase
):
class
CustomTestCase
(
unittest
.
TestCase
):
def
_callTestMethod
(
self
,
method
):
def
_callTestMethod
(
self
,
method
):
max_retry
=
int
(
max_retry
=
int
(
...
...
test/srt/run_suite.py
View file @
747dd450
...
@@ -86,6 +86,7 @@ suites = {
...
@@ -86,6 +86,7 @@ suites = {
TestFile
(
"test_radix_attention.py"
,
105
),
TestFile
(
"test_radix_attention.py"
,
105
),
TestFile
(
"test_regex_constrained.py"
,
64
),
TestFile
(
"test_regex_constrained.py"
,
64
),
TestFile
(
"test_retract_decode.py"
,
54
),
TestFile
(
"test_retract_decode.py"
,
54
),
TestFile
(
"test_request_queue_validation.py"
,
30
),
TestFile
(
"test_server_args.py"
,
1
),
TestFile
(
"test_server_args.py"
,
1
),
TestFile
(
"test_skip_tokenizer_init.py"
,
117
),
TestFile
(
"test_skip_tokenizer_init.py"
,
117
),
TestFile
(
"test_srt_engine.py"
,
261
),
TestFile
(
"test_srt_engine.py"
,
261
),
...
...
test/srt/test_request_queue_validation.py
0 → 100644
View file @
747dd450
import
asyncio
import
os
import
re
import
unittest
from
concurrent.futures
import
ThreadPoolExecutor
from
sglang.srt.utils
import
kill_process_tree
from
sglang.test.test_utils
import
(
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
STDERR_FILENAME
,
STDOUT_FILENAME
,
CustomTestCase
,
popen_launch_server
,
send_concurrent_generate_requests
,
send_generate_requests
,
)
class
TestMaxQueuedRequests
(
CustomTestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
stdout
=
open
(
STDOUT_FILENAME
,
"w"
)
cls
.
stderr
=
open
(
STDERR_FILENAME
,
"w"
)
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
(
"--max-running-requests"
,
# Enforce max request concurrency is 1
"1"
,
"--max-queued-requests"
,
# Enforce max queued request number is 1
"1"
,
),
return_stdout_stderr
=
(
cls
.
stdout
,
cls
.
stderr
),
)
@
classmethod
def
tearDownClass
(
cls
):
kill_process_tree
(
cls
.
process
.
pid
)
cls
.
stdout
.
close
()
cls
.
stderr
.
close
()
os
.
remove
(
STDOUT_FILENAME
)
os
.
remove
(
STDERR_FILENAME
)
def
test_max_queued_requests_validation_with_serial_requests
(
self
):
"""Verify request is not throttled when the max concurrency is 1."""
status_codes
=
send_generate_requests
(
self
.
base_url
,
num_requests
=
10
,
)
for
status_code
in
status_codes
:
assert
status_code
==
200
# request shouldn't be throttled
def
test_max_queued_requests_validation_with_concurrent_requests
(
self
):
"""Verify request throttling with concurrent requests."""
status_codes
=
asyncio
.
run
(
send_concurrent_generate_requests
(
self
.
base_url
,
num_requests
=
10
)
)
assert
200
in
status_codes
assert
503
in
status_codes
assert
all
(
status_code
in
[
200
,
503
]
for
status_code
in
status_codes
)
def
test_max_running_requests_and_max_queued_request_validation
(
self
):
"""Verify running request and queued request numbers based on server logs."""
rr_pattern
=
re
.
compile
(
r
"#running-req:\s*(\d+)"
)
qr_pattern
=
re
.
compile
(
r
"#queue-req:\s*(\d+)"
)
with
open
(
STDERR_FILENAME
)
as
lines
:
for
line
in
lines
:
rr_match
,
qr_match
=
rr_pattern
.
search
(
line
),
qr_pattern
.
search
(
line
)
if
rr_match
:
assert
int
(
rr_match
.
group
(
1
))
<=
1
if
qr_match
:
assert
int
(
qr_match
.
group
(
1
))
<=
1
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment