Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
a8ccacc8
Unverified
Commit
a8ccacc8
authored
Jan 16, 2025
by
Chang Su
Committed by
GitHub
Jan 16, 2025
Browse files
[Frontend] Fix request length check and add option to disallow auto truncation in scheduler (#2876)
parent
0427416b
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
154 additions
and
17 deletions
+154
-17
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+17
-15
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+18
-2
python/sglang/srt/managers/utils.py
python/sglang/srt/managers/utils.py
+41
-0
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+6
-0
test/srt/run_suite.py
test/srt/run_suite.py
+1
-0
test/srt/test_request_length_validation.py
test/srt/test_request_length_validation.py
+71
-0
No files found.
python/sglang/srt/managers/scheduler.py
View file @
a8ccacc8
...
...
@@ -78,6 +78,7 @@ from sglang.srt.managers.schedule_policy import (
from
sglang.srt.managers.session_controller
import
Session
from
sglang.srt.managers.tp_worker
import
TpModelWorker
from
sglang.srt.managers.tp_worker_overlap_thread
import
TpModelWorkerClient
from
sglang.srt.managers.utils
import
validate_input_length
from
sglang.srt.mem_cache.chunk_cache
import
ChunkCache
from
sglang.srt.mem_cache.radix_cache
import
RadixCache
from
sglang.srt.metrics.collector
import
SchedulerMetricsCollector
,
SchedulerStats
...
...
@@ -690,14 +691,16 @@ class Scheduler:
# By default, only return the logprobs for output tokens
req
.
logprob_start_len
=
len
(
req
.
origin_input_ids
)
-
1
# Truncate prompts that are too long
if
len
(
req
.
origin_input_ids
)
>
self
.
max_req_input_len
:
logger
.
warning
(
"Request length is longer than the KV cache pool size or "
"the max context length. Truncated. "
f
"
{
len
(
req
.
origin_input_ids
)
=
}
,
{
self
.
max_req_input_len
=
}
."
)
req
.
origin_input_ids
=
req
.
origin_input_ids
[:
self
.
max_req_input_len
]
# Validate prompts length
error_msg
=
validate_input_length
(
req
,
self
.
max_req_input_len
,
self
.
server_args
.
allow_auto_truncate
,
)
if
error_msg
:
self
.
waiting_queue
.
append
(
req
)
return
req
.
sampling_params
.
max_new_tokens
=
min
(
(
...
...
@@ -745,13 +748,12 @@ class Scheduler:
)
req
.
tokenizer
=
self
.
tokenizer
# Truncate prompts that are too long
if
len
(
req
.
origin_input_ids
)
>=
self
.
max_req_input_len
:
logger
.
warning
(
"Request length is longer than the KV cache pool size or "
"the max context length. Truncated!!!"
)
req
.
origin_input_ids
=
req
.
origin_input_ids
[:
self
.
max_req_input_len
]
# Validate prompts length
validate_input_length
(
req
,
self
.
max_req_input_len
,
self
.
server_args
.
allow_auto_truncate
,
)
self
.
waiting_queue
.
append
(
req
)
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
a8ccacc8
...
...
@@ -292,12 +292,28 @@ class TokenizerManager:
SessionParams
(
**
obj
.
session_params
)
if
obj
.
session_params
else
None
)
if
obj
.
input_ids
is
not
None
and
len
(
input_ids
)
>=
self
.
context_len
:
input_token_num
=
len
(
input_ids
)
if
input_ids
is
not
None
else
0
if
input_token_num
>=
self
.
context_len
:
raise
ValueError
(
f
"The input (
{
len
(
input_
ids
)
}
tokens) is longer than the "
f
"The input (
{
input_
token_num
}
tokens) is longer than the "
f
"model's context length (
{
self
.
context_len
}
tokens)."
)
if
(
obj
.
sampling_params
.
get
(
"max_new_tokens"
)
is
not
None
and
obj
.
sampling_params
.
get
(
"max_new_tokens"
)
+
input_token_num
>=
self
.
context_len
):
raise
ValueError
(
f
"Requested token count exceeds the model's maximum context length "
f
"of
{
self
.
context_len
}
tokens. You requested a total of "
f
"
{
obj
.
sampling_params
.
get
(
'max_new_tokens'
)
+
input_token_num
}
"
f
"tokens:
{
input_token_num
}
tokens from the input messages and "
f
"
{
obj
.
sampling_params
.
get
(
'max_new_tokens'
)
}
tokens for the "
f
"completion. Please reduce the number of tokens in the input "
f
"messages or the completion to fit within the limit."
)
# Parse sampling parameters
sampling_params
=
SamplingParams
(
**
obj
.
sampling_params
)
sampling_params
.
normalize
(
self
.
tokenizer
)
...
...
python/sglang/srt/managers/utils.py
0 → 100644
View file @
a8ccacc8
import
logging
from
typing
import
Optional
from
sglang.srt.managers.schedule_batch
import
FINISH_ABORT
,
Req
logger
=
logging
.
getLogger
(
__name__
)
def
validate_input_length
(
req
:
Req
,
max_req_input_len
:
int
,
allow_auto_truncate
:
bool
)
->
Optional
[
str
]:
"""Validate and potentially truncate input length.
Args:
req: The request containing input_ids to validate
max_req_input_len: Maximum allowed input length
allow_auto_truncate: Whether to truncate long inputs
Returns:
Error message if validation fails, None if successful
"""
if
len
(
req
.
origin_input_ids
)
>=
max_req_input_len
:
if
allow_auto_truncate
:
logger
.
warning
(
"Request length is longer than the KV cache pool size or "
"the max context length. Truncated. "
f
"
{
len
(
req
.
origin_input_ids
)
=
}
,
{
max_req_input_len
=
}
."
)
req
.
origin_input_ids
=
req
.
origin_input_ids
[:
max_req_input_len
]
return
None
else
:
error_msg
=
(
f
"Input length (
{
len
(
req
.
origin_input_ids
)
}
tokens) exceeds "
f
"the maximum allowed length (
{
max_req_input_len
}
tokens). "
f
"Use a shorter input or enable --allow-auto-truncate."
)
logger
.
error
(
error_msg
)
req
.
finished_reason
=
FINISH_ABORT
(
error_msg
)
return
error_msg
return
None
python/sglang/srt/server_args.py
View file @
a8ccacc8
...
...
@@ -157,6 +157,7 @@ class ServerArgs:
num_continuous_decode_steps
:
int
=
1
delete_ckpt_after_loading
:
bool
=
False
enable_memory_saver
:
bool
=
False
allow_auto_truncate
:
bool
=
False
def
__post_init__
(
self
):
# Set missing default values
...
...
@@ -859,6 +860,11 @@ class ServerArgs:
action
=
"store_true"
,
help
=
"Allow saving memory using release_memory_occupation and resume_memory_occupation"
,
)
parser
.
add_argument
(
"--allow-auto-truncate"
,
action
=
"store_true"
,
help
=
"Allow automatically truncating requests that exceed the maximum input length instead of returning an error."
,
)
@
classmethod
def
from_cli_args
(
cls
,
args
:
argparse
.
Namespace
):
...
...
test/srt/run_suite.py
View file @
a8ccacc8
...
...
@@ -31,6 +31,7 @@ suites = {
"test_pytorch_sampling_backend.py"
,
"test_radix_attention.py"
,
"test_release_memory_occupation.py"
,
"test_request_length_validation.py"
,
"test_retract_decode.py"
,
"test_server_args.py"
,
"test_session_control.py"
,
...
...
test/srt/test_request_length_validation.py
0 → 100644
View file @
a8ccacc8
import
unittest
import
openai
from
sglang.srt.utils
import
kill_process_tree
from
sglang.test.test_utils
import
(
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
popen_launch_server
,
)
class
TestRequestLengthValidation
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
api_key
=
"sk-123456"
# Start server with auto truncate disabled
cls
.
process
=
popen_launch_server
(
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
api_key
=
cls
.
api_key
,
other_args
=
(
"--max-total-tokens"
,
"1000"
,
"--context-length"
,
"100"
),
)
@
classmethod
def
tearDownClass
(
cls
):
kill_process_tree
(
cls
.
process
.
pid
)
def
test_input_length_validation
(
self
):
client
=
openai
.
Client
(
api_key
=
self
.
api_key
,
base_url
=
f
"
{
self
.
base_url
}
/v1"
)
long_text
=
"hello "
*
100
# Will tokenize to more than context length
with
self
.
assertRaises
(
openai
.
BadRequestError
)
as
cm
:
client
.
chat
.
completions
.
create
(
model
=
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
,
messages
=
[
{
"role"
:
"user"
,
"content"
:
long_text
},
],
temperature
=
0
,
)
self
.
assertIn
(
"is longer than the model's context length"
,
str
(
cm
.
exception
))
def
test_max_tokens_validation
(
self
):
client
=
openai
.
Client
(
api_key
=
self
.
api_key
,
base_url
=
f
"
{
self
.
base_url
}
/v1"
)
long_text
=
"hello "
with
self
.
assertRaises
(
openai
.
BadRequestError
)
as
cm
:
client
.
chat
.
completions
.
create
(
model
=
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
,
messages
=
[
{
"role"
:
"user"
,
"content"
:
long_text
},
],
temperature
=
0
,
max_tokens
=
500
,
)
self
.
assertIn
(
"Requested token count exceeds the model's maximum context"
,
str
(
cm
.
exception
),
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment