Unverified Commit a5095520 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

[minor] Improve code style and compatibility (#1961)

parent 7ef0084b
......@@ -21,6 +21,7 @@ runtime_common = ["aiohttp", "decord", "fastapi", "hf_transfer", "huggingface_hu
"torchao", "uvicorn", "uvloop", "zmq",
"outlines>=0.0.44", "modelscope"]
srt = ["sglang[runtime_common]", "torch", "vllm==0.6.3.post1"]
# HIP (Heterogeneous-computing Interface for Portability) for AMD
# => base docker rocm/vllm-dev:20241022, not from public vllm whl
srt_hip = ["sglang[runtime_common]", "torch", "vllm==0.6.3.dev13"]
......
......@@ -461,7 +461,7 @@ class TokenizerManager:
break
kill_child_process(include_self=True)
sys.exit(-1)
sys.exit(0)
async def handle_loop(self):
"""The event loop that handles requests"""
......
......@@ -32,7 +32,7 @@ from sglang.srt.layers.logits_processor import (
LogitsProcessorOutput,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.utils import monkey_patch_vllm_all_gather
from sglang.srt.utils import maybe_torch_compile, monkey_patch_vllm_all_gather
if TYPE_CHECKING:
from sglang.srt.model_executor.model_runner import ModelRunner
......@@ -92,7 +92,7 @@ def set_torch_compile_config():
torch._dynamo.config.accumulated_cache_size_limit = 1024
@torch.compile(dynamic=True)
@maybe_torch_compile(dynamic=True)
def clamp_position(seq_lens):
return torch.clamp((seq_lens - 1), min=0).to(torch.int64)
......
......@@ -79,6 +79,7 @@ from sglang.srt.utils import (
add_api_key_middleware,
assert_pkg_version,
configure_logger,
delete_directory,
is_port_available,
kill_child_process,
maybe_set_triton_cache_manager,
......@@ -97,8 +98,6 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
app = FastAPI()
tokenizer_manager: TokenizerManager = None
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
......@@ -107,6 +106,10 @@ app.add_middleware(
allow_headers=["*"],
)
tokenizer_manager: TokenizerManager = None
##### Native API endpoints #####
@app.get("/health")
async def health() -> Response:
......@@ -275,6 +278,9 @@ app.post("/classify")(classify_request)
app.put("/classify")(classify_request)
##### OpenAI-compatible API endpoints #####
@app.post("/v1/completions")
async def openai_v1_completions(raw_request: Request):
return await v1_completions(tokenizer_manager, raw_request)
......@@ -420,19 +426,6 @@ def launch_engine(
scheduler_pipe_readers[i].recv()
def add_prometheus_middleware(app: FastAPI):
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.1/vllm/entrypoints/openai/api_server.py#L216
from prometheus_client import CollectorRegistry, make_asgi_app, multiprocess
registry = CollectorRegistry()
multiprocess.MultiProcessCollector(registry)
metrics_route = Mount("/metrics", make_asgi_app(registry=registry))
# Workaround for 307 Redirect for /metrics
metrics_route.path_regex = re.compile("^/metrics(?P<path>.*)$")
app.routes.append(metrics_route)
def launch_server(
server_args: ServerArgs,
pipe_finish_writer: Optional[mp.connection.Connection] = None,
......@@ -492,6 +485,19 @@ def launch_server(
t.join()
def add_prometheus_middleware(app: FastAPI):
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.1/vllm/entrypoints/openai/api_server.py#L216
from prometheus_client import CollectorRegistry, make_asgi_app, multiprocess
registry = CollectorRegistry()
multiprocess.MultiProcessCollector(registry)
metrics_route = Mount("/metrics", make_asgi_app(registry=registry))
# Workaround for 307 Redirect for /metrics
metrics_route.path_regex = re.compile("^/metrics(?P<path>.*)$")
app.routes.append(metrics_route)
def _set_prometheus_env():
# Set prometheus multiprocess directory
# sglang uses prometheus multiprocess mode
......@@ -565,6 +571,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
return
model_info = res.json()
# Send a warmup request
request_name = "/generate" if model_info["is_generation"] else "/encode"
max_new_tokens = 8 if model_info["is_generation"] else 1
......@@ -602,6 +609,9 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
if pipe_finish_writer is not None:
pipe_finish_writer.send("ready")
if server_args.delete_ckpt_after_loading:
delete_directory(server_args.model_path)
class Runtime:
"""
......
......@@ -63,7 +63,7 @@ class ServerArgs:
stream_interval: int = 1
random_seed: Optional[int] = None
constrained_json_whitespace_pattern: Optional[str] = None
decode_log_interval: int = 40
watchdog_timeout: float = 300
# Logging
log_level: str = "info"
......@@ -71,18 +71,18 @@ class ServerArgs:
log_requests: bool = False
show_time_cost: bool = False
enable_metrics: bool = False
decode_log_interval: int = 40
# Other
# API related
api_key: Optional[str] = None
file_storage_pth: str = "SGLang_storage"
enable_cache_report: bool = False
watchdog_timeout: float = 600
# Data parallelism
dp_size: int = 1
load_balance_method: str = "round_robin"
# Distributed args
# Multi-node distributed serving
dist_init_addr: Optional[str] = None
nnodes: int = 1
node_rank: int = 0
......@@ -128,6 +128,7 @@ class ServerArgs:
enable_p2p_check: bool = False
triton_attention_reduce_in_fp32: bool = False
num_continuous_decode_steps: int = 1
delete_ckpt_after_loading: bool = False
def __post_init__(self):
# Set missing default values
......@@ -205,6 +206,7 @@ class ServerArgs:
@staticmethod
def add_cli_args(parser: argparse.ArgumentParser):
# Model and port args
parser.add_argument(
"--model-path",
type=str,
......@@ -324,6 +326,8 @@ class ServerArgs:
action="store_true",
help="Whether to use a CausalLM as an embedding model.",
)
# Memory and scheduling
parser.add_argument(
"--mem-fraction-static",
type=float,
......@@ -368,6 +372,8 @@ class ServerArgs:
default=ServerArgs.schedule_conservativeness,
help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see requests being retracted frequently.",
)
# Other runtime options
parser.add_argument(
"--tensor-parallel-size",
"--tp-size",
......@@ -393,6 +399,14 @@ class ServerArgs:
default=ServerArgs.constrained_json_whitespace_pattern,
help=r"Regex pattern for syntactic whitespaces allowed in JSON constrained output. For example, to allow the model generate consecutive whitespaces, set the pattern to [\n\t ]*",
)
parser.add_argument(
"--watchdog-timeout",
type=float,
default=ServerArgs.watchdog_timeout,
help="Set watchdog timeout in seconds. If a forward batch takes longer than this, the server will crash to prevent hanging.",
)
# Logging
parser.add_argument(
"--log-level",
type=str,
......@@ -420,7 +434,14 @@ class ServerArgs:
action="store_true",
help="Enable log prometheus metrics.",
)
parser.add_argument(
"--decode-log-interval",
type=int,
default=ServerArgs.decode_log_interval,
help="The log interval of decode batch",
)
# API related
parser.add_argument(
"--api-key",
type=str,
......@@ -438,18 +459,6 @@ class ServerArgs:
action="store_true",
help="Return number of cached tokens in usage.prompt_tokens_details for each openai request.",
)
parser.add_argument(
"--watchdog-timeout",
type=float,
default=ServerArgs.watchdog_timeout,
help="Set watchdog timeout in seconds. If a forward batch takes longer than this, the server will crash to prevent hanging.",
)
parser.add_argument(
"--decode-log-interval",
type=int,
default=ServerArgs.decode_log_interval,
help="The log interval of decode batch",
)
# Data parallelism
parser.add_argument(
......@@ -470,7 +479,7 @@ class ServerArgs:
],
)
# Multi-node distributed serving args
# Multi-node distributed serving
parser.add_argument(
"--dist-init-addr",
"--nccl-init-addr", # For backward compatbility. This will be removed in the future.
......@@ -677,6 +686,12 @@ class ServerArgs:
"This can potentially increase throughput but may also increase time-to-first-token latency. "
"The default value is 1, meaning only run one decoding step at a time.",
)
parser.add_argument(
"--delete-ckpt-after-loading",
default=ServerArgs.delete_ckpt_after_loading,
action="store_true",
help="Delete the model checkpoint after loading the model.",
)
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
......
......@@ -23,6 +23,8 @@ import os
import pickle
import random
import resource
import shutil
import signal
import socket
import time
import warnings
......@@ -35,6 +37,7 @@ import psutil
import requests
import torch
import torch.distributed as dist
import triton
import zmq
from fastapi.responses import ORJSONResponse
from packaging import version as pkg_version
......@@ -379,6 +382,10 @@ def kill_child_process(pid=None, include_self=False, skip_pid=None):
if include_self:
try:
itself.kill()
# Sometime processes cannot be killed with SIGKILL (e.g, PID=1 launched by kubernetes),
# so we send an additional signal to kill them.
itself.send_signal(signal.SIGINT)
except psutil.NoSuchProcess:
pass
......@@ -704,3 +711,44 @@ def get_zmq_socket(context: zmq.Context, socket_type: zmq.SocketType, endpoint:
raise ValueError(f"Unsupported socket type: {socket_type}")
return socket
def dump_to_file(dirpath, name, value):
from vllm.distributed import get_tensor_model_parallel_rank
if get_tensor_model_parallel_rank() != 0:
return
os.makedirs(dirpath, exist_ok=True)
if value.dtype is torch.bfloat16:
value = value.float()
value = value.cpu().numpy()
output_filename = os.path.join(dirpath, f"pytorch_dump_{name}.npy")
logger.info(f"Dump a tensor to {output_filename}. Shape = {value.shape}")
np.save(output_filename, value)
def is_triton_3():
return triton.__version__.startswith("3.")
def maybe_torch_compile(*args, **kwargs):
"""
torch.compile does not work for triton 2.2.0, which is needed in xlm1's jax.
Therefore, we disable it here.
"""
def decorator(func):
if is_triton_3():
return torch.compile(*args, **kwargs)(func)
return func
return decorator
def delete_directory(dirpath):
try:
# This will remove the directory and all its contents
shutil.rmtree(dirpath)
except OSError as e:
print(f"Warning: {dirpath} : {e.strerror}")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment