Unverified Commit 0463f7fb authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

Support data parallelism (static) (#480)


Co-authored-by: default avatarYing Sheng <ying.sheng@databricks.com>
Co-authored-by: default avatarLianmin Zheng <lianminzheng@gmail.com>
Co-authored-by: default avatarLiangsheng Yin <hnyls2002@gmail.com>
Co-authored-by: default avatarZhiqiang Xie <xiezhq@stanford.edu>
parent 565d7274
......@@ -35,7 +35,7 @@ from vllm.utils import print_warning_once
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.router.model_runner import InputMetadata
from sglang.srt.managers.controller.model_runner import InputMetadata
......
......@@ -30,7 +30,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.router.model_runner import InputMetadata
from sglang.srt.managers.controller.model_runner import InputMetadata
class MixtralMLP(nn.Module):
......
......@@ -23,7 +23,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.router.model_runner import InputMetadata
from sglang.srt.managers.controller.model_runner import InputMetadata
class QWenMLP(nn.Module):
......
......@@ -23,7 +23,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.router.model_runner import InputMetadata
from sglang.srt.managers.controller.model_runner import InputMetadata
Qwen2Config = None
......
......@@ -24,7 +24,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.router.model_runner import InputMetadata
from sglang.srt.managers.controller.model_runner import InputMetadata
class StablelmMLP(nn.Module):
......
......@@ -10,7 +10,7 @@ import sys
import threading
import time
from http import HTTPStatus
from typing import List, Optional, Union
from typing import Optional
# Fix a bug of Python threading
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
......@@ -28,14 +28,15 @@ from sglang.srt.constrained import disable_cache
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
from sglang.srt.managers.io_struct import GenerateReqInput
from sglang.srt.managers.router.manager import start_router_process
from sglang.srt.managers.controller.manager_single import start_controller_process as start_controller_process_single
from sglang.srt.managers.controller.manager_multi import start_controller_process as start_controller_process_multi
from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.srt.openai_api_adapter import (
load_chat_template_for_openai_api,
v1_chat_completions,
v1_completions,
)
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.server_args import ModelPortArgs, PortArgs, ServerArgs
from sglang.srt.utils import (
API_KEY_HEADER_NAME,
APIKeyValidatorMiddleware,
......@@ -141,14 +142,28 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
# Allocate ports
server_args.port, server_args.additional_ports = allocate_init_ports(
server_args.port, server_args.additional_ports, server_args.tp_size
server_args.port,
server_args.additional_ports,
server_args.tp_size,
server_args.dp_size,
)
# Init local models port args
ports = server_args.additional_ports
tp = server_args.tp_size
model_port_args = []
for i in range(server_args.dp_size):
model_port_args.append(
ModelPortArgs(
nccl_port=ports[3 + i * (tp + 1)],
model_tp_ports=ports[3 + i * (tp + 1) + 1 : 3 + (i + 1) * (tp + 1)],
)
)
port_args = PortArgs(
tokenizer_port=server_args.additional_ports[0],
router_port=server_args.additional_ports[1],
detokenizer_port=server_args.additional_ports[2],
nccl_port=server_args.additional_ports[3],
model_rpc_ports=server_args.additional_ports[4:],
tokenizer_port=ports[0],
router_port=ports[1],
detokenizer_port=ports[2],
model_port_args=model_port_args,
)
# Launch processes
......@@ -156,8 +171,12 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
pipe_router_reader, pipe_router_writer = mp.Pipe(duplex=False)
pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False)
if server_args.dp_size == 1:
start_process = start_controller_process_single
else:
start_process = start_controller_process_multi
proc_router = mp.Process(
target=start_router_process,
target=start_process,
args=(server_args, port_args, pipe_router_writer, model_overide_args),
)
proc_router.start()
......@@ -251,19 +270,20 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
class Runtime:
def __init__(
self,
log_evel: str = "error",
log_level: str = "error",
model_overide_args: Optional[dict] = None,
*args,
**kwargs,
):
"""See the arguments in server_args.py::ServerArgs"""
self.server_args = ServerArgs(*args, log_level=log_evel, **kwargs)
self.server_args = ServerArgs(*args, log_level=log_level, **kwargs)
# Pre-allocate ports
self.server_args.port, self.server_args.additional_ports = allocate_init_ports(
self.server_args.port,
self.server_args.additional_ports,
self.server_args.tp_size,
self.server_args.dp_size,
)
self.url = self.server_args.url()
......
......@@ -44,6 +44,10 @@ class ServerArgs:
# Other
api_key: str = ""
# Data parallelism
dp_size: int = 1
load_balance_method: str = "round_robin"
# Optimization/debug options
enable_flashinfer: bool = False
attention_reduce_in_fp32: bool = False
......@@ -226,6 +230,24 @@ class ServerArgs:
help="Set API key of the server",
)
# Data parallelism
parser.add_argument(
"--dp-size",
type=int,
default=ServerArgs.dp_size,
help="Data parallelism size.",
)
parser.add_argument(
"--load-balance-method",
type=str,
default=ServerArgs.load_balance_method,
help="Load balancing strategy for data parallelism.",
choices=[
"round_robin",
"shortest_queue",
],
)
# Optimization/debug options
parser.add_argument(
"--enable-flashinfer",
......@@ -271,10 +293,15 @@ class ServerArgs:
)
@dataclasses.dataclass
class ModelPortArgs:
nccl_port: int
model_tp_ports: List[int]
@dataclasses.dataclass
class PortArgs:
tokenizer_port: int
router_port: int
detokenizer_port: int
nccl_port: int
model_rpc_ports: List[int]
model_port_args: List[ModelPortArgs]
"""Common utilities."""
import base64
import multiprocessing
import logging
import os
import random
......@@ -12,12 +13,14 @@ from typing import List, Optional
import numpy as np
import requests
import rpyc
import torch
import triton
from rpyc.utils.server import ThreadedServer
from fastapi.responses import JSONResponse
from packaging import version as pkg_version
from starlette.middleware.base import BaseHTTPMiddleware
import torch.distributed as dist
logger = logging.getLogger(__name__)
......@@ -120,14 +123,16 @@ def get_available_gpu_memory(gpu_id, distributed=False):
def set_random_seed(seed: int) -> None:
"""Set the random seed for all libraries."""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def is_port_available(port):
"""Return whether a port is available."""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
try:
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
......@@ -142,7 +147,9 @@ def allocate_init_ports(
port: Optional[int] = None,
additional_ports: Optional[List[int]] = None,
tp_size: int = 1,
dp_size: int = 1,
):
"""Allocate ports for all connections."""
if additional_ports:
ret_ports = [port] + additional_ports
else:
......@@ -151,20 +158,23 @@ def allocate_init_ports(
ret_ports = list(set(x for x in ret_ports if is_port_available(x)))
cur_port = ret_ports[-1] + 1 if len(ret_ports) > 0 else 10000
while len(ret_ports) < 5 + tp_size:
# HTTP + Tokenizer + Controller + Detokenizer + dp_size * (nccl + tp_size)
num_ports_needed = 4 + dp_size * (1 + tp_size)
while len(ret_ports) < num_ports_needed:
if cur_port not in ret_ports and is_port_available(cur_port):
ret_ports.append(cur_port)
cur_port += 1
if port and ret_ports[0] != port:
if port is not None and ret_ports[0] != port:
logger.warn(
f"WARNING: Port {port} is not available. Use port {ret_ports[0]} instead."
)
return ret_ports[0], ret_ports[1:]
return ret_ports[0], ret_ports[1:num_ports_needed]
def get_int_token_logit_bias(tokenizer, vocab_size):
"""Get the logit bias for integer-only tokens."""
# a bug when model's vocab size > tokenizer.vocab_size
vocab_size = tokenizer.vocab_size
logit_bias = np.zeros(vocab_size, dtype=np.float32)
......@@ -181,12 +191,8 @@ def wrap_kernel_launcher(kernel):
if int(triton.__version__.split(".")[0]) >= 3:
return None
if dist.is_initialized():
rank = dist.get_rank()
else:
rank = 0
kernels = kernel.cache[rank].values()
gpu_id = torch.cuda.current_device()
kernels = kernel.cache[gpu_id].values()
kernel = next(iter(kernels))
# Different trition versions use different low-level names
......@@ -363,6 +369,63 @@ def load_image(image_file):
return image, image_size
def init_rpyc_service(service: rpyc.Service, port: int):
t = ThreadedServer(
service=service,
port=port,
protocol_config={
"allow_public_attrs": True,
"allow_pickle": True,
"sync_request_timeout": 3600
},
)
t.logger.setLevel(logging.WARN)
t.start()
def connect_to_rpyc_service(port, host="localhost"):
time.sleep(1)
repeat_count = 0
while repeat_count < 20:
try:
con = rpyc.connect(
host,
port,
config={
"allow_public_attrs": True,
"allow_pickle": True,
"sync_request_timeout": 3600
},
)
break
except ConnectionRefusedError:
time.sleep(1)
repeat_count += 1
if repeat_count == 20:
raise RuntimeError("init rpc env error!")
return con.root
def start_rpyc_process(service: rpyc.Service, port: int):
# Return the proxy and the process
proc = multiprocessing.Process(target=init_rpyc_service, args=(service, port))
proc.start()
proxy = connect_to_rpyc_service(port)
assert proc.is_alive()
return proxy, proc
def suppress_other_loggers():
from vllm.logger import logger as vllm_default_logger
vllm_default_logger.setLevel(logging.WARN)
logging.getLogger("vllm.utils").setLevel(logging.WARN)
logging.getLogger("vllm.selector").setLevel(logging.WARN)
logging.getLogger("vllm.config").setLevel(logging.ERROR)
def assert_pkg_version(pkg: str, min_version: str):
try:
installed_version = version(pkg)
......@@ -394,4 +457,4 @@ class APIKeyValidatorMiddleware(BaseHTTPMiddleware):
content={"detail": "Invalid API Key"},
)
response = await call_next(request)
return response
\ No newline at end of file
return response
......@@ -5,7 +5,7 @@ from dataclasses import dataclass
import torch
import torch.distributed as dist
from sglang.srt.managers.router.model_runner import ModelRunner
from sglang.srt.managers.controller.model_runner import ModelRunner
from sglang.srt.model_config import ModelConfig
......
......@@ -7,8 +7,8 @@ import torch
import torch.distributed as dist
import transformers
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode, Req
from sglang.srt.managers.router.model_runner import ModelRunner
from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode, Req
from sglang.srt.managers.controller.model_runner import ModelRunner
from sglang.srt.model_config import ModelConfig
from sglang.srt.sampling_params import SamplingParams
......
......@@ -5,7 +5,7 @@ import numpy as np
import torch
import torch.distributed as dist
from sglang.srt.managers.router.model_runner import ModelRunner
from sglang.srt.managers.controller.model_runner import ModelRunner
from sglang.srt.model_config import ModelConfig
......
......@@ -6,8 +6,8 @@ import torch
import torch.distributed as dist
from sglang.srt.hf_transformers_utils import get_processor
from sglang.srt.managers.router.infer_batch import ForwardMode
from sglang.srt.managers.router.model_runner import InputMetadata, ModelRunner
from sglang.srt.managers.controller.infer_batch import ForwardMode
from sglang.srt.managers.controller.model_runner import InputMetadata, ModelRunner
from sglang.srt.model_config import ModelConfig
from sglang.srt.utils import load_image
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment