Commit f719d9ae authored by Lianmin Zheng's avatar Lianmin Zheng
Browse files

Launch dp ranks in parallel (#2053)


Co-authored-by: default avatarHaotian Liu <6631389+haotian-liu@users.noreply.github.com>
parent edad3731
......@@ -3,7 +3,6 @@ from __future__ import annotations
from typing import TYPE_CHECKING
import torch
import torch.nn as nn
from sglang.srt.layers.attention import AttentionBackend
from sglang.srt.managers.schedule_batch import global_server_args_dict
......
......@@ -17,6 +17,7 @@ limitations under the License.
import logging
import multiprocessing as mp
import threading
from enum import Enum, auto
import zmq
......@@ -28,6 +29,7 @@ from sglang.srt.managers.io_struct import (
from sglang.srt.managers.scheduler import run_scheduler_process
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import (
bind_port,
configure_logger,
get_zmq_socket,
kill_parent_process,
......@@ -80,35 +82,62 @@ class DataParallelController:
# Start data parallel workers
base_gpu_id = 0
self.workers = []
scheduler_pipe_readers = []
self.workers = [None] * server_args.dp_size
threads = []
sockets = []
for dp_rank in range(server_args.dp_size):
tmp_port_args = PortArgs.init_new(server_args)
tmp_port_args.tokenizer_ipc_name = port_args.tokenizer_ipc_name
tmp_port_args.detokenizer_ipc_name = port_args.detokenizer_ipc_name
if server_args.enable_dp_attention:
# Share workers for DP and TP
send_to, reader = self.launch_tensor_parallel_process(
server_args,
tmp_port_args,
base_gpu_id,
dp_rank,
)
base_gpu_id += 1
scheduler_pipe_readers.append(reader)
# Data parallelism resues the tensor parallelism group,
# so all dp ranks should use the same nccl port.
tmp_port_args.nccl_port = port_args.nccl_port
else:
send_to = self.launch_tensor_parallel_group(
server_args,
tmp_port_args,
base_gpu_id,
dp_rank,
)
base_gpu_id += server_args.tp_size
self.workers.append(send_to)
for reader in scheduler_pipe_readers:
reader.recv()
# This port is checked free in PortArgs.init_new.
# We hold it first so that the next dp worker gets a different port
sockets.append(bind_port(tmp_port_args.nccl_port))
# Create a thread for each worker
thread = threading.Thread(
target=self.launch_worker_func,
args=(server_args, tmp_port_args, base_gpu_id, dp_rank),
)
threads.append(thread)
base_gpu_id += 1 if server_args.enable_dp_attention else server_args.tp_size
# Free all sockets before starting the threads to launch TP workers
for sock in sockets:
sock.close()
# Start all threads
for thread in threads:
thread.start()
for thread in threads:
thread.join()
def launch_worker_func(
self,
server_args: ServerArgs,
port_args: PortArgs,
base_gpu_id: int,
dp_rank: int,
):
logger.info(f"Launch DP{dp_rank} starting at GPU #{base_gpu_id}.")
launch_func_ = (
self.launch_tensor_parallel_process
if server_args.enable_dp_attention
else self.launch_tensor_parallel_group
)
self.workers[dp_rank] = launch_func_(
server_args,
port_args,
base_gpu_id,
dp_rank,
)
def launch_tensor_parallel_group(
self,
......@@ -164,8 +193,8 @@ class DataParallelController:
send_to = get_zmq_socket(
self.context, zmq.PUSH, port_args.scheduler_input_ipc_name
)
return send_to, reader
reader.recv()
return send_to
def round_robin_scheduler(self, req):
self.workers[self.round_robin_counter].send_pyobj(req)
......
......@@ -159,7 +159,7 @@ class ServerArgs:
if self.tp_size >= 16:
self.mem_fraction_static = 0.79
elif self.tp_size >= 8:
self.mem_fraction_static = 0.83
self.mem_fraction_static = 0.82
elif self.tp_size >= 4:
self.mem_fraction_static = 0.85
elif self.tp_size >= 2:
......@@ -211,7 +211,7 @@ class ServerArgs:
self.enable_overlap_schedule = False
logger.warning(
f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE workload issue. "
"The CUDA graph is disabled."
"The CUDA graph is disabled. Data parallel size is adjust to be the same as tensor parallel size."
)
if self.enable_overlap_schedule:
......
......@@ -794,6 +794,15 @@ def add_prometheus_middleware(app):
app.routes.append(metrics_route)
def bind_port(port):
"""Bind to a specific port, assuming it's available."""
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) # Allows address reuse
sock.bind(("", port))
sock.listen(1)
return sock
def get_amdgpu_memory_capacity():
try:
# Run rocm-smi and capture the output
......
......@@ -24,8 +24,6 @@ class TestDPAttention(unittest.TestCase):
"--trust-remote-code",
"--tp",
"2",
"--dp",
"2",
"--enable-dp-attention",
],
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment