Unverified Commit 747dd450 authored by harrisonlimh's avatar harrisonlimh Committed by GitHub
Browse files

feat: throttle requests at scheduler based on --max_queued_requests (#7565)

parent b5821592
......@@ -38,7 +38,7 @@ import orjson
import requests
import uvicorn
import uvloop
from fastapi import Depends, FastAPI, Request, UploadFile
from fastapi import Depends, FastAPI, HTTPException, Request, UploadFile
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
......@@ -174,6 +174,18 @@ app.add_middleware(
)
@app.exception_handler(HTTPException)
async def validation_exception_handler(request: Request, exc: HTTPException):
"""Enrich HTTP exception with status code and other details"""
error = ErrorResponse(
object="error",
message=exc.detail,
type=str(exc.status_code),
code=exc.status_code,
)
return ORJSONResponse(content=error.model_dump(), status_code=exc.status_code)
# Custom exception handlers to change validation error status codes
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
......
......@@ -4,7 +4,7 @@ import uuid
from abc import ABC, abstractmethod
from typing import Any, Optional, Union
from fastapi import Request
from fastapi import HTTPException, Request
from fastapi.responses import ORJSONResponse, StreamingResponse
from sglang.srt.entrypoints.openai.protocol import ErrorResponse, OpenAIServingRequest
......@@ -45,7 +45,10 @@ class OpenAIServingBase(ABC):
return await self._handle_non_streaming_request(
adapted_request, processed_request, raw_request
)
except HTTPException as e:
return self.create_error_response(
message=e.detail, err_type=str(e.status_code), status_code=e.status_code
)
except Exception as e:
logger.exception(f"Error in request: {e}")
return self.create_error_response(
......
......@@ -911,6 +911,8 @@ class AbortReq:
rid: str = ""
# Whether to abort all requests
abort_all: bool = False
# The finished reason data
finished_reason: Optional[Dict[str, Any]] = None
@dataclass
......
......@@ -24,6 +24,7 @@ import time
from collections import defaultdict, deque
from concurrent import futures
from dataclasses import dataclass
from http import HTTPStatus
from pathlib import Path
from types import SimpleNamespace
from typing import Dict, List, Optional, Tuple, Union
......@@ -370,6 +371,7 @@ class Scheduler(
self.max_total_num_tokens,
self.max_prefill_tokens,
self.max_running_requests,
self.max_queued_requests,
self.max_req_len,
self.max_req_input_len,
self.random_seed,
......@@ -1086,6 +1088,19 @@ class Scheduler(
self.return_health_check_ct += 1
continue
# If it is a work request, accept or reject the request based on the request queue size.
if is_work_request(recv_req):
if len(self.waiting_queue) + 1 > self.max_queued_requests:
abort_req = AbortReq(
recv_req.rid,
finished_reason={
"type": "abort",
"status_code": HTTPStatus.SERVICE_UNAVAILABLE,
"message": "The request queue is full.",
},
)
self.send_to_tokenizer.send_pyobj(abort_req)
continue
output = self._request_dispatcher(recv_req)
if output is not None:
if isinstance(output, RpcReqOutput):
......@@ -2902,6 +2917,10 @@ def is_health_check_generate_req(recv_req):
return getattr(recv_req, "rid", "").startswith("HEALTH_CHECK")
def is_work_request(recv_req):
return isinstance(recv_req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput))
def _export_static_state(model):
return dict(
buffers=[
......
......@@ -766,6 +766,19 @@ class TokenizerManager:
):
raise ValueError(finish_reason["message"])
if (
finish_reason.get("type") == "abort"
and finish_reason.get("status_code")
== HTTPStatus.SERVICE_UNAVAILABLE
):
# This is an abort request initiated by scheduler.
# Delete the key to prevent resending abort request to the scheduler and
# to ensure aborted request state is cleaned up.
del self.rid_to_state[state.obj.rid]
raise fastapi.HTTPException(
status_code=finish_reason["status_code"],
detail=finish_reason["message"],
)
yield out
break
......@@ -1705,8 +1718,15 @@ class TokenizerManager:
def _handle_abort_req(self, recv_obj):
state = self.rid_to_state[recv_obj.rid]
state.finished = True
state.out_list.append(
{
if recv_obj.finished_reason:
out = {
"meta_info": {
"id": recv_obj.rid,
"finish_reason": recv_obj.finished_reason,
},
}
else:
out = {
"text": "",
"meta_info": {
"id": recv_obj.rid,
......@@ -1718,7 +1738,7 @@ class TokenizerManager:
"completion_tokens": 0,
},
}
)
state.out_list.append(out)
state.event.set()
def _handle_open_session_req_output(self, recv_obj):
......@@ -1910,8 +1930,10 @@ class _Communicator(Generic[T]):
#
# | entrypoint | is_streaming | status | abort engine | cancel asyncio task | rid_to_state |
# | ---------- | ------------ | --------------- | --------------- | --------------------- | --------------------------- |
# | http | yes | validation | background task | fast api | del in _handle_abort_req |
# | http | yes | waiting queue | background task | fast api | del in _handle_abort_req |
# | http | yes | running | background task | fast api | del in _handle_batch_output |
# | http | no | validation | http exception | http exception | del in _handle_abort_req |
# | http | no | waiting queue | type 1 | type 1 exception | del in _handle_abort_req |
# | http | no | running | type 3 | type 3 exception | del in _handle_batch_output |
#
......@@ -130,6 +130,10 @@ class TpModelWorker:
self.model_runner.req_to_token_pool.size,
)
assert self.max_running_requests > 0, "max_running_request is zero"
self.max_queued_requests = server_args.max_queued_requests
assert (
self.max_running_requests > 0
), "max_queued_requests is zero. We need to be at least 1 to schedule a request."
self.max_req_len = min(
self.model_config.context_len - 1,
self.max_total_num_tokens - 1,
......@@ -165,6 +169,7 @@ class TpModelWorker:
self.max_total_num_tokens,
self.max_prefill_tokens,
self.max_running_requests,
self.max_queued_requests,
self.max_req_len,
self.max_req_input_len,
self.random_seed,
......
......@@ -19,6 +19,7 @@ import json
import logging
import os
import random
import sys
import tempfile
from typing import List, Literal, Optional, Union
......@@ -74,6 +75,7 @@ class ServerArgs:
# Memory and scheduling
mem_fraction_static: Optional[float] = None
max_running_requests: Optional[int] = None
max_queued_requests: Optional[int] = sys.maxsize
max_total_tokens: Optional[int] = None
chunked_prefill_size: Optional[int] = None
max_prefill_tokens: int = 16384
......@@ -805,6 +807,12 @@ class ServerArgs:
default=ServerArgs.max_running_requests,
help="The maximum number of running requests.",
)
parser.add_argument(
"--max-queued-requests",
type=int,
default=ServerArgs.max_queued_requests,
help="The maximum number of queued requests. This option is ignored when using disaggregation-mode.",
)
parser.add_argument(
"--max-total-tokens",
type=int,
......
......@@ -19,6 +19,7 @@ from pathlib import Path
from types import SimpleNamespace
from typing import Awaitable, Callable, List, Optional, Tuple
import aiohttp
import numpy as np
import requests
import torch
......@@ -1303,6 +1304,58 @@ def run_logprob_check(self: unittest.TestCase, arg: Tuple):
raise
def send_generate_requests(base_url: str, num_requests: int) -> List[str]:
"""Sends generate request serially and returns status codes. Max concurrency is 1."""
def generate():
prompt = """
System: You are a helpful assistant.
User: What is the capital of France?
Assistant: The capital of France is
"""
response = requests.post(
f"{base_url}/generate",
json={
"text": prompt,
"sampling_params": {
"temperature": 0,
"max_new_tokens": 50,
},
},
)
return response.status_code
return [generate() for _ in range(num_requests)]
async def send_concurrent_generate_requests(
base_url: str, num_requests: int
) -> List[str]:
"""Sends generate request concurrently and returns status codes. Max concurrency is num_requests."""
async def async_generate():
async with aiohttp.ClientSession() as session:
prompt = """
System: You are a helpful assistant.
User: What is the capital of France?
Assistant: The capital of France is
"""
async with session.post(
f"{base_url}/generate",
json={
"text": prompt,
"sampling_params": {
"temperature": 0,
"max_new_tokens": 50,
},
},
) as response:
return response.status
tasks = [asyncio.create_task(async_generate()) for _ in range(num_requests)]
return await asyncio.gather(*tasks)
class CustomTestCase(unittest.TestCase):
def _callTestMethod(self, method):
max_retry = int(
......
......@@ -86,6 +86,7 @@ suites = {
TestFile("test_radix_attention.py", 105),
TestFile("test_regex_constrained.py", 64),
TestFile("test_retract_decode.py", 54),
TestFile("test_request_queue_validation.py", 30),
TestFile("test_server_args.py", 1),
TestFile("test_skip_tokenizer_init.py", 117),
TestFile("test_srt_engine.py", 261),
......
import asyncio
import os
import re
import unittest
from concurrent.futures import ThreadPoolExecutor
from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
STDERR_FILENAME,
STDOUT_FILENAME,
CustomTestCase,
popen_launch_server,
send_concurrent_generate_requests,
send_generate_requests,
)
class TestMaxQueuedRequests(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.stdout = open(STDOUT_FILENAME, "w")
cls.stderr = open(STDERR_FILENAME, "w")
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=(
"--max-running-requests", # Enforce max request concurrency is 1
"1",
"--max-queued-requests", # Enforce max queued request number is 1
"1",
),
return_stdout_stderr=(cls.stdout, cls.stderr),
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
cls.stdout.close()
cls.stderr.close()
os.remove(STDOUT_FILENAME)
os.remove(STDERR_FILENAME)
def test_max_queued_requests_validation_with_serial_requests(self):
"""Verify request is not throttled when the max concurrency is 1."""
status_codes = send_generate_requests(
self.base_url,
num_requests=10,
)
for status_code in status_codes:
assert status_code == 200 # request shouldn't be throttled
def test_max_queued_requests_validation_with_concurrent_requests(self):
"""Verify request throttling with concurrent requests."""
status_codes = asyncio.run(
send_concurrent_generate_requests(self.base_url, num_requests=10)
)
assert 200 in status_codes
assert 503 in status_codes
assert all(status_code in [200, 503] for status_code in status_codes)
def test_max_running_requests_and_max_queued_request_validation(self):
"""Verify running request and queued request numbers based on server logs."""
rr_pattern = re.compile(r"#running-req:\s*(\d+)")
qr_pattern = re.compile(r"#queue-req:\s*(\d+)")
with open(STDERR_FILENAME) as lines:
for line in lines:
rr_match, qr_match = rr_pattern.search(line), qr_pattern.search(line)
if rr_match:
assert int(rr_match.group(1)) <= 1
if qr_match:
assert int(qr_match.group(1)) <= 1
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment