Unverified Commit a8ccacc8 authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

[Frontend] Fix request length check and add option to disallow auto truncation in scheduler (#2876)

parent 0427416b
......@@ -78,6 +78,7 @@ from sglang.srt.managers.schedule_policy import (
from sglang.srt.managers.session_controller import Session
from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
from sglang.srt.managers.utils import validate_input_length
from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.radix_cache import RadixCache
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
......@@ -690,14 +691,16 @@ class Scheduler:
# By default, only return the logprobs for output tokens
req.logprob_start_len = len(req.origin_input_ids) - 1
# Truncate prompts that are too long
if len(req.origin_input_ids) > self.max_req_input_len:
logger.warning(
"Request length is longer than the KV cache pool size or "
"the max context length. Truncated. "
f"{len(req.origin_input_ids)=}, {self.max_req_input_len=}."
)
req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]
# Validate prompts length
error_msg = validate_input_length(
req,
self.max_req_input_len,
self.server_args.allow_auto_truncate,
)
if error_msg:
self.waiting_queue.append(req)
return
req.sampling_params.max_new_tokens = min(
(
......@@ -745,13 +748,12 @@ class Scheduler:
)
req.tokenizer = self.tokenizer
# Truncate prompts that are too long
if len(req.origin_input_ids) >= self.max_req_input_len:
logger.warning(
"Request length is longer than the KV cache pool size or "
"the max context length. Truncated!!!"
)
req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]
# Validate prompts length
validate_input_length(
req,
self.max_req_input_len,
self.server_args.allow_auto_truncate,
)
self.waiting_queue.append(req)
......
......@@ -292,12 +292,28 @@ class TokenizerManager:
SessionParams(**obj.session_params) if obj.session_params else None
)
if obj.input_ids is not None and len(input_ids) >= self.context_len:
input_token_num = len(input_ids) if input_ids is not None else 0
if input_token_num >= self.context_len:
raise ValueError(
f"The input ({len(input_ids)} tokens) is longer than the "
f"The input ({input_token_num} tokens) is longer than the "
f"model's context length ({self.context_len} tokens)."
)
if (
obj.sampling_params.get("max_new_tokens") is not None
and obj.sampling_params.get("max_new_tokens") + input_token_num
>= self.context_len
):
raise ValueError(
f"Requested token count exceeds the model's maximum context length "
f"of {self.context_len} tokens. You requested a total of "
f"{obj.sampling_params.get('max_new_tokens') + input_token_num} "
f"tokens: {input_token_num} tokens from the input messages and "
f"{obj.sampling_params.get('max_new_tokens')} tokens for the "
f"completion. Please reduce the number of tokens in the input "
f"messages or the completion to fit within the limit."
)
# Parse sampling parameters
sampling_params = SamplingParams(**obj.sampling_params)
sampling_params.normalize(self.tokenizer)
......
import logging
from typing import Optional
from sglang.srt.managers.schedule_batch import FINISH_ABORT, Req
logger = logging.getLogger(__name__)
def validate_input_length(
req: Req, max_req_input_len: int, allow_auto_truncate: bool
) -> Optional[str]:
"""Validate and potentially truncate input length.
Args:
req: The request containing input_ids to validate
max_req_input_len: Maximum allowed input length
allow_auto_truncate: Whether to truncate long inputs
Returns:
Error message if validation fails, None if successful
"""
if len(req.origin_input_ids) >= max_req_input_len:
if allow_auto_truncate:
logger.warning(
"Request length is longer than the KV cache pool size or "
"the max context length. Truncated. "
f"{len(req.origin_input_ids)=}, {max_req_input_len=}."
)
req.origin_input_ids = req.origin_input_ids[:max_req_input_len]
return None
else:
error_msg = (
f"Input length ({len(req.origin_input_ids)} tokens) exceeds "
f"the maximum allowed length ({max_req_input_len} tokens). "
f"Use a shorter input or enable --allow-auto-truncate."
)
logger.error(error_msg)
req.finished_reason = FINISH_ABORT(error_msg)
return error_msg
return None
......@@ -157,6 +157,7 @@ class ServerArgs:
num_continuous_decode_steps: int = 1
delete_ckpt_after_loading: bool = False
enable_memory_saver: bool = False
allow_auto_truncate: bool = False
def __post_init__(self):
# Set missing default values
......@@ -859,6 +860,11 @@ class ServerArgs:
action="store_true",
help="Allow saving memory using release_memory_occupation and resume_memory_occupation",
)
parser.add_argument(
"--allow-auto-truncate",
action="store_true",
help="Allow automatically truncating requests that exceed the maximum input length instead of returning an error.",
)
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
......
......@@ -31,6 +31,7 @@ suites = {
"test_pytorch_sampling_backend.py",
"test_radix_attention.py",
"test_release_memory_occupation.py",
"test_request_length_validation.py",
"test_retract_decode.py",
"test_server_args.py",
"test_session_control.py",
......
import unittest
import openai
from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
popen_launch_server,
)
class TestRequestLengthValidation(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-123456"
# Start server with auto truncate disabled
cls.process = popen_launch_server(
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
api_key=cls.api_key,
other_args=("--max-total-tokens", "1000", "--context-length", "100"),
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_input_length_validation(self):
client = openai.Client(api_key=self.api_key, base_url=f"{self.base_url}/v1")
long_text = "hello " * 100 # Will tokenize to more than context length
with self.assertRaises(openai.BadRequestError) as cm:
client.chat.completions.create(
model=DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
messages=[
{"role": "user", "content": long_text},
],
temperature=0,
)
self.assertIn("is longer than the model's context length", str(cm.exception))
def test_max_tokens_validation(self):
client = openai.Client(api_key=self.api_key, base_url=f"{self.base_url}/v1")
long_text = "hello "
with self.assertRaises(openai.BadRequestError) as cm:
client.chat.completions.create(
model=DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
messages=[
{"role": "user", "content": long_text},
],
temperature=0,
max_tokens=500,
)
self.assertIn(
"Requested token count exceeds the model's maximum context",
str(cm.exception),
)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment