Unverified Commit 09593e9b authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

Multi-node Tensor Parallelism (#550)


Co-authored-by: default avatarLianmin Zheng <lianminzheng@gmail.com>
parent 53a7ebd8
...@@ -20,7 +20,7 @@ python3 bench_throughput.py --backend srt --tokenizer meta-llama/Llama-2-7b-chat ...@@ -20,7 +20,7 @@ python3 bench_throughput.py --backend srt --tokenizer meta-llama/Llama-2-7b-chat
``` ```
# run synthetic # run synthetic
python3 synthetic_benchmark.py --backend srt --tokenizer meta-llama/Llama-2-7b-chat-hf --num-prompt 1000 --request-rate 100 --input-len 1024 --output-len 256 --port 30000 python3 bench_throughput.py --backend srt --tokenizer meta-llama/Llama-2-7b-chat-hf --num-prompt 1000 --request-rate 100 --input-len 1024 --output-len 256 --port 30000
``` ```
...@@ -36,7 +36,7 @@ python3 bench_throughput.py --backend vllm --tokenizer meta-llama/Llama-2-7b-cha ...@@ -36,7 +36,7 @@ python3 bench_throughput.py --backend vllm --tokenizer meta-llama/Llama-2-7b-cha
``` ```
# run synthetic # run synthetic
python3 synthetic_benchmark.py --backend vllm --tokenizer meta-llama/Llama-2-7b-chat-hf --num-prompt 1000 --request-rate 100 --input-len 1024 --output-len 256 --port 30000 python3 bench_throughput.py --backend vllm --tokenizer meta-llama/Llama-2-7b-chat-hf --num-prompt 1000 --request-rate 100 --input-len 1024 --output-len 256 --port 30000
``` ```
......
...@@ -24,7 +24,7 @@ if __name__ == "__main__": ...@@ -24,7 +24,7 @@ if __name__ == "__main__":
raise ValueError(f"Invalid backend: {args.backend}") raise ValueError(f"Invalid backend: {args.backend}")
url = f"{args.host}:{args.port}" url = f"{args.host}:{args.port}"
a = random.randint(0, 1 << 20) a = 20
max_new_tokens = 256 max_new_tokens = 256
prompt = f"{a, }" prompt = f"{a, }"
......
...@@ -2,7 +2,8 @@ ...@@ -2,7 +2,8 @@
import argparse import argparse
from sglang.srt.server import ServerArgs, launch_server from sglang.srt.server import launch_server
from sglang.srt.server_args import ServerArgs
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
......
...@@ -76,8 +76,9 @@ def start_controller_process( ...@@ -76,8 +76,9 @@ def start_controller_process(
) )
try: try:
tp_size_local = server_args.tp_size // server_args.nnodes
model_client = ModelTpClient( model_client = ModelTpClient(
list(range(server_args.tp_size)), [i for _ in range(server_args.nnodes) for i in range(tp_size_local)],
server_args, server_args,
port_args.model_port_args[0], port_args.model_port_args[0],
model_overide_args, model_overide_args,
......
...@@ -246,12 +246,16 @@ class ModelRunner: ...@@ -246,12 +246,16 @@ class ModelRunner:
torch.cuda.set_device(self.gpu_id) torch.cuda.set_device(self.gpu_id)
logger.info(f"[gpu_id={self.gpu_id}] Init nccl begin.") logger.info(f"[gpu_id={self.gpu_id}] Init nccl begin.")
monkey_patch_vllm_p2p_access_check(self.gpu_id) monkey_patch_vllm_p2p_access_check(self.gpu_id)
if server_args.nccl_init_addr:
nccl_init_method = f"tcp://{server_args.nccl_init_addr}"
else:
nccl_init_method = f"tcp://127.0.0.1:{self.nccl_port}"
init_distributed_environment( init_distributed_environment(
backend="nccl", backend="nccl",
world_size=self.tp_size, world_size=self.tp_size,
rank=self.tp_rank, rank=self.tp_rank,
local_rank=self.gpu_id, local_rank=self.gpu_id,
distributed_init_method=f"tcp://127.0.0.1:{self.nccl_port}", distributed_init_method=nccl_init_method
) )
initialize_model_parallel(tensor_model_parallel_size=self.tp_size) initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
total_gpu_memory = get_available_gpu_memory( total_gpu_memory = get_available_gpu_memory(
...@@ -311,7 +315,7 @@ class ModelRunner: ...@@ -311,7 +315,7 @@ class ModelRunner:
self.gpu_id, distributed=self.tp_size > 1 self.gpu_id, distributed=self.tp_size > 1
) )
head_dim = self.model_config.head_dim head_dim = self.model_config.head_dim
head_num = self.model_config.num_key_value_heads // self.tp_size head_num = self.model_config.get_num_kv_heads(self.tp_size)
cell_size = head_num * head_dim * self.model_config.num_hidden_layers * 2 * 2 cell_size = head_num * head_dim * self.model_config.num_hidden_layers * 2 * 2
rest_memory = available_gpu_memory - total_gpu_memory * ( rest_memory = available_gpu_memory - total_gpu_memory * (
1 - self.mem_fraction_static 1 - self.mem_fraction_static
...@@ -324,7 +328,7 @@ class ModelRunner: ...@@ -324,7 +328,7 @@ class ModelRunner:
if self.max_total_num_tokens <= 0: if self.max_total_num_tokens <= 0:
raise RuntimeError( raise RuntimeError(
"Not enought memory. Please try to increase --mem-fraction-static." "Not enough memory. Please try to increase --mem-fraction-static."
) )
self.req_to_token_pool = ReqToTokenPool( self.req_to_token_pool = ReqToTokenPool(
......
...@@ -37,7 +37,8 @@ from sglang.srt.utils import ( ...@@ -37,7 +37,8 @@ from sglang.srt.utils import (
get_int_token_logit_bias, get_int_token_logit_bias,
is_multimodal_model, is_multimodal_model,
set_random_seed, set_random_seed,
start_rpyc_process, start_rpyc_service_process,
connect_rpyc_service,
suppress_other_loggers, suppress_other_loggers,
) )
from sglang.utils import get_exception_traceback from sglang.utils import get_exception_traceback
...@@ -770,12 +771,17 @@ class ModelTpClient: ...@@ -770,12 +771,17 @@ class ModelTpClient:
else: else:
with ThreadPoolExecutor(self.tp_size) as executor: with ThreadPoolExecutor(self.tp_size) as executor:
# Launch model processes # Launch model processes
rets = executor.map( if server_args.nnodes == 1:
lambda args: start_rpyc_process(*args), self.procs = list(executor.map(
[(ModelTpService, p) for p in model_port_args.model_tp_ports], lambda args: start_rpyc_service_process(*args),
) [(ModelTpService, p) for p in model_port_args.model_tp_ports],
self.model_services = [x[0] for x in rets] ))
self.procs = [x[1] for x in rets] addrs = [("localhost", p) for p in model_port_args.model_tp_ports]
else:
addrs = [(ip, port) for ip, port in zip(model_port_args.model_tp_ips, model_port_args.model_tp_ports)]
self.model_services = list(executor.map(
lambda args: connect_rpyc_service(*args), addrs))
# Init model # Init model
def init_model(i): def init_model(i):
...@@ -787,7 +793,7 @@ class ModelTpClient: ...@@ -787,7 +793,7 @@ class ModelTpClient:
model_overide_args, model_overide_args,
) )
self.model_servers = executor.map(init_model, range(self.tp_size)) self.model_servers = list(executor.map(init_model, range(self.tp_size)))
# Wrap functions # Wrap functions
def async_wrap(func_name): def async_wrap(func_name):
......
...@@ -71,7 +71,11 @@ class ModelConfig: ...@@ -71,7 +71,11 @@ class ModelConfig:
return 1 return 1
# For DBRX and MPT # For DBRX and MPT
if self.hf_config.model_type in ["dbrx", "mpt"]: if self.hf_config.model_type in ["mpt"]:
if "kv_n_heads" in self.hf_config.attn_config:
return self.hf_config.attn_config["kv_n_heads"]
return self.hf_config.num_attention_heads
if self.hf_config.model_type in ["dbrx"]:
return getattr( return getattr(
self.hf_config.attn_config, self.hf_config.attn_config,
"kv_n_heads", "kv_n_heads",
......
...@@ -35,6 +35,7 @@ from sglang.srt.managers.controller.manager_multi import ( ...@@ -35,6 +35,7 @@ from sglang.srt.managers.controller.manager_multi import (
from sglang.srt.managers.controller.manager_single import ( from sglang.srt.managers.controller.manager_single import (
start_controller_process as start_controller_process_single, start_controller_process as start_controller_process_single,
) )
from sglang.srt.managers.controller.tp_worker import ModelTpService
from sglang.srt.managers.detokenizer_manager import start_detokenizer_process from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
from sglang.srt.managers.io_struct import GenerateReqInput from sglang.srt.managers.io_struct import GenerateReqInput
from sglang.srt.managers.tokenizer_manager import TokenizerManager from sglang.srt.managers.tokenizer_manager import TokenizerManager
...@@ -50,9 +51,13 @@ from sglang.srt.utils import ( ...@@ -50,9 +51,13 @@ from sglang.srt.utils import (
allocate_init_ports, allocate_init_ports,
assert_pkg_version, assert_pkg_version,
enable_show_time_cost, enable_show_time_cost,
send_addrs_to_rank_0,
receive_addrs,
start_rpyc_service_process,
) )
from sglang.utils import get_exception_traceback from sglang.utils import get_exception_traceback
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
...@@ -151,21 +156,23 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg ...@@ -151,21 +156,23 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
load_chat_template_for_openai_api(server_args.chat_template) load_chat_template_for_openai_api(server_args.chat_template)
# Allocate ports # Allocate ports
assert server_args.tp_size % server_args.nnodes == 0
tp_size_local = server_args.tp_size // server_args.nnodes
server_args.port, server_args.additional_ports = allocate_init_ports( server_args.port, server_args.additional_ports = allocate_init_ports(
server_args.port, server_args.port,
server_args.additional_ports, server_args.additional_ports,
server_args.tp_size, tp_size_local,
server_args.dp_size, server_args.dp_size,
) )
ports = server_args.additional_ports ports = server_args.additional_ports
tp = server_args.tp_size
model_port_args = [] model_port_args = []
for i in range(server_args.dp_size): for i in range(server_args.dp_size):
model_port_args.append( model_port_args.append(
ModelPortArgs( ModelPortArgs(
nccl_port=ports[3 + i * (tp + 1)], nccl_port=ports[3 + i * (tp_size_local + 1)],
model_tp_ports=ports[3 + i * (tp + 1) + 1 : 3 + (i + 1) * (tp + 1)], model_tp_ips=[None] * tp_size_local,
model_tp_ports=ports[3 + i * (tp_size_local + 1) + 1 : 3 + (i + 1) * (tp_size_local + 1)],
) )
) )
port_args = PortArgs( port_args = PortArgs(
...@@ -175,6 +182,20 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg ...@@ -175,6 +182,20 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
model_port_args=model_port_args, model_port_args=model_port_args,
) )
# TODO multi-node dp is not supported
assert not (server_args.dp_size > 1 and server_args.node_rank is not None)
if server_args.nnodes > 1:
if server_args.node_rank != 0:
send_addrs_to_rank_0(model_port_args[0], server_args)
else:
receive_addrs(model_port_args[0], server_args)
for i in range(tp_size_local):
start_rpyc_service_process(ModelTpService, model_port_args[0].model_tp_ports[i])
if server_args.node_rank != 0:
print("Listen for connections...")
while True:
pass
# Launch processes # Launch processes
tokenizer_manager = TokenizerManager(server_args, port_args, model_overide_args) tokenizer_manager = TokenizerManager(server_args, port_args, model_overide_args)
pipe_router_reader, pipe_router_writer = mp.Pipe(duplex=False) pipe_router_reader, pipe_router_writer = mp.Pipe(duplex=False)
......
...@@ -56,6 +56,11 @@ class ServerArgs: ...@@ -56,6 +56,11 @@ class ServerArgs:
disable_regex_jump_forward: bool = False disable_regex_jump_forward: bool = False
disable_disk_cache: bool = False disable_disk_cache: bool = False
# Distributed args
nccl_init_addr: Optional[str] = None
nnodes: int = 1
node_rank: Optional[int] = None
def __post_init__(self): def __post_init__(self):
if self.tokenizer_path is None: if self.tokenizer_path is None:
self.tokenizer_path = self.model_path self.tokenizer_path = self.model_path
...@@ -252,6 +257,24 @@ class ServerArgs: ...@@ -252,6 +257,24 @@ class ServerArgs:
], ],
) )
# Multi-node distributed serving args
parser.add_argument(
"--nccl-init-addr",
type=str,
help="The nccl init address of multi-node server."
)
parser.add_argument(
"--nnodes",
type=int,
default=1,
help="Number of nodes"
)
parser.add_argument(
"--node-rank",
type=int,
help="The node rank."
)
# Optimization/debug options # Optimization/debug options
parser.add_argument( parser.add_argument(
"--enable-flashinfer", "--enable-flashinfer",
...@@ -300,6 +323,7 @@ class ServerArgs: ...@@ -300,6 +323,7 @@ class ServerArgs:
@dataclasses.dataclass @dataclasses.dataclass
class ModelPortArgs: class ModelPortArgs:
nccl_port: int nccl_port: int
model_tp_ips: List[str]
model_tp_ports: List[int] model_tp_ports: List[int]
......
"""Common utilities.""" """Common utilities."""
import base64 import base64
import fcntl
import logging import logging
import multiprocessing import multiprocessing
import os import os
import random import random
import socket import socket
import struct
import time import time
from importlib.metadata import PackageNotFoundError, version from importlib.metadata import PackageNotFoundError, version
from io import BytesIO from io import BytesIO
...@@ -369,23 +371,7 @@ def load_image(image_file): ...@@ -369,23 +371,7 @@ def load_image(image_file):
return image, image_size return image, image_size
def init_rpyc_service(service: rpyc.Service, port: int): def connect_rpyc_service(host, port):
t = ThreadedServer(
service=service,
port=port,
protocol_config={
"allow_public_attrs": True,
"allow_pickle": True,
"sync_request_timeout": 3600,
},
)
t.logger.setLevel(logging.WARN)
t.start()
def connect_to_rpyc_service(port, host="localhost"):
time.sleep(1)
repeat_count = 0 repeat_count = 0
while repeat_count < 20: while repeat_count < 20:
try: try:
...@@ -399,22 +385,33 @@ def connect_to_rpyc_service(port, host="localhost"): ...@@ -399,22 +385,33 @@ def connect_to_rpyc_service(port, host="localhost"):
}, },
) )
break break
except ConnectionRefusedError: except ConnectionRefusedError as e:
time.sleep(1) time.sleep(1)
repeat_count += 1 repeat_count += 1
if repeat_count == 20: if repeat_count == 20:
raise RuntimeError("init rpc env error!") raise RuntimeError(f"Connect rpyc error: {e}")
return con.root return con.root
def start_rpyc_process(service: rpyc.Service, port: int): def start_rpyc_service(service: rpyc.Service, port: int):
# Return the proxy and the process t = ThreadedServer(
proc = multiprocessing.Process(target=init_rpyc_service, args=(service, port)) service=service,
port=port,
protocol_config={
"allow_public_attrs": True,
"allow_pickle": True,
"sync_request_timeout": 3600,
},
)
t.logger.setLevel(logging.WARN)
t.start()
def start_rpyc_service_process(service: rpyc.Service, port: int):
proc = multiprocessing.Process(target=start_rpyc_service, args=(service, port))
proc.start() proc.start()
proxy = connect_to_rpyc_service(port) return proc
assert proc.is_alive()
return proxy, proc
def suppress_other_loggers(): def suppress_other_loggers():
...@@ -487,3 +484,66 @@ class APIKeyValidatorMiddleware(BaseHTTPMiddleware): ...@@ -487,3 +484,66 @@ class APIKeyValidatorMiddleware(BaseHTTPMiddleware):
) )
response = await call_next(request) response = await call_next(request)
return response return response
def get_ip_address(ifname):
"""
Get the IP address of a network interface.
:param ifname: Name of the network interface (e.g., 'eth0')
:return: IP address of the network interface
"""
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
ip_address = fcntl.ioctl(
s.fileno(),
0x8915, # SIOCGIFADDR
struct.pack('256s', bytes(ifname[:15], 'utf-8'))
)[20:24]
return socket.inet_ntoa(ip_address)
def send_addrs_to_rank_0(model_port_args, server_args):
assert server_args.node_rank != 0 and server_args.dp_size == 1
import torch.distributed as dist
ifname = os.environ.get("SGLANG_SOCKET_IFNAME", os.environ.get("NCCL_SOCKET_IFNAME", "eth0"))
ip_addr = get_ip_address(ifname)
num_tp_ports = server_args.tp_size // server_args.nnodes
model_port_args.model_tp_ips[:num_tp_ports] = [ip_addr] * num_tp_ports
ip_addr = [int(x) for x in ip_addr.split(".")]
addrs_tensor = torch.tensor(ip_addr + model_port_args.model_tp_ports, dtype=torch.int)
init_method = f"tcp://{server_args.nccl_init_addr}"
dist.init_process_group(backend="gloo", init_method=init_method, rank=server_args.node_rank, world_size=server_args.nnodes)
dist.send(addrs_tensor, dst=0)
print(f"Node {server_args.node_rank} sent: ip_address {ip_addr} and ports {model_port_args.model_tp_ports}")
dist.barrier()
dist.destroy_process_group()
def receive_addrs(model_port_args, server_args):
assert server_args.node_rank == 0 and server_args.dp_size == 1
import torch.distributed as dist
ifname = os.environ.get("SGLANG_SOCKET_IFNAME", os.environ.get("NCCL_SOCKET_IFNAME", "eth0"))
ip_addr = get_ip_address(ifname)
num_tp_ports = server_args.tp_size // server_args.nnodes
model_port_args.model_tp_ips[:num_tp_ports] = [ip_addr] * num_tp_ports
init_method = f"tcp://{server_args.nccl_init_addr}"
dist.init_process_group(backend="gloo", init_method=init_method, rank=server_args.node_rank, world_size=server_args.nnodes)
for src_rank in range(1, server_args.nnodes):
tensor = torch.zeros(4 + num_tp_ports, dtype=torch.int)
dist.recv(tensor, src=src_rank)
ip = ".".join([str(x) for x in tensor[:4].tolist()])
ports = tensor[4:].tolist()
model_port_args.model_tp_ips[num_tp_ports * src_rank: num_tp_ports * (src_rank + 1)] = [ip] * num_tp_ports
model_port_args.model_tp_ports[num_tp_ports * src_rank: num_tp_ports * (src_rank + 1)] = ports
print(f"Node 0 received from rank {src_rank}: {tensor.tolist()}")
dist.barrier()
dist.destroy_process_group()
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment