Unverified Commit a5095520 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

[minor] Improve code style and compatibility (#1961)

parent 7ef0084b
...@@ -21,6 +21,7 @@ runtime_common = ["aiohttp", "decord", "fastapi", "hf_transfer", "huggingface_hu ...@@ -21,6 +21,7 @@ runtime_common = ["aiohttp", "decord", "fastapi", "hf_transfer", "huggingface_hu
"torchao", "uvicorn", "uvloop", "zmq", "torchao", "uvicorn", "uvloop", "zmq",
"outlines>=0.0.44", "modelscope"] "outlines>=0.0.44", "modelscope"]
srt = ["sglang[runtime_common]", "torch", "vllm==0.6.3.post1"] srt = ["sglang[runtime_common]", "torch", "vllm==0.6.3.post1"]
# HIP (Heterogeneous-computing Interface for Portability) for AMD # HIP (Heterogeneous-computing Interface for Portability) for AMD
# => base docker rocm/vllm-dev:20241022, not from public vllm whl # => base docker rocm/vllm-dev:20241022, not from public vllm whl
srt_hip = ["sglang[runtime_common]", "torch", "vllm==0.6.3.dev13"] srt_hip = ["sglang[runtime_common]", "torch", "vllm==0.6.3.dev13"]
......
...@@ -461,7 +461,7 @@ class TokenizerManager: ...@@ -461,7 +461,7 @@ class TokenizerManager:
break break
kill_child_process(include_self=True) kill_child_process(include_self=True)
sys.exit(-1) sys.exit(0)
async def handle_loop(self): async def handle_loop(self):
"""The event loop that handles requests""" """The event loop that handles requests"""
......
...@@ -32,7 +32,7 @@ from sglang.srt.layers.logits_processor import ( ...@@ -32,7 +32,7 @@ from sglang.srt.layers.logits_processor import (
LogitsProcessorOutput, LogitsProcessorOutput,
) )
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.utils import monkey_patch_vllm_all_gather from sglang.srt.utils import maybe_torch_compile, monkey_patch_vllm_all_gather
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
...@@ -92,7 +92,7 @@ def set_torch_compile_config(): ...@@ -92,7 +92,7 @@ def set_torch_compile_config():
torch._dynamo.config.accumulated_cache_size_limit = 1024 torch._dynamo.config.accumulated_cache_size_limit = 1024
@torch.compile(dynamic=True) @maybe_torch_compile(dynamic=True)
def clamp_position(seq_lens): def clamp_position(seq_lens):
return torch.clamp((seq_lens - 1), min=0).to(torch.int64) return torch.clamp((seq_lens - 1), min=0).to(torch.int64)
......
...@@ -79,6 +79,7 @@ from sglang.srt.utils import ( ...@@ -79,6 +79,7 @@ from sglang.srt.utils import (
add_api_key_middleware, add_api_key_middleware,
assert_pkg_version, assert_pkg_version,
configure_logger, configure_logger,
delete_directory,
is_port_available, is_port_available,
kill_child_process, kill_child_process,
maybe_set_triton_cache_manager, maybe_set_triton_cache_manager,
...@@ -97,8 +98,6 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) ...@@ -97,8 +98,6 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
app = FastAPI() app = FastAPI()
tokenizer_manager: TokenizerManager = None
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
allow_origins=["*"], allow_origins=["*"],
...@@ -107,6 +106,10 @@ app.add_middleware( ...@@ -107,6 +106,10 @@ app.add_middleware(
allow_headers=["*"], allow_headers=["*"],
) )
tokenizer_manager: TokenizerManager = None
##### Native API endpoints #####
@app.get("/health") @app.get("/health")
async def health() -> Response: async def health() -> Response:
...@@ -275,6 +278,9 @@ app.post("/classify")(classify_request) ...@@ -275,6 +278,9 @@ app.post("/classify")(classify_request)
app.put("/classify")(classify_request) app.put("/classify")(classify_request)
##### OpenAI-compatible API endpoints #####
@app.post("/v1/completions") @app.post("/v1/completions")
async def openai_v1_completions(raw_request: Request): async def openai_v1_completions(raw_request: Request):
return await v1_completions(tokenizer_manager, raw_request) return await v1_completions(tokenizer_manager, raw_request)
...@@ -420,19 +426,6 @@ def launch_engine( ...@@ -420,19 +426,6 @@ def launch_engine(
scheduler_pipe_readers[i].recv() scheduler_pipe_readers[i].recv()
def add_prometheus_middleware(app: FastAPI):
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.1/vllm/entrypoints/openai/api_server.py#L216
from prometheus_client import CollectorRegistry, make_asgi_app, multiprocess
registry = CollectorRegistry()
multiprocess.MultiProcessCollector(registry)
metrics_route = Mount("/metrics", make_asgi_app(registry=registry))
# Workaround for 307 Redirect for /metrics
metrics_route.path_regex = re.compile("^/metrics(?P<path>.*)$")
app.routes.append(metrics_route)
def launch_server( def launch_server(
server_args: ServerArgs, server_args: ServerArgs,
pipe_finish_writer: Optional[mp.connection.Connection] = None, pipe_finish_writer: Optional[mp.connection.Connection] = None,
...@@ -492,6 +485,19 @@ def launch_server( ...@@ -492,6 +485,19 @@ def launch_server(
t.join() t.join()
def add_prometheus_middleware(app: FastAPI):
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.1/vllm/entrypoints/openai/api_server.py#L216
from prometheus_client import CollectorRegistry, make_asgi_app, multiprocess
registry = CollectorRegistry()
multiprocess.MultiProcessCollector(registry)
metrics_route = Mount("/metrics", make_asgi_app(registry=registry))
# Workaround for 307 Redirect for /metrics
metrics_route.path_regex = re.compile("^/metrics(?P<path>.*)$")
app.routes.append(metrics_route)
def _set_prometheus_env(): def _set_prometheus_env():
# Set prometheus multiprocess directory # Set prometheus multiprocess directory
# sglang uses prometheus multiprocess mode # sglang uses prometheus multiprocess mode
...@@ -565,6 +571,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer): ...@@ -565,6 +571,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
return return
model_info = res.json() model_info = res.json()
# Send a warmup request # Send a warmup request
request_name = "/generate" if model_info["is_generation"] else "/encode" request_name = "/generate" if model_info["is_generation"] else "/encode"
max_new_tokens = 8 if model_info["is_generation"] else 1 max_new_tokens = 8 if model_info["is_generation"] else 1
...@@ -602,6 +609,9 @@ def _wait_and_warmup(server_args, pipe_finish_writer): ...@@ -602,6 +609,9 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
if pipe_finish_writer is not None: if pipe_finish_writer is not None:
pipe_finish_writer.send("ready") pipe_finish_writer.send("ready")
if server_args.delete_ckpt_after_loading:
delete_directory(server_args.model_path)
class Runtime: class Runtime:
""" """
......
...@@ -63,7 +63,7 @@ class ServerArgs: ...@@ -63,7 +63,7 @@ class ServerArgs:
stream_interval: int = 1 stream_interval: int = 1
random_seed: Optional[int] = None random_seed: Optional[int] = None
constrained_json_whitespace_pattern: Optional[str] = None constrained_json_whitespace_pattern: Optional[str] = None
decode_log_interval: int = 40 watchdog_timeout: float = 300
# Logging # Logging
log_level: str = "info" log_level: str = "info"
...@@ -71,18 +71,18 @@ class ServerArgs: ...@@ -71,18 +71,18 @@ class ServerArgs:
log_requests: bool = False log_requests: bool = False
show_time_cost: bool = False show_time_cost: bool = False
enable_metrics: bool = False enable_metrics: bool = False
decode_log_interval: int = 40
# Other # API related
api_key: Optional[str] = None api_key: Optional[str] = None
file_storage_pth: str = "SGLang_storage" file_storage_pth: str = "SGLang_storage"
enable_cache_report: bool = False enable_cache_report: bool = False
watchdog_timeout: float = 600
# Data parallelism # Data parallelism
dp_size: int = 1 dp_size: int = 1
load_balance_method: str = "round_robin" load_balance_method: str = "round_robin"
# Distributed args # Multi-node distributed serving
dist_init_addr: Optional[str] = None dist_init_addr: Optional[str] = None
nnodes: int = 1 nnodes: int = 1
node_rank: int = 0 node_rank: int = 0
...@@ -128,6 +128,7 @@ class ServerArgs: ...@@ -128,6 +128,7 @@ class ServerArgs:
enable_p2p_check: bool = False enable_p2p_check: bool = False
triton_attention_reduce_in_fp32: bool = False triton_attention_reduce_in_fp32: bool = False
num_continuous_decode_steps: int = 1 num_continuous_decode_steps: int = 1
delete_ckpt_after_loading: bool = False
def __post_init__(self): def __post_init__(self):
# Set missing default values # Set missing default values
...@@ -205,6 +206,7 @@ class ServerArgs: ...@@ -205,6 +206,7 @@ class ServerArgs:
@staticmethod @staticmethod
def add_cli_args(parser: argparse.ArgumentParser): def add_cli_args(parser: argparse.ArgumentParser):
# Model and port args
parser.add_argument( parser.add_argument(
"--model-path", "--model-path",
type=str, type=str,
...@@ -324,6 +326,8 @@ class ServerArgs: ...@@ -324,6 +326,8 @@ class ServerArgs:
action="store_true", action="store_true",
help="Whether to use a CausalLM as an embedding model.", help="Whether to use a CausalLM as an embedding model.",
) )
# Memory and scheduling
parser.add_argument( parser.add_argument(
"--mem-fraction-static", "--mem-fraction-static",
type=float, type=float,
...@@ -368,6 +372,8 @@ class ServerArgs: ...@@ -368,6 +372,8 @@ class ServerArgs:
default=ServerArgs.schedule_conservativeness, default=ServerArgs.schedule_conservativeness,
help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see requests being retracted frequently.", help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see requests being retracted frequently.",
) )
# Other runtime options
parser.add_argument( parser.add_argument(
"--tensor-parallel-size", "--tensor-parallel-size",
"--tp-size", "--tp-size",
...@@ -393,6 +399,14 @@ class ServerArgs: ...@@ -393,6 +399,14 @@ class ServerArgs:
default=ServerArgs.constrained_json_whitespace_pattern, default=ServerArgs.constrained_json_whitespace_pattern,
help=r"Regex pattern for syntactic whitespaces allowed in JSON constrained output. For example, to allow the model generate consecutive whitespaces, set the pattern to [\n\t ]*", help=r"Regex pattern for syntactic whitespaces allowed in JSON constrained output. For example, to allow the model generate consecutive whitespaces, set the pattern to [\n\t ]*",
) )
parser.add_argument(
"--watchdog-timeout",
type=float,
default=ServerArgs.watchdog_timeout,
help="Set watchdog timeout in seconds. If a forward batch takes longer than this, the server will crash to prevent hanging.",
)
# Logging
parser.add_argument( parser.add_argument(
"--log-level", "--log-level",
type=str, type=str,
...@@ -420,7 +434,14 @@ class ServerArgs: ...@@ -420,7 +434,14 @@ class ServerArgs:
action="store_true", action="store_true",
help="Enable log prometheus metrics.", help="Enable log prometheus metrics.",
) )
parser.add_argument(
"--decode-log-interval",
type=int,
default=ServerArgs.decode_log_interval,
help="The log interval of decode batch",
)
# API related
parser.add_argument( parser.add_argument(
"--api-key", "--api-key",
type=str, type=str,
...@@ -438,18 +459,6 @@ class ServerArgs: ...@@ -438,18 +459,6 @@ class ServerArgs:
action="store_true", action="store_true",
help="Return number of cached tokens in usage.prompt_tokens_details for each openai request.", help="Return number of cached tokens in usage.prompt_tokens_details for each openai request.",
) )
parser.add_argument(
"--watchdog-timeout",
type=float,
default=ServerArgs.watchdog_timeout,
help="Set watchdog timeout in seconds. If a forward batch takes longer than this, the server will crash to prevent hanging.",
)
parser.add_argument(
"--decode-log-interval",
type=int,
default=ServerArgs.decode_log_interval,
help="The log interval of decode batch",
)
# Data parallelism # Data parallelism
parser.add_argument( parser.add_argument(
...@@ -470,7 +479,7 @@ class ServerArgs: ...@@ -470,7 +479,7 @@ class ServerArgs:
], ],
) )
# Multi-node distributed serving args # Multi-node distributed serving
parser.add_argument( parser.add_argument(
"--dist-init-addr", "--dist-init-addr",
"--nccl-init-addr", # For backward compatbility. This will be removed in the future. "--nccl-init-addr", # For backward compatbility. This will be removed in the future.
...@@ -677,6 +686,12 @@ class ServerArgs: ...@@ -677,6 +686,12 @@ class ServerArgs:
"This can potentially increase throughput but may also increase time-to-first-token latency. " "This can potentially increase throughput but may also increase time-to-first-token latency. "
"The default value is 1, meaning only run one decoding step at a time.", "The default value is 1, meaning only run one decoding step at a time.",
) )
parser.add_argument(
"--delete-ckpt-after-loading",
default=ServerArgs.delete_ckpt_after_loading,
action="store_true",
help="Delete the model checkpoint after loading the model.",
)
@classmethod @classmethod
def from_cli_args(cls, args: argparse.Namespace): def from_cli_args(cls, args: argparse.Namespace):
......
...@@ -23,6 +23,8 @@ import os ...@@ -23,6 +23,8 @@ import os
import pickle import pickle
import random import random
import resource import resource
import shutil
import signal
import socket import socket
import time import time
import warnings import warnings
...@@ -35,6 +37,7 @@ import psutil ...@@ -35,6 +37,7 @@ import psutil
import requests import requests
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import triton
import zmq import zmq
from fastapi.responses import ORJSONResponse from fastapi.responses import ORJSONResponse
from packaging import version as pkg_version from packaging import version as pkg_version
...@@ -379,6 +382,10 @@ def kill_child_process(pid=None, include_self=False, skip_pid=None): ...@@ -379,6 +382,10 @@ def kill_child_process(pid=None, include_self=False, skip_pid=None):
if include_self: if include_self:
try: try:
itself.kill() itself.kill()
# Sometime processes cannot be killed with SIGKILL (e.g, PID=1 launched by kubernetes),
# so we send an additional signal to kill them.
itself.send_signal(signal.SIGINT)
except psutil.NoSuchProcess: except psutil.NoSuchProcess:
pass pass
...@@ -704,3 +711,44 @@ def get_zmq_socket(context: zmq.Context, socket_type: zmq.SocketType, endpoint: ...@@ -704,3 +711,44 @@ def get_zmq_socket(context: zmq.Context, socket_type: zmq.SocketType, endpoint:
raise ValueError(f"Unsupported socket type: {socket_type}") raise ValueError(f"Unsupported socket type: {socket_type}")
return socket return socket
def dump_to_file(dirpath, name, value):
from vllm.distributed import get_tensor_model_parallel_rank
if get_tensor_model_parallel_rank() != 0:
return
os.makedirs(dirpath, exist_ok=True)
if value.dtype is torch.bfloat16:
value = value.float()
value = value.cpu().numpy()
output_filename = os.path.join(dirpath, f"pytorch_dump_{name}.npy")
logger.info(f"Dump a tensor to {output_filename}. Shape = {value.shape}")
np.save(output_filename, value)
def is_triton_3():
return triton.__version__.startswith("3.")
def maybe_torch_compile(*args, **kwargs):
"""
torch.compile does not work for triton 2.2.0, which is needed in xlm1's jax.
Therefore, we disable it here.
"""
def decorator(func):
if is_triton_3():
return torch.compile(*args, **kwargs)(func)
return func
return decorator
def delete_directory(dirpath):
try:
# This will remove the directory and all its contents
shutil.rmtree(dirpath)
except OSError as e:
print(f"Warning: {dirpath} : {e.strerror}")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment