"vscode:/vscode.git/clone" did not exist on "1a1f38e3fe787fdceb205802ef946d8df79966c2"
Unverified Commit 7023f413 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Clean up (#422)

parent 09deb20d
......@@ -2,10 +2,11 @@ import argparse
from sglang.srt.server import ServerArgs, launch_server
if __name__ == "__main__":
parser = argparse.ArgumentParser()
ServerArgs.add_cli_args(parser)
args = parser.parse_args()
server_args = ServerArgs.from_cli_args(args)
launch_server(server_args, None)
launch_server(server_args, None)
\ No newline at end of file
......@@ -37,6 +37,7 @@ from sglang.srt.utils import (
)
logger = logging.getLogger("model_rpc")
logging.getLogger("vllm.utils").setLevel(logging.WARN)
class ModelRpcServer:
......@@ -113,7 +114,7 @@ class ModelRpcServer:
f"max_prefill_num_token={self.max_prefill_num_token}, "
f"context_len={self.model_config.context_len}, "
)
logger.info(server_args.get_optional_modes_logging())
logger.info(f"server_args: {server_args.print_mode_args()}")
# Init cache
self.tree_cache = RadixCache(disable=server_args.disable_radix_cache)
......
......@@ -28,7 +28,6 @@ QUANTIZATION_CONFIG_MAPPING = {
logger = logging.getLogger("model_runner")
# for server args in model endpoints
global_server_args_dict: dict = None
......@@ -276,9 +275,6 @@ class ModelRunner:
init_method=f"tcp://127.0.0.1:{self.nccl_port}",
)
# A small all_reduce for warmup.
if self.tp_size > 1:
torch.distributed.all_reduce(torch.zeros(1).cuda())
initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
total_gpu_memory = get_available_gpu_memory(
......
......@@ -15,15 +15,11 @@ setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
import aiohttp
import psutil
import pydantic
import requests
import uvicorn
import uvloop
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import Response, StreamingResponse
from pydantic import BaseModel
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse
from sglang.backend.runtime_endpoint import RuntimeEndpoint
from sglang.srt.constrained import disable_cache
......@@ -37,7 +33,7 @@ from sglang.srt.conversation import (
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
from sglang.srt.managers.io_struct import DetokenizeReqInput, GenerateReqInput
from sglang.srt.managers.openai_protocol import (
from sglang.srt.openai_protocol import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseChoice,
......@@ -56,45 +52,24 @@ from sglang.srt.managers.openai_protocol import (
from sglang.srt.managers.router.manager import start_router_process
from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import enable_show_time_cost, handle_port_init
from sglang.srt.utils import (
enable_show_time_cost,
allocate_init_ports,
jsonify_pydantic_model,
assert_pkg_version,
get_exception_traceback,
API_KEY_HEADER_NAME,
APIKeyValidatorMiddleware
)
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
API_KEY_HEADER_NAME = "X-API-Key"
class APIKeyValidatorMiddleware(BaseHTTPMiddleware):
def __init__(self, app, api_key: str):
super().__init__(app)
self.api_key = api_key
async def dispatch(self, request: Request, call_next):
# extract API key from the request headers
api_key_header = request.headers.get(API_KEY_HEADER_NAME)
if not api_key_header or api_key_header != self.api_key:
return JSONResponse(
status_code=403,
content={"detail": "Invalid API Key"},
)
response = await call_next(request)
return response
app = FastAPI()
tokenizer_manager = None
chat_template_name = None
# FIXME: Remove this once we drop support for pydantic 1.x
IS_PYDANTIC_1 = int(pydantic.VERSION.split(".")[0]) == 1
def jsonify_pydantic_model(obj: BaseModel):
if IS_PYDANTIC_1:
return obj.json(ensure_ascii=False)
return obj.model_dump_json()
@app.get("/health")
async def health() -> Response:
"""Health check."""
......@@ -124,6 +99,31 @@ async def flush_cache():
)
async def stream_generator(obj: GenerateReqInput):
async for out in tokenizer_manager.generate_request(obj):
await handle_token_logprobs_results(obj, out)
yield out
@app.post("/generate")
async def generate_request(obj: GenerateReqInput):
obj.post_init()
if obj.stream:
async def stream_results():
async for out in stream_generator(obj):
yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(stream_results(), media_type="text/event-stream")
ret = await tokenizer_manager.generate_request(obj).__anext__()
await handle_token_logprobs_results(obj, ret)
return ret
async def detokenize_logprob_tokens(token_logprobs, decode_to_text):
if not decode_to_text:
return [(logprob, token_id, None) for logprob, token_id in token_logprobs]
......@@ -175,68 +175,6 @@ async def handle_token_logprobs_results(obj: GenerateReqInput, ret):
await convert_style(r, obj.return_text_in_logprobs)
async def stream_generator(obj: GenerateReqInput):
async for out in tokenizer_manager.generate_request(obj):
await handle_token_logprobs_results(obj, out)
yield out
async def make_openai_style_logprobs(
prefill_token_logprobs=None,
decode_token_logprobs=None,
prefill_top_logprobs=None,
decode_top_logprobs=None,
):
ret_logprobs = LogProbs()
def append_token_logprobs(token_logprobs):
for logprob, _, token_text in token_logprobs:
ret_logprobs.tokens.append(token_text)
ret_logprobs.token_logprobs.append(logprob)
# Not Supported yet
ret_logprobs.text_offset.append(-1)
def append_top_logprobs(top_logprobs):
for tokens in top_logprobs:
if tokens is not None:
ret_logprobs.top_logprobs.append(
{token[2]: token[0] for token in tokens}
)
else:
ret_logprobs.top_logprobs.append(None)
if prefill_token_logprobs is not None:
append_token_logprobs(prefill_token_logprobs)
if decode_token_logprobs is not None:
append_token_logprobs(decode_token_logprobs)
if prefill_top_logprobs is not None:
append_top_logprobs(prefill_top_logprobs)
if decode_top_logprobs is not None:
append_top_logprobs(decode_top_logprobs)
return ret_logprobs
@app.post("/generate")
async def generate_request(obj: GenerateReqInput):
obj.post_init()
if obj.stream:
async def stream_results():
async for out in stream_generator(obj):
yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(stream_results(), media_type="text/event-stream")
ret = await tokenizer_manager.generate_request(obj).__anext__()
await handle_token_logprobs_results(obj, ret)
return ret
@app.post("/v1/completions")
async def v1_completions(raw_request: Request):
request_json = await raw_request.json()
......@@ -500,27 +438,97 @@ async def v1_chat_completions(raw_request: Request):
return response
def launch_server(server_args: ServerArgs, pipe_finish_writer):
global tokenizer_manager
async def make_openai_style_logprobs(
prefill_token_logprobs=None,
decode_token_logprobs=None,
prefill_top_logprobs=None,
decode_top_logprobs=None,
):
ret_logprobs = LogProbs()
def append_token_logprobs(token_logprobs):
for logprob, _, token_text in token_logprobs:
ret_logprobs.tokens.append(token_text)
ret_logprobs.token_logprobs.append(logprob)
# Not Supported yet
ret_logprobs.text_offset.append(-1)
def append_top_logprobs(top_logprobs):
for tokens in top_logprobs:
if tokens is not None:
ret_logprobs.top_logprobs.append(
{token[2]: token[0] for token in tokens}
)
else:
ret_logprobs.top_logprobs.append(None)
if prefill_token_logprobs is not None:
append_token_logprobs(prefill_token_logprobs)
if decode_token_logprobs is not None:
append_token_logprobs(decode_token_logprobs)
if prefill_top_logprobs is not None:
append_top_logprobs(prefill_top_logprobs)
if decode_top_logprobs is not None:
append_top_logprobs(decode_top_logprobs)
return ret_logprobs
def load_chat_template_for_openai_api(chat_template_arg):
global chat_template_name
if server_args.enable_flashinfer:
from sglang.srt.utils import assert_pkg_version
assert_pkg_version("flashinfer", "0.0.4")
print(f"Use chat template: {chat_template_arg}")
if not chat_template_exists(chat_template_arg):
if not os.path.exists(chat_template_arg):
raise RuntimeError(
f"Chat template {chat_template_arg} is not a built-in template name "
"or a valid chat template file path."
)
with open(chat_template_arg, "r") as filep:
template = json.load(filep)
try:
sep_style = SeparatorStyle[template["sep_style"]]
except KeyError:
raise ValueError(
f"Unknown separator style: {template['sep_style']}"
) from None
register_conv_template(
Conversation(
name=template["name"],
system_template=template["system"] + "\n{system_message}",
system_message=template.get("system_message", ""),
roles=(template["user"], template["assistant"]),
sep_style=sep_style,
sep=template.get("sep", "\n"),
stop_str=template["stop_str"],
),
override=True,
)
chat_template_name = template["name"]
else:
chat_template_name = chat_template_arg
# start show time thread
def launch_server(server_args: ServerArgs, pipe_finish_writer):
global tokenizer_manager
# Set global environments
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
if server_args.show_time_cost:
enable_show_time_cost()
# disable disk cache if needed
if server_args.disable_disk_cache:
disable_cache()
if server_args.enable_flashinfer:
assert_pkg_version("flashinfer", "0.0.4")
if server_args.chat_template:
# TODO: replace this with huggingface transformers template
load_chat_template_for_openai_api(server_args.chat_template)
# Handle ports
server_args.port, server_args.additional_ports = handle_port_init(
# Allocate ports
server_args.port, server_args.additional_ports = allocate_init_ports(
server_args.port, server_args.additional_ports, server_args.tp_size
)
port_args = PortArgs(
tokenizer_port=server_args.additional_ports[0],
router_port=server_args.additional_ports[1],
......@@ -529,39 +537,6 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer):
model_rpc_ports=server_args.additional_ports[4:],
)
# Load chat template if needed
if server_args.chat_template is not None:
print(f"Use chat template: {server_args.chat_template}")
if not chat_template_exists(server_args.chat_template):
if not os.path.exists(server_args.chat_template):
raise RuntimeError(
f"Chat template {server_args.chat_template} is not a built-in template name "
"or a valid chat template file path."
)
with open(server_args.chat_template, "r") as filep:
template = json.load(filep)
try:
sep_style = SeparatorStyle[template["sep_style"]]
except KeyError:
raise ValueError(
f"Unknown separator style: {template['sep_style']}"
) from None
register_conv_template(
Conversation(
name=template["name"],
system_template=template["system"] + "\n{system_message}",
system_message=template.get("system_message", ""),
roles=(template["user"], template["assistant"]),
sep_style=sep_style,
sep=template.get("sep", "\n"),
stop_str=template["stop_str"],
),
override=True,
)
chat_template_name = template["name"]
else:
chat_template_name = server_args.chat_template
# Launch processes
tokenizer_manager = TokenizerManager(server_args, port_args)
pipe_router_reader, pipe_router_writer = mp.Pipe(duplex=False)
......@@ -593,31 +568,21 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer):
if router_init_state != "init ok" or detoken_init_state != "init ok":
proc_router.kill()
proc_detoken.kill()
print("router init state:", router_init_state)
print("detoken init state:", detoken_init_state)
print(f"Initialization failed. router_init_state: {router_init_state}", flush=True)
print(f"Initialization failed. detoken_init_state: {detoken_init_state}", flush=True)
sys.exit(1)
assert proc_router.is_alive() and proc_detoken.is_alive()
if server_args.api_key and server_args.api_key != "":
app.add_middleware(APIKeyValidatorMiddleware, api_key=server_args.api_key)
def _launch_server():
uvicorn.run(
app,
host=server_args.host,
port=server_args.port,
log_level=server_args.log_level,
timeout_keep_alive=5,
loop="uvloop",
)
def _wait_and_warmup():
headers = {}
url = server_args.url()
if server_args.api_key and server_args.api_key != "":
if server_args.api_key:
headers[API_KEY_HEADER_NAME] = server_args.api_key
# Wait until the server is launched
for _ in range(120):
time.sleep(0.5)
try:
......@@ -625,16 +590,9 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer):
break
except requests.exceptions.RequestException as e:
pass
else:
if pipe_finish_writer is not None:
pipe_finish_writer.send(str(e))
else:
print(e, flush=True)
return
# Warmup
# Send a warmup request
try:
# print("Warmup...", flush=True)
res = requests.post(
url + "/generate",
json={
......@@ -647,14 +605,12 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer):
headers=headers,
timeout=60,
)
# print(f"Warmup done. model response: {res.json()['text']}")
# print("=" * 20, "Server is ready", "=" * 20, flush=True)
except requests.exceptions.RequestException as e:
assert res.status_code == 200
except Exception as e:
if pipe_finish_writer is not None:
pipe_finish_writer.send(str(e))
else:
print(e, flush=True)
return
pipe_finish_writer.send(get_exception_traceback())
print(f"Initialization failed. warmup error: {e}")
raise e
if pipe_finish_writer is not None:
pipe_finish_writer.send("init ok")
......@@ -662,7 +618,14 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer):
t = threading.Thread(target=_wait_and_warmup)
t.start()
try:
_launch_server()
uvicorn.run(
app,
host=server_args.host,
port=server_args.port,
log_level=server_args.log_level,
timeout_keep_alive=5,
loop="uvloop",
)
finally:
t.join()
......@@ -670,52 +633,16 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer):
class Runtime:
def __init__(
self,
model_path: str,
tokenizer_path: Optional[str] = None,
load_format: str = "auto",
tokenizer_mode: str = "auto",
trust_remote_code: bool = True,
mem_fraction_static: float = ServerArgs.mem_fraction_static,
max_prefill_num_token: int = ServerArgs.max_prefill_num_token,
context_length: int = ServerArgs.context_length,
tp_size: int = 1,
schedule_heuristic: str = "lpm",
attention_reduce_in_fp32: bool = False,
random_seed: int = 42,
log_level: str = "error",
disable_radix_cache: bool = False,
enable_flashinfer: bool = False,
disable_regex_jump_forward: bool = False,
disable_disk_cache: bool = False,
api_key: str = "",
port: Optional[int] = None,
additional_ports: Optional[Union[List[int], int]] = None,
log_evel="error",
*args,
**kwargs,
):
host = "127.0.0.1"
port, additional_ports = handle_port_init(port, additional_ports, tp_size)
self.server_args = ServerArgs(
model_path=model_path,
tokenizer_path=tokenizer_path,
host=host,
port=port,
additional_ports=additional_ports,
load_format=load_format,
tokenizer_mode=tokenizer_mode,
trust_remote_code=trust_remote_code,
mem_fraction_static=mem_fraction_static,
max_prefill_num_token=max_prefill_num_token,
context_length=context_length,
tp_size=tp_size,
schedule_heuristic=schedule_heuristic,
attention_reduce_in_fp32=attention_reduce_in_fp32,
random_seed=random_seed,
log_level=log_level,
disable_radix_cache=disable_radix_cache,
enable_flashinfer=enable_flashinfer,
disable_regex_jump_forward=disable_regex_jump_forward,
disable_disk_cache=disable_disk_cache,
api_key=api_key,
)
"""See the arguments in server_args.py::ServerArgs"""
self.server_args = ServerArgs(*args, log_level=log_evel, **kwargs)
# Pre-allocate ports
self.server_args.port, self.server_args.additional_ports = allocate_init_ports(
self.server_args.port, self.server_args.additional_ports, self.server_args.tp_size)
self.url = self.server_args.url()
self.generate_url = (
......@@ -736,7 +663,7 @@ class Runtime:
if init_state != "init ok":
self.shutdown()
raise RuntimeError("Launch failed. Please see the error messages above.")
raise RuntimeError("Initialization failed. Please see the error messages above.")
self.endpoint = RuntimeEndpoint(self.url)
......@@ -765,13 +692,12 @@ class Runtime:
self,
prompt: str,
sampling_params,
) -> None:
):
json_data = {
"text": prompt,
"sampling_params": sampling_params,
"stream": True,
}
pos = 0
timeout = aiohttp.ClientTimeout(total=3 * 3600)
......
"""The arguments of the server."""
import argparse
import dataclasses
from typing import List, Optional, Union
......@@ -5,33 +7,44 @@ from typing import List, Optional, Union
@dataclasses.dataclass
class ServerArgs:
# Model and tokenizer
model_path: str
tokenizer_path: Optional[str] = None
host: str = "127.0.0.1"
port: int = 30000
additional_ports: Optional[Union[List[int], int]] = None
load_format: str = "auto"
tokenizer_mode: str = "auto"
chat_template: Optional[str] = None
trust_remote_code: bool = True
context_length: Optional[int] = None
# Port
host: str = "127.0.0.1"
port: int = 30000
additional_ports: Optional[Union[List[int], int]] = None
# Memory and scheduling
mem_fraction_static: Optional[float] = None
max_prefill_num_token: Optional[int] = None
context_length: Optional[int] = None
tp_size: int = 1
schedule_heuristic: str = "lpm"
schedule_conservativeness: float = 1.0
attention_reduce_in_fp32: bool = False
random_seed: int = 42
# Other runtime options
tp_size: int = 1
stream_interval: int = 8
random_seed: int = 42
# Logging
log_level: str = "info"
disable_log_stats: bool = False
log_stats_interval: int = 10
log_level: str = "info"
api_key: str = ""
show_time_cost: bool = False
# optional modes
disable_radix_cache: bool = False
# Other
api_key: str = ""
# Optimization/debug options
enable_flashinfer: bool = False
attention_reduce_in_fp32: bool = False
disable_radix_cache: bool = False
disable_regex_jump_forward: bool = False
disable_disk_cache: bool = False
......@@ -66,15 +79,16 @@ class ServerArgs:
default=ServerArgs.tokenizer_path,
help="The path of the tokenizer.",
)
parser.add_argument("--host", type=str, default=ServerArgs.host)
parser.add_argument("--port", type=int, default=ServerArgs.port)
# we want to be able to pass a list of ports
parser.add_argument("--host", type=str, default=ServerArgs.host,
help="The host of the server.")
parser.add_argument("--port", type=int, default=ServerArgs.port,
help="The port of the server.")
parser.add_argument(
"--additional-ports",
type=int,
nargs="*",
default=[],
help="Additional ports specified for launching server.",
help="Additional ports specified for the server.",
)
parser.add_argument(
"--load-format",
......@@ -112,6 +126,12 @@ class ServerArgs:
action="store_true",
help="Whether or not to allow for custom models defined on the Hub in their own modeling files.",
)
parser.add_argument(
"--context-length",
type=int,
default=ServerArgs.context_length,
help="The model's maximum context length. Defaults to None (will use the value from the model's config.json instead).",
)
parser.add_argument(
"--mem-fraction-static",
type=float,
......@@ -124,18 +144,6 @@ class ServerArgs:
default=ServerArgs.max_prefill_num_token,
help="The maximum number of tokens in a prefill batch. The real bound will be the maximum of this value and the model's maximum context length.",
)
parser.add_argument(
"--context-length",
type=int,
default=ServerArgs.context_length,
help="The model's maximum context length. Use this to reduce the context length to save memory. Defaults to None (will use the value from the model's config.json instead).",
)
parser.add_argument(
"--tp-size",
type=int,
default=ServerArgs.tp_size,
help="Tensor parallelism degree.",
)
parser.add_argument(
"--schedule-heuristic",
type=str,
......@@ -149,15 +157,10 @@ class ServerArgs:
help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see requests being retracted frequently.",
)
parser.add_argument(
"--random-seed",
"--tp-size",
type=int,
default=ServerArgs.random_seed,
help="Random seed.",
)
parser.add_argument(
"--attention-reduce-in-fp32",
action="store_true",
help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16.",
default=ServerArgs.tp_size,
help="Tensor parallelism size.",
)
parser.add_argument(
"--stream-interval",
......@@ -165,11 +168,17 @@ class ServerArgs:
default=ServerArgs.stream_interval,
help="The interval (or buffer size) for streaming in terms of the token length. A smaller value makes streaming smoother, while a larger value makes the throughput higher",
)
parser.add_argument(
"--random-seed",
type=int,
default=ServerArgs.random_seed,
help="Random seed.",
)
parser.add_argument(
"--log-level",
type=str,
default=ServerArgs.log_level,
help="Log level",
help="Logging level",
)
parser.add_argument(
"--disable-log-stats",
......@@ -182,28 +191,33 @@ class ServerArgs:
default=ServerArgs.log_stats_interval,
help="Log stats interval in second.",
)
parser.add_argument(
"--show-time-cost",
action="store_true",
help="Show time cost of custom marks",
)
parser.add_argument(
"--api-key",
type=str,
default=ServerArgs.api_key,
help="Set API Key",
help="Set API key of the server",
)
# Optimization/debug options
parser.add_argument(
"--show-time-cost",
"--enable-flashinfer",
action="store_true",
help="Show time cost of custom marks",
help="Enable flashinfer inference kernels",
)
# optional modes
parser.add_argument(
"--disable-radix-cache",
"--attention-reduce-in-fp32",
action="store_true",
help="Disable RadixAttention",
help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16.",
)
parser.add_argument(
"--enable-flashinfer",
"--disable-radix-cache",
action="store_true",
help="Enable flashinfer inference kernels",
help="Disable RadixAttention",
)
parser.add_argument(
"--disable-regex-jump-forward",
......@@ -224,13 +238,13 @@ class ServerArgs:
def url(self):
return f"http://{self.host}:{self.port}"
def get_optional_modes_logging(self):
def print_mode_args(self):
return (
f"disable_radix_cache={self.disable_radix_cache}, "
f"enable_flashinfer={self.enable_flashinfer}, "
f"attention_reduce_in_fp32={self.attention_reduce_in_fp32}"
f"disable_radix_cache={self.disable_radix_cache}, "
f"disable_regex_jump_forward={self.disable_regex_jump_forward}, "
f"disable_disk_cache={self.disable_disk_cache}, "
f"attention_reduce_in_fp32={self.attention_reduce_in_fp32}"
)
......@@ -240,4 +254,4 @@ class PortArgs:
router_port: int
detokenizer_port: int
nccl_port: int
model_rpc_ports: List[int]
model_rpc_ports: List[int]
\ No newline at end of file
......@@ -10,9 +10,12 @@ from io import BytesIO
from typing import List, Optional
import numpy as np
import pydantic
import requests
import torch
from packaging import version as pkg_version
from pydantic import BaseModel
from starlette.middleware.base import BaseHTTPMiddleware
show_time_cost = False
time_infos = {}
......@@ -120,7 +123,7 @@ def check_port(port):
return False
def handle_port_init(
def allocate_init_ports(
port: Optional[int] = None,
additional_ports: Optional[List[int]] = None,
tp_size: int = 1,
......@@ -159,8 +162,6 @@ def get_exception_traceback():
def get_int_token_logit_bias(tokenizer, vocab_size):
from transformers import LlamaTokenizer, LlamaTokenizerFast
# a bug when model's vocab size > tokenizer.vocab_size
vocab_size = tokenizer.vocab_size
logit_bias = np.zeros(vocab_size, dtype=np.float32)
......@@ -281,3 +282,32 @@ def assert_pkg_version(pkg: str, min_version: str):
)
except PackageNotFoundError:
raise Exception(f"{pkg} with minimum required version {min_version} is not installed")
API_KEY_HEADER_NAME = "X-API-Key"
class APIKeyValidatorMiddleware(BaseHTTPMiddleware):
def __init__(self, app, api_key: str):
super().__init__(app)
self.api_key = api_key
async def dispatch(self, request, call_next):
# extract API key from the request headers
api_key_header = request.headers.get(API_KEY_HEADER_NAME)
if not api_key_header or api_key_header != self.api_key:
return JSONResponse(
status_code=403,
content={"detail": "Invalid API Key"},
)
response = await call_next(request)
return response
# FIXME: Remove this once we drop support for pydantic 1.x
IS_PYDANTIC_1 = int(pydantic.VERSION.split(".")[0]) == 1
def jsonify_pydantic_model(obj: BaseModel):
if IS_PYDANTIC_1:
return obj.json(ensure_ascii=False)
return obj.model_dump_json()
kill -9 $(ps aux | grep 'python' | grep -v 'grep' | awk '{print $2}')
kill -9 $(ps aux | grep 'sglang' | grep -v 'grep' | awk '{print $2}')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment