Unverified Commit 0463f7fb authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

Support data parallelism (static) (#480)


Co-authored-by: default avatarYing Sheng <ying.sheng@databricks.com>
Co-authored-by: default avatarLianmin Zheng <lianminzheng@gmail.com>
Co-authored-by: default avatarLiangsheng Yin <hnyls2002@gmail.com>
Co-authored-by: default avatarZhiqiang Xie <xiezhq@stanford.edu>
parent 565d7274
...@@ -35,7 +35,7 @@ from vllm.utils import print_warning_once ...@@ -35,7 +35,7 @@ from vllm.utils import print_warning_once
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.router.model_runner import InputMetadata from sglang.srt.managers.controller.model_runner import InputMetadata
......
...@@ -30,7 +30,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -30,7 +30,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.router.model_runner import InputMetadata from sglang.srt.managers.controller.model_runner import InputMetadata
class MixtralMLP(nn.Module): class MixtralMLP(nn.Module):
......
...@@ -23,7 +23,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -23,7 +23,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.router.model_runner import InputMetadata from sglang.srt.managers.controller.model_runner import InputMetadata
class QWenMLP(nn.Module): class QWenMLP(nn.Module):
......
...@@ -23,7 +23,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -23,7 +23,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.router.model_runner import InputMetadata from sglang.srt.managers.controller.model_runner import InputMetadata
Qwen2Config = None Qwen2Config = None
......
...@@ -24,7 +24,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -24,7 +24,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.router.model_runner import InputMetadata from sglang.srt.managers.controller.model_runner import InputMetadata
class StablelmMLP(nn.Module): class StablelmMLP(nn.Module):
......
...@@ -10,7 +10,7 @@ import sys ...@@ -10,7 +10,7 @@ import sys
import threading import threading
import time import time
from http import HTTPStatus from http import HTTPStatus
from typing import List, Optional, Union from typing import Optional
# Fix a bug of Python threading # Fix a bug of Python threading
setattr(threading, "_register_atexit", lambda *args, **kwargs: None) setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
...@@ -28,14 +28,15 @@ from sglang.srt.constrained import disable_cache ...@@ -28,14 +28,15 @@ from sglang.srt.constrained import disable_cache
from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.managers.detokenizer_manager import start_detokenizer_process from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
from sglang.srt.managers.io_struct import GenerateReqInput from sglang.srt.managers.io_struct import GenerateReqInput
from sglang.srt.managers.router.manager import start_router_process from sglang.srt.managers.controller.manager_single import start_controller_process as start_controller_process_single
from sglang.srt.managers.controller.manager_multi import start_controller_process as start_controller_process_multi
from sglang.srt.managers.tokenizer_manager import TokenizerManager from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.srt.openai_api_adapter import ( from sglang.srt.openai_api_adapter import (
load_chat_template_for_openai_api, load_chat_template_for_openai_api,
v1_chat_completions, v1_chat_completions,
v1_completions, v1_completions,
) )
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import ModelPortArgs, PortArgs, ServerArgs
from sglang.srt.utils import ( from sglang.srt.utils import (
API_KEY_HEADER_NAME, API_KEY_HEADER_NAME,
APIKeyValidatorMiddleware, APIKeyValidatorMiddleware,
...@@ -141,14 +142,28 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg ...@@ -141,14 +142,28 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
# Allocate ports # Allocate ports
server_args.port, server_args.additional_ports = allocate_init_ports( server_args.port, server_args.additional_ports = allocate_init_ports(
server_args.port, server_args.additional_ports, server_args.tp_size server_args.port,
server_args.additional_ports,
server_args.tp_size,
server_args.dp_size,
)
# Init local models port args
ports = server_args.additional_ports
tp = server_args.tp_size
model_port_args = []
for i in range(server_args.dp_size):
model_port_args.append(
ModelPortArgs(
nccl_port=ports[3 + i * (tp + 1)],
model_tp_ports=ports[3 + i * (tp + 1) + 1 : 3 + (i + 1) * (tp + 1)],
)
) )
port_args = PortArgs( port_args = PortArgs(
tokenizer_port=server_args.additional_ports[0], tokenizer_port=ports[0],
router_port=server_args.additional_ports[1], router_port=ports[1],
detokenizer_port=server_args.additional_ports[2], detokenizer_port=ports[2],
nccl_port=server_args.additional_ports[3], model_port_args=model_port_args,
model_rpc_ports=server_args.additional_ports[4:],
) )
# Launch processes # Launch processes
...@@ -156,8 +171,12 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg ...@@ -156,8 +171,12 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
pipe_router_reader, pipe_router_writer = mp.Pipe(duplex=False) pipe_router_reader, pipe_router_writer = mp.Pipe(duplex=False)
pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False) pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False)
if server_args.dp_size == 1:
start_process = start_controller_process_single
else:
start_process = start_controller_process_multi
proc_router = mp.Process( proc_router = mp.Process(
target=start_router_process, target=start_process,
args=(server_args, port_args, pipe_router_writer, model_overide_args), args=(server_args, port_args, pipe_router_writer, model_overide_args),
) )
proc_router.start() proc_router.start()
...@@ -251,19 +270,20 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg ...@@ -251,19 +270,20 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
class Runtime: class Runtime:
def __init__( def __init__(
self, self,
log_evel: str = "error", log_level: str = "error",
model_overide_args: Optional[dict] = None, model_overide_args: Optional[dict] = None,
*args, *args,
**kwargs, **kwargs,
): ):
"""See the arguments in server_args.py::ServerArgs""" """See the arguments in server_args.py::ServerArgs"""
self.server_args = ServerArgs(*args, log_level=log_evel, **kwargs) self.server_args = ServerArgs(*args, log_level=log_level, **kwargs)
# Pre-allocate ports # Pre-allocate ports
self.server_args.port, self.server_args.additional_ports = allocate_init_ports( self.server_args.port, self.server_args.additional_ports = allocate_init_ports(
self.server_args.port, self.server_args.port,
self.server_args.additional_ports, self.server_args.additional_ports,
self.server_args.tp_size, self.server_args.tp_size,
self.server_args.dp_size,
) )
self.url = self.server_args.url() self.url = self.server_args.url()
......
...@@ -44,6 +44,10 @@ class ServerArgs: ...@@ -44,6 +44,10 @@ class ServerArgs:
# Other # Other
api_key: str = "" api_key: str = ""
# Data parallelism
dp_size: int = 1
load_balance_method: str = "round_robin"
# Optimization/debug options # Optimization/debug options
enable_flashinfer: bool = False enable_flashinfer: bool = False
attention_reduce_in_fp32: bool = False attention_reduce_in_fp32: bool = False
...@@ -226,6 +230,24 @@ class ServerArgs: ...@@ -226,6 +230,24 @@ class ServerArgs:
help="Set API key of the server", help="Set API key of the server",
) )
# Data parallelism
parser.add_argument(
"--dp-size",
type=int,
default=ServerArgs.dp_size,
help="Data parallelism size.",
)
parser.add_argument(
"--load-balance-method",
type=str,
default=ServerArgs.load_balance_method,
help="Load balancing strategy for data parallelism.",
choices=[
"round_robin",
"shortest_queue",
],
)
# Optimization/debug options # Optimization/debug options
parser.add_argument( parser.add_argument(
"--enable-flashinfer", "--enable-flashinfer",
...@@ -271,10 +293,15 @@ class ServerArgs: ...@@ -271,10 +293,15 @@ class ServerArgs:
) )
@dataclasses.dataclass
class ModelPortArgs:
nccl_port: int
model_tp_ports: List[int]
@dataclasses.dataclass @dataclasses.dataclass
class PortArgs: class PortArgs:
tokenizer_port: int tokenizer_port: int
router_port: int router_port: int
detokenizer_port: int detokenizer_port: int
nccl_port: int model_port_args: List[ModelPortArgs]
model_rpc_ports: List[int]
"""Common utilities.""" """Common utilities."""
import base64 import base64
import multiprocessing
import logging import logging
import os import os
import random import random
...@@ -12,12 +13,14 @@ from typing import List, Optional ...@@ -12,12 +13,14 @@ from typing import List, Optional
import numpy as np import numpy as np
import requests import requests
import rpyc
import torch import torch
import triton import triton
from rpyc.utils.server import ThreadedServer
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from packaging import version as pkg_version from packaging import version as pkg_version
from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.base import BaseHTTPMiddleware
import torch.distributed as dist
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -120,14 +123,16 @@ def get_available_gpu_memory(gpu_id, distributed=False): ...@@ -120,14 +123,16 @@ def get_available_gpu_memory(gpu_id, distributed=False):
def set_random_seed(seed: int) -> None: def set_random_seed(seed: int) -> None:
"""Set the random seed for all libraries."""
random.seed(seed) random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed) torch.manual_seed(seed)
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed) torch.cuda.manual_seed_all(seed)
def is_port_available(port): def is_port_available(port):
"""Return whether a port is available."""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
try: try:
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
...@@ -142,7 +147,9 @@ def allocate_init_ports( ...@@ -142,7 +147,9 @@ def allocate_init_ports(
port: Optional[int] = None, port: Optional[int] = None,
additional_ports: Optional[List[int]] = None, additional_ports: Optional[List[int]] = None,
tp_size: int = 1, tp_size: int = 1,
dp_size: int = 1,
): ):
"""Allocate ports for all connections."""
if additional_ports: if additional_ports:
ret_ports = [port] + additional_ports ret_ports = [port] + additional_ports
else: else:
...@@ -151,20 +158,23 @@ def allocate_init_ports( ...@@ -151,20 +158,23 @@ def allocate_init_ports(
ret_ports = list(set(x for x in ret_ports if is_port_available(x))) ret_ports = list(set(x for x in ret_ports if is_port_available(x)))
cur_port = ret_ports[-1] + 1 if len(ret_ports) > 0 else 10000 cur_port = ret_ports[-1] + 1 if len(ret_ports) > 0 else 10000
while len(ret_ports) < 5 + tp_size: # HTTP + Tokenizer + Controller + Detokenizer + dp_size * (nccl + tp_size)
num_ports_needed = 4 + dp_size * (1 + tp_size)
while len(ret_ports) < num_ports_needed:
if cur_port not in ret_ports and is_port_available(cur_port): if cur_port not in ret_ports and is_port_available(cur_port):
ret_ports.append(cur_port) ret_ports.append(cur_port)
cur_port += 1 cur_port += 1
if port and ret_ports[0] != port: if port is not None and ret_ports[0] != port:
logger.warn( logger.warn(
f"WARNING: Port {port} is not available. Use port {ret_ports[0]} instead." f"WARNING: Port {port} is not available. Use port {ret_ports[0]} instead."
) )
return ret_ports[0], ret_ports[1:] return ret_ports[0], ret_ports[1:num_ports_needed]
def get_int_token_logit_bias(tokenizer, vocab_size): def get_int_token_logit_bias(tokenizer, vocab_size):
"""Get the logit bias for integer-only tokens."""
# a bug when model's vocab size > tokenizer.vocab_size # a bug when model's vocab size > tokenizer.vocab_size
vocab_size = tokenizer.vocab_size vocab_size = tokenizer.vocab_size
logit_bias = np.zeros(vocab_size, dtype=np.float32) logit_bias = np.zeros(vocab_size, dtype=np.float32)
...@@ -181,12 +191,8 @@ def wrap_kernel_launcher(kernel): ...@@ -181,12 +191,8 @@ def wrap_kernel_launcher(kernel):
if int(triton.__version__.split(".")[0]) >= 3: if int(triton.__version__.split(".")[0]) >= 3:
return None return None
if dist.is_initialized(): gpu_id = torch.cuda.current_device()
rank = dist.get_rank() kernels = kernel.cache[gpu_id].values()
else:
rank = 0
kernels = kernel.cache[rank].values()
kernel = next(iter(kernels)) kernel = next(iter(kernels))
# Different trition versions use different low-level names # Different trition versions use different low-level names
...@@ -363,6 +369,63 @@ def load_image(image_file): ...@@ -363,6 +369,63 @@ def load_image(image_file):
return image, image_size return image, image_size
def init_rpyc_service(service: rpyc.Service, port: int):
t = ThreadedServer(
service=service,
port=port,
protocol_config={
"allow_public_attrs": True,
"allow_pickle": True,
"sync_request_timeout": 3600
},
)
t.logger.setLevel(logging.WARN)
t.start()
def connect_to_rpyc_service(port, host="localhost"):
time.sleep(1)
repeat_count = 0
while repeat_count < 20:
try:
con = rpyc.connect(
host,
port,
config={
"allow_public_attrs": True,
"allow_pickle": True,
"sync_request_timeout": 3600
},
)
break
except ConnectionRefusedError:
time.sleep(1)
repeat_count += 1
if repeat_count == 20:
raise RuntimeError("init rpc env error!")
return con.root
def start_rpyc_process(service: rpyc.Service, port: int):
# Return the proxy and the process
proc = multiprocessing.Process(target=init_rpyc_service, args=(service, port))
proc.start()
proxy = connect_to_rpyc_service(port)
assert proc.is_alive()
return proxy, proc
def suppress_other_loggers():
from vllm.logger import logger as vllm_default_logger
vllm_default_logger.setLevel(logging.WARN)
logging.getLogger("vllm.utils").setLevel(logging.WARN)
logging.getLogger("vllm.selector").setLevel(logging.WARN)
logging.getLogger("vllm.config").setLevel(logging.ERROR)
def assert_pkg_version(pkg: str, min_version: str): def assert_pkg_version(pkg: str, min_version: str):
try: try:
installed_version = version(pkg) installed_version = version(pkg)
......
...@@ -5,7 +5,7 @@ from dataclasses import dataclass ...@@ -5,7 +5,7 @@ from dataclasses import dataclass
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from sglang.srt.managers.router.model_runner import ModelRunner from sglang.srt.managers.controller.model_runner import ModelRunner
from sglang.srt.model_config import ModelConfig from sglang.srt.model_config import ModelConfig
......
...@@ -7,8 +7,8 @@ import torch ...@@ -7,8 +7,8 @@ import torch
import torch.distributed as dist import torch.distributed as dist
import transformers import transformers
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode, Req from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode, Req
from sglang.srt.managers.router.model_runner import ModelRunner from sglang.srt.managers.controller.model_runner import ModelRunner
from sglang.srt.model_config import ModelConfig from sglang.srt.model_config import ModelConfig
from sglang.srt.sampling_params import SamplingParams from sglang.srt.sampling_params import SamplingParams
......
...@@ -5,7 +5,7 @@ import numpy as np ...@@ -5,7 +5,7 @@ import numpy as np
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from sglang.srt.managers.router.model_runner import ModelRunner from sglang.srt.managers.controller.model_runner import ModelRunner
from sglang.srt.model_config import ModelConfig from sglang.srt.model_config import ModelConfig
......
...@@ -6,8 +6,8 @@ import torch ...@@ -6,8 +6,8 @@ import torch
import torch.distributed as dist import torch.distributed as dist
from sglang.srt.hf_transformers_utils import get_processor from sglang.srt.hf_transformers_utils import get_processor
from sglang.srt.managers.router.infer_batch import ForwardMode from sglang.srt.managers.controller.infer_batch import ForwardMode
from sglang.srt.managers.router.model_runner import InputMetadata, ModelRunner from sglang.srt.managers.controller.model_runner import InputMetadata, ModelRunner
from sglang.srt.model_config import ModelConfig from sglang.srt.model_config import ModelConfig
from sglang.srt.utils import load_image from sglang.srt.utils import load_image
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment