"model/vscode:/vscode.git/clone" did not exist on "08fbb60bb24bcc178e6ca630e4bb2ef313f2d274"
Commit f719d9ae authored by Lianmin Zheng's avatar Lianmin Zheng
Browse files

Launch dp ranks in parallel (#2053)


Co-authored-by: default avatarHaotian Liu <6631389+haotian-liu@users.noreply.github.com>
parent edad3731
...@@ -3,7 +3,6 @@ from __future__ import annotations ...@@ -3,7 +3,6 @@ from __future__ import annotations
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import torch import torch
import torch.nn as nn
from sglang.srt.layers.attention import AttentionBackend from sglang.srt.layers.attention import AttentionBackend
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
......
...@@ -17,6 +17,7 @@ limitations under the License. ...@@ -17,6 +17,7 @@ limitations under the License.
import logging import logging
import multiprocessing as mp import multiprocessing as mp
import threading
from enum import Enum, auto from enum import Enum, auto
import zmq import zmq
...@@ -28,6 +29,7 @@ from sglang.srt.managers.io_struct import ( ...@@ -28,6 +29,7 @@ from sglang.srt.managers.io_struct import (
from sglang.srt.managers.scheduler import run_scheduler_process from sglang.srt.managers.scheduler import run_scheduler_process
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import ( from sglang.srt.utils import (
bind_port,
configure_logger, configure_logger,
get_zmq_socket, get_zmq_socket,
kill_parent_process, kill_parent_process,
...@@ -80,35 +82,62 @@ class DataParallelController: ...@@ -80,35 +82,62 @@ class DataParallelController:
# Start data parallel workers # Start data parallel workers
base_gpu_id = 0 base_gpu_id = 0
self.workers = [] self.workers = [None] * server_args.dp_size
scheduler_pipe_readers = []
threads = []
sockets = []
for dp_rank in range(server_args.dp_size): for dp_rank in range(server_args.dp_size):
tmp_port_args = PortArgs.init_new(server_args) tmp_port_args = PortArgs.init_new(server_args)
tmp_port_args.tokenizer_ipc_name = port_args.tokenizer_ipc_name tmp_port_args.tokenizer_ipc_name = port_args.tokenizer_ipc_name
tmp_port_args.detokenizer_ipc_name = port_args.detokenizer_ipc_name tmp_port_args.detokenizer_ipc_name = port_args.detokenizer_ipc_name
if server_args.enable_dp_attention: if server_args.enable_dp_attention:
# Share workers for DP and TP # Data parallelism resues the tensor parallelism group,
send_to, reader = self.launch_tensor_parallel_process( # so all dp ranks should use the same nccl port.
server_args, tmp_port_args.nccl_port = port_args.nccl_port
tmp_port_args,
base_gpu_id,
dp_rank,
)
base_gpu_id += 1
scheduler_pipe_readers.append(reader)
else: else:
send_to = self.launch_tensor_parallel_group( # This port is checked free in PortArgs.init_new.
server_args, # We hold it first so that the next dp worker gets a different port
tmp_port_args, sockets.append(bind_port(tmp_port_args.nccl_port))
base_gpu_id,
dp_rank, # Create a thread for each worker
) thread = threading.Thread(
base_gpu_id += server_args.tp_size target=self.launch_worker_func,
self.workers.append(send_to) args=(server_args, tmp_port_args, base_gpu_id, dp_rank),
)
for reader in scheduler_pipe_readers: threads.append(thread)
reader.recv() base_gpu_id += 1 if server_args.enable_dp_attention else server_args.tp_size
# Free all sockets before starting the threads to launch TP workers
for sock in sockets:
sock.close()
# Start all threads
for thread in threads:
thread.start()
for thread in threads:
thread.join()
def launch_worker_func(
self,
server_args: ServerArgs,
port_args: PortArgs,
base_gpu_id: int,
dp_rank: int,
):
logger.info(f"Launch DP{dp_rank} starting at GPU #{base_gpu_id}.")
launch_func_ = (
self.launch_tensor_parallel_process
if server_args.enable_dp_attention
else self.launch_tensor_parallel_group
)
self.workers[dp_rank] = launch_func_(
server_args,
port_args,
base_gpu_id,
dp_rank,
)
def launch_tensor_parallel_group( def launch_tensor_parallel_group(
self, self,
...@@ -164,8 +193,8 @@ class DataParallelController: ...@@ -164,8 +193,8 @@ class DataParallelController:
send_to = get_zmq_socket( send_to = get_zmq_socket(
self.context, zmq.PUSH, port_args.scheduler_input_ipc_name self.context, zmq.PUSH, port_args.scheduler_input_ipc_name
) )
reader.recv()
return send_to, reader return send_to
def round_robin_scheduler(self, req): def round_robin_scheduler(self, req):
self.workers[self.round_robin_counter].send_pyobj(req) self.workers[self.round_robin_counter].send_pyobj(req)
......
...@@ -159,7 +159,7 @@ class ServerArgs: ...@@ -159,7 +159,7 @@ class ServerArgs:
if self.tp_size >= 16: if self.tp_size >= 16:
self.mem_fraction_static = 0.79 self.mem_fraction_static = 0.79
elif self.tp_size >= 8: elif self.tp_size >= 8:
self.mem_fraction_static = 0.83 self.mem_fraction_static = 0.82
elif self.tp_size >= 4: elif self.tp_size >= 4:
self.mem_fraction_static = 0.85 self.mem_fraction_static = 0.85
elif self.tp_size >= 2: elif self.tp_size >= 2:
...@@ -211,7 +211,7 @@ class ServerArgs: ...@@ -211,7 +211,7 @@ class ServerArgs:
self.enable_overlap_schedule = False self.enable_overlap_schedule = False
logger.warning( logger.warning(
f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE workload issue. " f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE workload issue. "
"The CUDA graph is disabled." "The CUDA graph is disabled. Data parallel size is adjust to be the same as tensor parallel size."
) )
if self.enable_overlap_schedule: if self.enable_overlap_schedule:
......
...@@ -794,6 +794,15 @@ def add_prometheus_middleware(app): ...@@ -794,6 +794,15 @@ def add_prometheus_middleware(app):
app.routes.append(metrics_route) app.routes.append(metrics_route)
def bind_port(port):
"""Bind to a specific port, assuming it's available."""
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) # Allows address reuse
sock.bind(("", port))
sock.listen(1)
return sock
def get_amdgpu_memory_capacity(): def get_amdgpu_memory_capacity():
try: try:
# Run rocm-smi and capture the output # Run rocm-smi and capture the output
......
...@@ -24,8 +24,6 @@ class TestDPAttention(unittest.TestCase): ...@@ -24,8 +24,6 @@ class TestDPAttention(unittest.TestCase):
"--trust-remote-code", "--trust-remote-code",
"--tp", "--tp",
"2", "2",
"--dp",
"2",
"--enable-dp-attention", "--enable-dp-attention",
], ],
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment