Unverified Commit 7023f413 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Clean up (#422)

parent 09deb20d
...@@ -2,6 +2,7 @@ import argparse ...@@ -2,6 +2,7 @@ import argparse
from sglang.srt.server import ServerArgs, launch_server from sglang.srt.server import ServerArgs, launch_server
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
ServerArgs.add_cli_args(parser) ServerArgs.add_cli_args(parser)
......
...@@ -37,6 +37,7 @@ from sglang.srt.utils import ( ...@@ -37,6 +37,7 @@ from sglang.srt.utils import (
) )
logger = logging.getLogger("model_rpc") logger = logging.getLogger("model_rpc")
logging.getLogger("vllm.utils").setLevel(logging.WARN)
class ModelRpcServer: class ModelRpcServer:
...@@ -113,7 +114,7 @@ class ModelRpcServer: ...@@ -113,7 +114,7 @@ class ModelRpcServer:
f"max_prefill_num_token={self.max_prefill_num_token}, " f"max_prefill_num_token={self.max_prefill_num_token}, "
f"context_len={self.model_config.context_len}, " f"context_len={self.model_config.context_len}, "
) )
logger.info(server_args.get_optional_modes_logging()) logger.info(f"server_args: {server_args.print_mode_args()}")
# Init cache # Init cache
self.tree_cache = RadixCache(disable=server_args.disable_radix_cache) self.tree_cache = RadixCache(disable=server_args.disable_radix_cache)
......
...@@ -28,7 +28,6 @@ QUANTIZATION_CONFIG_MAPPING = { ...@@ -28,7 +28,6 @@ QUANTIZATION_CONFIG_MAPPING = {
logger = logging.getLogger("model_runner") logger = logging.getLogger("model_runner")
# for server args in model endpoints # for server args in model endpoints
global_server_args_dict: dict = None global_server_args_dict: dict = None
...@@ -276,9 +275,6 @@ class ModelRunner: ...@@ -276,9 +275,6 @@ class ModelRunner:
init_method=f"tcp://127.0.0.1:{self.nccl_port}", init_method=f"tcp://127.0.0.1:{self.nccl_port}",
) )
# A small all_reduce for warmup.
if self.tp_size > 1:
torch.distributed.all_reduce(torch.zeros(1).cuda())
initialize_model_parallel(tensor_model_parallel_size=self.tp_size) initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
total_gpu_memory = get_available_gpu_memory( total_gpu_memory = get_available_gpu_memory(
......
...@@ -15,15 +15,11 @@ setattr(threading, "_register_atexit", lambda *args, **kwargs: None) ...@@ -15,15 +15,11 @@ setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
import aiohttp import aiohttp
import psutil import psutil
import pydantic
import requests import requests
import uvicorn import uvicorn
import uvloop import uvloop
from fastapi import FastAPI, HTTPException, Request from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import Response, StreamingResponse from fastapi.responses import Response, StreamingResponse
from pydantic import BaseModel
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse
from sglang.backend.runtime_endpoint import RuntimeEndpoint from sglang.backend.runtime_endpoint import RuntimeEndpoint
from sglang.srt.constrained import disable_cache from sglang.srt.constrained import disable_cache
...@@ -37,7 +33,7 @@ from sglang.srt.conversation import ( ...@@ -37,7 +33,7 @@ from sglang.srt.conversation import (
from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.managers.detokenizer_manager import start_detokenizer_process from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
from sglang.srt.managers.io_struct import DetokenizeReqInput, GenerateReqInput from sglang.srt.managers.io_struct import DetokenizeReqInput, GenerateReqInput
from sglang.srt.managers.openai_protocol import ( from sglang.srt.openai_protocol import (
ChatCompletionRequest, ChatCompletionRequest,
ChatCompletionResponse, ChatCompletionResponse,
ChatCompletionResponseChoice, ChatCompletionResponseChoice,
...@@ -56,45 +52,24 @@ from sglang.srt.managers.openai_protocol import ( ...@@ -56,45 +52,24 @@ from sglang.srt.managers.openai_protocol import (
from sglang.srt.managers.router.manager import start_router_process from sglang.srt.managers.router.manager import start_router_process
from sglang.srt.managers.tokenizer_manager import TokenizerManager from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import enable_show_time_cost, handle_port_init from sglang.srt.utils import (
enable_show_time_cost,
allocate_init_ports,
jsonify_pydantic_model,
assert_pkg_version,
get_exception_traceback,
API_KEY_HEADER_NAME,
APIKeyValidatorMiddleware
)
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
API_KEY_HEADER_NAME = "X-API-Key"
class APIKeyValidatorMiddleware(BaseHTTPMiddleware):
def __init__(self, app, api_key: str):
super().__init__(app)
self.api_key = api_key
async def dispatch(self, request: Request, call_next):
# extract API key from the request headers
api_key_header = request.headers.get(API_KEY_HEADER_NAME)
if not api_key_header or api_key_header != self.api_key:
return JSONResponse(
status_code=403,
content={"detail": "Invalid API Key"},
)
response = await call_next(request)
return response
app = FastAPI() app = FastAPI()
tokenizer_manager = None tokenizer_manager = None
chat_template_name = None chat_template_name = None
# FIXME: Remove this once we drop support for pydantic 1.x
IS_PYDANTIC_1 = int(pydantic.VERSION.split(".")[0]) == 1
def jsonify_pydantic_model(obj: BaseModel):
if IS_PYDANTIC_1:
return obj.json(ensure_ascii=False)
return obj.model_dump_json()
@app.get("/health") @app.get("/health")
async def health() -> Response: async def health() -> Response:
"""Health check.""" """Health check."""
...@@ -124,6 +99,31 @@ async def flush_cache(): ...@@ -124,6 +99,31 @@ async def flush_cache():
) )
async def stream_generator(obj: GenerateReqInput):
async for out in tokenizer_manager.generate_request(obj):
await handle_token_logprobs_results(obj, out)
yield out
@app.post("/generate")
async def generate_request(obj: GenerateReqInput):
obj.post_init()
if obj.stream:
async def stream_results():
async for out in stream_generator(obj):
yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(stream_results(), media_type="text/event-stream")
ret = await tokenizer_manager.generate_request(obj).__anext__()
await handle_token_logprobs_results(obj, ret)
return ret
async def detokenize_logprob_tokens(token_logprobs, decode_to_text): async def detokenize_logprob_tokens(token_logprobs, decode_to_text):
if not decode_to_text: if not decode_to_text:
return [(logprob, token_id, None) for logprob, token_id in token_logprobs] return [(logprob, token_id, None) for logprob, token_id in token_logprobs]
...@@ -175,68 +175,6 @@ async def handle_token_logprobs_results(obj: GenerateReqInput, ret): ...@@ -175,68 +175,6 @@ async def handle_token_logprobs_results(obj: GenerateReqInput, ret):
await convert_style(r, obj.return_text_in_logprobs) await convert_style(r, obj.return_text_in_logprobs)
async def stream_generator(obj: GenerateReqInput):
async for out in tokenizer_manager.generate_request(obj):
await handle_token_logprobs_results(obj, out)
yield out
async def make_openai_style_logprobs(
prefill_token_logprobs=None,
decode_token_logprobs=None,
prefill_top_logprobs=None,
decode_top_logprobs=None,
):
ret_logprobs = LogProbs()
def append_token_logprobs(token_logprobs):
for logprob, _, token_text in token_logprobs:
ret_logprobs.tokens.append(token_text)
ret_logprobs.token_logprobs.append(logprob)
# Not Supported yet
ret_logprobs.text_offset.append(-1)
def append_top_logprobs(top_logprobs):
for tokens in top_logprobs:
if tokens is not None:
ret_logprobs.top_logprobs.append(
{token[2]: token[0] for token in tokens}
)
else:
ret_logprobs.top_logprobs.append(None)
if prefill_token_logprobs is not None:
append_token_logprobs(prefill_token_logprobs)
if decode_token_logprobs is not None:
append_token_logprobs(decode_token_logprobs)
if prefill_top_logprobs is not None:
append_top_logprobs(prefill_top_logprobs)
if decode_top_logprobs is not None:
append_top_logprobs(decode_top_logprobs)
return ret_logprobs
@app.post("/generate")
async def generate_request(obj: GenerateReqInput):
obj.post_init()
if obj.stream:
async def stream_results():
async for out in stream_generator(obj):
yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(stream_results(), media_type="text/event-stream")
ret = await tokenizer_manager.generate_request(obj).__anext__()
await handle_token_logprobs_results(obj, ret)
return ret
@app.post("/v1/completions") @app.post("/v1/completions")
async def v1_completions(raw_request: Request): async def v1_completions(raw_request: Request):
request_json = await raw_request.json() request_json = await raw_request.json()
...@@ -500,45 +438,54 @@ async def v1_chat_completions(raw_request: Request): ...@@ -500,45 +438,54 @@ async def v1_chat_completions(raw_request: Request):
return response return response
def launch_server(server_args: ServerArgs, pipe_finish_writer): async def make_openai_style_logprobs(
global tokenizer_manager prefill_token_logprobs=None,
global chat_template_name decode_token_logprobs=None,
prefill_top_logprobs=None,
if server_args.enable_flashinfer: decode_top_logprobs=None,
from sglang.srt.utils import assert_pkg_version ):
assert_pkg_version("flashinfer", "0.0.4") ret_logprobs = LogProbs()
# start show time thread def append_token_logprobs(token_logprobs):
if server_args.show_time_cost: for logprob, _, token_text in token_logprobs:
enable_show_time_cost() ret_logprobs.tokens.append(token_text)
ret_logprobs.token_logprobs.append(logprob)
# disable disk cache if needed # Not Supported yet
if server_args.disable_disk_cache: ret_logprobs.text_offset.append(-1)
disable_cache()
# Handle ports def append_top_logprobs(top_logprobs):
server_args.port, server_args.additional_ports = handle_port_init( for tokens in top_logprobs:
server_args.port, server_args.additional_ports, server_args.tp_size if tokens is not None:
ret_logprobs.top_logprobs.append(
{token[2]: token[0] for token in tokens}
) )
else:
ret_logprobs.top_logprobs.append(None)
port_args = PortArgs( if prefill_token_logprobs is not None:
tokenizer_port=server_args.additional_ports[0], append_token_logprobs(prefill_token_logprobs)
router_port=server_args.additional_ports[1], if decode_token_logprobs is not None:
detokenizer_port=server_args.additional_ports[2], append_token_logprobs(decode_token_logprobs)
nccl_port=server_args.additional_ports[3], if prefill_top_logprobs is not None:
model_rpc_ports=server_args.additional_ports[4:], append_top_logprobs(prefill_top_logprobs)
) if decode_top_logprobs is not None:
append_top_logprobs(decode_top_logprobs)
# Load chat template if needed return ret_logprobs
if server_args.chat_template is not None:
print(f"Use chat template: {server_args.chat_template}")
if not chat_template_exists(server_args.chat_template): def load_chat_template_for_openai_api(chat_template_arg):
if not os.path.exists(server_args.chat_template): global chat_template_name
print(f"Use chat template: {chat_template_arg}")
if not chat_template_exists(chat_template_arg):
if not os.path.exists(chat_template_arg):
raise RuntimeError( raise RuntimeError(
f"Chat template {server_args.chat_template} is not a built-in template name " f"Chat template {chat_template_arg} is not a built-in template name "
"or a valid chat template file path." "or a valid chat template file path."
) )
with open(server_args.chat_template, "r") as filep: with open(chat_template_arg, "r") as filep:
template = json.load(filep) template = json.load(filep)
try: try:
sep_style = SeparatorStyle[template["sep_style"]] sep_style = SeparatorStyle[template["sep_style"]]
...@@ -560,7 +507,35 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer): ...@@ -560,7 +507,35 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer):
) )
chat_template_name = template["name"] chat_template_name = template["name"]
else: else:
chat_template_name = server_args.chat_template chat_template_name = chat_template_arg
def launch_server(server_args: ServerArgs, pipe_finish_writer):
global tokenizer_manager
# Set global environments
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
if server_args.show_time_cost:
enable_show_time_cost()
if server_args.disable_disk_cache:
disable_cache()
if server_args.enable_flashinfer:
assert_pkg_version("flashinfer", "0.0.4")
if server_args.chat_template:
# TODO: replace this with huggingface transformers template
load_chat_template_for_openai_api(server_args.chat_template)
# Allocate ports
server_args.port, server_args.additional_ports = allocate_init_ports(
server_args.port, server_args.additional_ports, server_args.tp_size
)
port_args = PortArgs(
tokenizer_port=server_args.additional_ports[0],
router_port=server_args.additional_ports[1],
detokenizer_port=server_args.additional_ports[2],
nccl_port=server_args.additional_ports[3],
model_rpc_ports=server_args.additional_ports[4:],
)
# Launch processes # Launch processes
tokenizer_manager = TokenizerManager(server_args, port_args) tokenizer_manager = TokenizerManager(server_args, port_args)
...@@ -593,31 +568,21 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer): ...@@ -593,31 +568,21 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer):
if router_init_state != "init ok" or detoken_init_state != "init ok": if router_init_state != "init ok" or detoken_init_state != "init ok":
proc_router.kill() proc_router.kill()
proc_detoken.kill() proc_detoken.kill()
print("router init state:", router_init_state) print(f"Initialization failed. router_init_state: {router_init_state}", flush=True)
print("detoken init state:", detoken_init_state) print(f"Initialization failed. detoken_init_state: {detoken_init_state}", flush=True)
sys.exit(1) sys.exit(1)
assert proc_router.is_alive() and proc_detoken.is_alive() assert proc_router.is_alive() and proc_detoken.is_alive()
if server_args.api_key and server_args.api_key != "": if server_args.api_key and server_args.api_key != "":
app.add_middleware(APIKeyValidatorMiddleware, api_key=server_args.api_key) app.add_middleware(APIKeyValidatorMiddleware, api_key=server_args.api_key)
def _launch_server():
uvicorn.run(
app,
host=server_args.host,
port=server_args.port,
log_level=server_args.log_level,
timeout_keep_alive=5,
loop="uvloop",
)
def _wait_and_warmup(): def _wait_and_warmup():
headers = {} headers = {}
url = server_args.url() url = server_args.url()
if server_args.api_key and server_args.api_key != "": if server_args.api_key:
headers[API_KEY_HEADER_NAME] = server_args.api_key headers[API_KEY_HEADER_NAME] = server_args.api_key
# Wait until the server is launched
for _ in range(120): for _ in range(120):
time.sleep(0.5) time.sleep(0.5)
try: try:
...@@ -625,16 +590,9 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer): ...@@ -625,16 +590,9 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer):
break break
except requests.exceptions.RequestException as e: except requests.exceptions.RequestException as e:
pass pass
else:
if pipe_finish_writer is not None:
pipe_finish_writer.send(str(e))
else:
print(e, flush=True)
return
# Warmup # Send a warmup request
try: try:
# print("Warmup...", flush=True)
res = requests.post( res = requests.post(
url + "/generate", url + "/generate",
json={ json={
...@@ -647,14 +605,12 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer): ...@@ -647,14 +605,12 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer):
headers=headers, headers=headers,
timeout=60, timeout=60,
) )
# print(f"Warmup done. model response: {res.json()['text']}") assert res.status_code == 200
# print("=" * 20, "Server is ready", "=" * 20, flush=True) except Exception as e:
except requests.exceptions.RequestException as e:
if pipe_finish_writer is not None: if pipe_finish_writer is not None:
pipe_finish_writer.send(str(e)) pipe_finish_writer.send(get_exception_traceback())
else: print(f"Initialization failed. warmup error: {e}")
print(e, flush=True) raise e
return
if pipe_finish_writer is not None: if pipe_finish_writer is not None:
pipe_finish_writer.send("init ok") pipe_finish_writer.send("init ok")
...@@ -662,7 +618,14 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer): ...@@ -662,7 +618,14 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer):
t = threading.Thread(target=_wait_and_warmup) t = threading.Thread(target=_wait_and_warmup)
t.start() t.start()
try: try:
_launch_server() uvicorn.run(
app,
host=server_args.host,
port=server_args.port,
log_level=server_args.log_level,
timeout_keep_alive=5,
loop="uvloop",
)
finally: finally:
t.join() t.join()
...@@ -670,52 +633,16 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer): ...@@ -670,52 +633,16 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer):
class Runtime: class Runtime:
def __init__( def __init__(
self, self,
model_path: str, log_evel="error",
tokenizer_path: Optional[str] = None, *args,
load_format: str = "auto", **kwargs,
tokenizer_mode: str = "auto",
trust_remote_code: bool = True,
mem_fraction_static: float = ServerArgs.mem_fraction_static,
max_prefill_num_token: int = ServerArgs.max_prefill_num_token,
context_length: int = ServerArgs.context_length,
tp_size: int = 1,
schedule_heuristic: str = "lpm",
attention_reduce_in_fp32: bool = False,
random_seed: int = 42,
log_level: str = "error",
disable_radix_cache: bool = False,
enable_flashinfer: bool = False,
disable_regex_jump_forward: bool = False,
disable_disk_cache: bool = False,
api_key: str = "",
port: Optional[int] = None,
additional_ports: Optional[Union[List[int], int]] = None,
): ):
host = "127.0.0.1" """See the arguments in server_args.py::ServerArgs"""
port, additional_ports = handle_port_init(port, additional_ports, tp_size) self.server_args = ServerArgs(*args, log_level=log_evel, **kwargs)
self.server_args = ServerArgs(
model_path=model_path, # Pre-allocate ports
tokenizer_path=tokenizer_path, self.server_args.port, self.server_args.additional_ports = allocate_init_ports(
host=host, self.server_args.port, self.server_args.additional_ports, self.server_args.tp_size)
port=port,
additional_ports=additional_ports,
load_format=load_format,
tokenizer_mode=tokenizer_mode,
trust_remote_code=trust_remote_code,
mem_fraction_static=mem_fraction_static,
max_prefill_num_token=max_prefill_num_token,
context_length=context_length,
tp_size=tp_size,
schedule_heuristic=schedule_heuristic,
attention_reduce_in_fp32=attention_reduce_in_fp32,
random_seed=random_seed,
log_level=log_level,
disable_radix_cache=disable_radix_cache,
enable_flashinfer=enable_flashinfer,
disable_regex_jump_forward=disable_regex_jump_forward,
disable_disk_cache=disable_disk_cache,
api_key=api_key,
)
self.url = self.server_args.url() self.url = self.server_args.url()
self.generate_url = ( self.generate_url = (
...@@ -736,7 +663,7 @@ class Runtime: ...@@ -736,7 +663,7 @@ class Runtime:
if init_state != "init ok": if init_state != "init ok":
self.shutdown() self.shutdown()
raise RuntimeError("Launch failed. Please see the error messages above.") raise RuntimeError("Initialization failed. Please see the error messages above.")
self.endpoint = RuntimeEndpoint(self.url) self.endpoint = RuntimeEndpoint(self.url)
...@@ -765,13 +692,12 @@ class Runtime: ...@@ -765,13 +692,12 @@ class Runtime:
self, self,
prompt: str, prompt: str,
sampling_params, sampling_params,
) -> None: ):
json_data = { json_data = {
"text": prompt, "text": prompt,
"sampling_params": sampling_params, "sampling_params": sampling_params,
"stream": True, "stream": True,
} }
pos = 0 pos = 0
timeout = aiohttp.ClientTimeout(total=3 * 3600) timeout = aiohttp.ClientTimeout(total=3 * 3600)
......
"""The arguments of the server."""
import argparse import argparse
import dataclasses import dataclasses
from typing import List, Optional, Union from typing import List, Optional, Union
...@@ -5,33 +7,44 @@ from typing import List, Optional, Union ...@@ -5,33 +7,44 @@ from typing import List, Optional, Union
@dataclasses.dataclass @dataclasses.dataclass
class ServerArgs: class ServerArgs:
# Model and tokenizer
model_path: str model_path: str
tokenizer_path: Optional[str] = None tokenizer_path: Optional[str] = None
host: str = "127.0.0.1"
port: int = 30000
additional_ports: Optional[Union[List[int], int]] = None
load_format: str = "auto" load_format: str = "auto"
tokenizer_mode: str = "auto" tokenizer_mode: str = "auto"
chat_template: Optional[str] = None chat_template: Optional[str] = None
trust_remote_code: bool = True trust_remote_code: bool = True
context_length: Optional[int] = None
# Port
host: str = "127.0.0.1"
port: int = 30000
additional_ports: Optional[Union[List[int], int]] = None
# Memory and scheduling
mem_fraction_static: Optional[float] = None mem_fraction_static: Optional[float] = None
max_prefill_num_token: Optional[int] = None max_prefill_num_token: Optional[int] = None
context_length: Optional[int] = None
tp_size: int = 1
schedule_heuristic: str = "lpm" schedule_heuristic: str = "lpm"
schedule_conservativeness: float = 1.0 schedule_conservativeness: float = 1.0
attention_reduce_in_fp32: bool = False
random_seed: int = 42 # Other runtime options
tp_size: int = 1
stream_interval: int = 8 stream_interval: int = 8
random_seed: int = 42
# Logging
log_level: str = "info"
disable_log_stats: bool = False disable_log_stats: bool = False
log_stats_interval: int = 10 log_stats_interval: int = 10
log_level: str = "info"
api_key: str = ""
show_time_cost: bool = False show_time_cost: bool = False
# optional modes # Other
disable_radix_cache: bool = False api_key: str = ""
# Optimization/debug options
enable_flashinfer: bool = False enable_flashinfer: bool = False
attention_reduce_in_fp32: bool = False
disable_radix_cache: bool = False
disable_regex_jump_forward: bool = False disable_regex_jump_forward: bool = False
disable_disk_cache: bool = False disable_disk_cache: bool = False
...@@ -66,15 +79,16 @@ class ServerArgs: ...@@ -66,15 +79,16 @@ class ServerArgs:
default=ServerArgs.tokenizer_path, default=ServerArgs.tokenizer_path,
help="The path of the tokenizer.", help="The path of the tokenizer.",
) )
parser.add_argument("--host", type=str, default=ServerArgs.host) parser.add_argument("--host", type=str, default=ServerArgs.host,
parser.add_argument("--port", type=int, default=ServerArgs.port) help="The host of the server.")
# we want to be able to pass a list of ports parser.add_argument("--port", type=int, default=ServerArgs.port,
help="The port of the server.")
parser.add_argument( parser.add_argument(
"--additional-ports", "--additional-ports",
type=int, type=int,
nargs="*", nargs="*",
default=[], default=[],
help="Additional ports specified for launching server.", help="Additional ports specified for the server.",
) )
parser.add_argument( parser.add_argument(
"--load-format", "--load-format",
...@@ -112,6 +126,12 @@ class ServerArgs: ...@@ -112,6 +126,12 @@ class ServerArgs:
action="store_true", action="store_true",
help="Whether or not to allow for custom models defined on the Hub in their own modeling files.", help="Whether or not to allow for custom models defined on the Hub in their own modeling files.",
) )
parser.add_argument(
"--context-length",
type=int,
default=ServerArgs.context_length,
help="The model's maximum context length. Defaults to None (will use the value from the model's config.json instead).",
)
parser.add_argument( parser.add_argument(
"--mem-fraction-static", "--mem-fraction-static",
type=float, type=float,
...@@ -124,18 +144,6 @@ class ServerArgs: ...@@ -124,18 +144,6 @@ class ServerArgs:
default=ServerArgs.max_prefill_num_token, default=ServerArgs.max_prefill_num_token,
help="The maximum number of tokens in a prefill batch. The real bound will be the maximum of this value and the model's maximum context length.", help="The maximum number of tokens in a prefill batch. The real bound will be the maximum of this value and the model's maximum context length.",
) )
parser.add_argument(
"--context-length",
type=int,
default=ServerArgs.context_length,
help="The model's maximum context length. Use this to reduce the context length to save memory. Defaults to None (will use the value from the model's config.json instead).",
)
parser.add_argument(
"--tp-size",
type=int,
default=ServerArgs.tp_size,
help="Tensor parallelism degree.",
)
parser.add_argument( parser.add_argument(
"--schedule-heuristic", "--schedule-heuristic",
type=str, type=str,
...@@ -149,15 +157,10 @@ class ServerArgs: ...@@ -149,15 +157,10 @@ class ServerArgs:
help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see requests being retracted frequently.", help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see requests being retracted frequently.",
) )
parser.add_argument( parser.add_argument(
"--random-seed", "--tp-size",
type=int, type=int,
default=ServerArgs.random_seed, default=ServerArgs.tp_size,
help="Random seed.", help="Tensor parallelism size.",
)
parser.add_argument(
"--attention-reduce-in-fp32",
action="store_true",
help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16.",
) )
parser.add_argument( parser.add_argument(
"--stream-interval", "--stream-interval",
...@@ -165,11 +168,17 @@ class ServerArgs: ...@@ -165,11 +168,17 @@ class ServerArgs:
default=ServerArgs.stream_interval, default=ServerArgs.stream_interval,
help="The interval (or buffer size) for streaming in terms of the token length. A smaller value makes streaming smoother, while a larger value makes the throughput higher", help="The interval (or buffer size) for streaming in terms of the token length. A smaller value makes streaming smoother, while a larger value makes the throughput higher",
) )
parser.add_argument(
"--random-seed",
type=int,
default=ServerArgs.random_seed,
help="Random seed.",
)
parser.add_argument( parser.add_argument(
"--log-level", "--log-level",
type=str, type=str,
default=ServerArgs.log_level, default=ServerArgs.log_level,
help="Log level", help="Logging level",
) )
parser.add_argument( parser.add_argument(
"--disable-log-stats", "--disable-log-stats",
...@@ -182,28 +191,33 @@ class ServerArgs: ...@@ -182,28 +191,33 @@ class ServerArgs:
default=ServerArgs.log_stats_interval, default=ServerArgs.log_stats_interval,
help="Log stats interval in second.", help="Log stats interval in second.",
) )
parser.add_argument(
"--show-time-cost",
action="store_true",
help="Show time cost of custom marks",
)
parser.add_argument( parser.add_argument(
"--api-key", "--api-key",
type=str, type=str,
default=ServerArgs.api_key, default=ServerArgs.api_key,
help="Set API Key", help="Set API key of the server",
) )
# Optimization/debug options
parser.add_argument( parser.add_argument(
"--show-time-cost", "--enable-flashinfer",
action="store_true", action="store_true",
help="Show time cost of custom marks", help="Enable flashinfer inference kernels",
) )
# optional modes
parser.add_argument( parser.add_argument(
"--disable-radix-cache", "--attention-reduce-in-fp32",
action="store_true", action="store_true",
help="Disable RadixAttention", help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16.",
) )
parser.add_argument( parser.add_argument(
"--enable-flashinfer", "--disable-radix-cache",
action="store_true", action="store_true",
help="Enable flashinfer inference kernels", help="Disable RadixAttention",
) )
parser.add_argument( parser.add_argument(
"--disable-regex-jump-forward", "--disable-regex-jump-forward",
...@@ -224,13 +238,13 @@ class ServerArgs: ...@@ -224,13 +238,13 @@ class ServerArgs:
def url(self): def url(self):
return f"http://{self.host}:{self.port}" return f"http://{self.host}:{self.port}"
def get_optional_modes_logging(self): def print_mode_args(self):
return ( return (
f"disable_radix_cache={self.disable_radix_cache}, "
f"enable_flashinfer={self.enable_flashinfer}, " f"enable_flashinfer={self.enable_flashinfer}, "
f"attention_reduce_in_fp32={self.attention_reduce_in_fp32}"
f"disable_radix_cache={self.disable_radix_cache}, "
f"disable_regex_jump_forward={self.disable_regex_jump_forward}, " f"disable_regex_jump_forward={self.disable_regex_jump_forward}, "
f"disable_disk_cache={self.disable_disk_cache}, " f"disable_disk_cache={self.disable_disk_cache}, "
f"attention_reduce_in_fp32={self.attention_reduce_in_fp32}"
) )
......
...@@ -10,9 +10,12 @@ from io import BytesIO ...@@ -10,9 +10,12 @@ from io import BytesIO
from typing import List, Optional from typing import List, Optional
import numpy as np import numpy as np
import pydantic
import requests import requests
import torch import torch
from packaging import version as pkg_version from packaging import version as pkg_version
from pydantic import BaseModel
from starlette.middleware.base import BaseHTTPMiddleware
show_time_cost = False show_time_cost = False
time_infos = {} time_infos = {}
...@@ -120,7 +123,7 @@ def check_port(port): ...@@ -120,7 +123,7 @@ def check_port(port):
return False return False
def handle_port_init( def allocate_init_ports(
port: Optional[int] = None, port: Optional[int] = None,
additional_ports: Optional[List[int]] = None, additional_ports: Optional[List[int]] = None,
tp_size: int = 1, tp_size: int = 1,
...@@ -159,8 +162,6 @@ def get_exception_traceback(): ...@@ -159,8 +162,6 @@ def get_exception_traceback():
def get_int_token_logit_bias(tokenizer, vocab_size): def get_int_token_logit_bias(tokenizer, vocab_size):
from transformers import LlamaTokenizer, LlamaTokenizerFast
# a bug when model's vocab size > tokenizer.vocab_size # a bug when model's vocab size > tokenizer.vocab_size
vocab_size = tokenizer.vocab_size vocab_size = tokenizer.vocab_size
logit_bias = np.zeros(vocab_size, dtype=np.float32) logit_bias = np.zeros(vocab_size, dtype=np.float32)
...@@ -281,3 +282,32 @@ def assert_pkg_version(pkg: str, min_version: str): ...@@ -281,3 +282,32 @@ def assert_pkg_version(pkg: str, min_version: str):
) )
except PackageNotFoundError: except PackageNotFoundError:
raise Exception(f"{pkg} with minimum required version {min_version} is not installed") raise Exception(f"{pkg} with minimum required version {min_version} is not installed")
API_KEY_HEADER_NAME = "X-API-Key"
class APIKeyValidatorMiddleware(BaseHTTPMiddleware):
def __init__(self, app, api_key: str):
super().__init__(app)
self.api_key = api_key
async def dispatch(self, request, call_next):
# extract API key from the request headers
api_key_header = request.headers.get(API_KEY_HEADER_NAME)
if not api_key_header or api_key_header != self.api_key:
return JSONResponse(
status_code=403,
content={"detail": "Invalid API Key"},
)
response = await call_next(request)
return response
# FIXME: Remove this once we drop support for pydantic 1.x
IS_PYDANTIC_1 = int(pydantic.VERSION.split(".")[0]) == 1
def jsonify_pydantic_model(obj: BaseModel):
if IS_PYDANTIC_1:
return obj.json(ensure_ascii=False)
return obj.model_dump_json()
kill -9 $(ps aux | grep 'python' | grep -v 'grep' | awk '{print $2}') kill -9 $(ps aux | grep 'sglang' | grep -v 'grep' | awk '{print $2}')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment