Unverified Commit a20e7df8 authored by Yongtong Wu's avatar Yongtong Wu Committed by GitHub
Browse files

Improve dp attention port assignment scheme (#5889)


Co-authored-by: default avatarCheng Wan <cwan@x.ai>
parent 1bdd0102
......@@ -21,7 +21,7 @@ import threading
import time
from collections import deque
from enum import Enum, auto
from typing import List
from typing import List, Optional
import psutil
import setproctitle
......@@ -36,7 +36,11 @@ from sglang.srt.managers.io_struct import (
)
from sglang.srt.managers.schedule_batch import Req
from sglang.srt.managers.scheduler import run_scheduler_process
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.server_args import (
DP_ATTENTION_HANDSHAKE_PORT_DELTA,
PortArgs,
ServerArgs,
)
from sglang.srt.utils import (
bind_port,
configure_logger,
......@@ -140,22 +144,12 @@ class DataParallelController:
self.workers: List[zmq.Socket] = [None] * server_args.dp_size
if server_args.enable_dp_attention:
dp_port_args = self.launch_dp_attention_schedulers(server_args, port_args)
self.launch_dp_attention_schedulers(server_args, port_args)
self.control_message_step = server_args.tp_size
else:
dp_port_args = self.launch_dp_schedulers(server_args, port_args)
self.launch_dp_schedulers(server_args, port_args)
self.control_message_step = 1
# Only node rank 0 runs the real data parallel controller that dispatches the requests.
if server_args.node_rank == 0:
for dp_rank in range(server_args.dp_size):
self.workers[dp_rank] = get_zmq_socket(
self.context,
zmq.PUSH,
dp_port_args[dp_rank].scheduler_input_ipc_name,
True,
)
self.max_req_input_len = None
self.init_dispatcher()
......@@ -188,13 +182,11 @@ class DataParallelController:
threads = []
sockets = []
dp_port_args = []
ready_events = []
for dp_rank in range(server_args.dp_size):
tmp_port_args = PortArgs.init_new(server_args)
tmp_port_args.tokenizer_ipc_name = port_args.tokenizer_ipc_name
tmp_port_args.detokenizer_ipc_name = port_args.detokenizer_ipc_name
dp_port_args.append(tmp_port_args)
# This port is checked free in PortArgs.init_new.
# We hold it first so that the next dp worker gets a different port
......@@ -213,6 +205,14 @@ class DataParallelController:
server_args.tp_size * server_args.pp_size * server_args.gpu_id_step
)
if server_args.node_rank == 0:
self.workers[dp_rank] = get_zmq_socket(
self.context,
zmq.PUSH,
tmp_port_args.scheduler_input_ipc_name,
True,
)
# Free all sockets before starting the threads to launch TP workers
for sock in sockets:
sock.close()
......@@ -223,8 +223,6 @@ class DataParallelController:
for event in ready_events:
event.wait()
return dp_port_args
def launch_tensor_parallel_group_thread(
self,
server_args: ServerArgs,
......@@ -241,19 +239,115 @@ class DataParallelController:
while True:
time.sleep(30 * 24 * 3600)
def launch_dp_attention_schedulers(self, server_args, port_args):
self.launch_tensor_parallel_group(server_args, port_args, 0, None)
dp_port_args = []
for dp_rank in range(server_args.dp_size):
dp_port_args.append(PortArgs.init_new(server_args, dp_rank))
return dp_port_args
def _broadcast_worker_ports(
self, server_args: ServerArgs, worker_ports: Optional[List[int]] = None
) -> List[int]:
"""Broadcast worker ports from node 0 to all other nodes.
Node 0 acts as the server, waiting for all other nodes to connect and
sending them the pre-allocated worker ports. Other nodes act as clients,
connecting to node 0 to receive their copy of the worker ports.
Args:
server_args: Server arguments containing node configuration.
worker_ports: Pre-allocated worker ports to broadcast.
Returns:
List of worker ports (same on all nodes after broadcast).
"""
# Determine the endpoint for inter-node communication
if server_args.dist_init_addr is None:
endpoint = f"tcp://127.0.0.1:{server_args.port + DP_ATTENTION_HANDSHAKE_PORT_DELTA}"
else:
endpoint = f"tcp://{server_args.dist_init_addr}"
if server_args.node_rank == 0:
# Node 0: Broadcast worker ports to all other nodes
return self._broadcast_ports_as_server(
endpoint, server_args.nnodes - 1, worker_ports
)
else:
# Other nodes: Receive worker ports from node 0
return self._receive_ports_as_client(endpoint, server_args.node_rank)
def _broadcast_ports_as_server(
self, endpoint: str, expected_clients: int, worker_ports: List[int]
) -> List[int]:
"""Broadcast worker ports to all client nodes."""
logger.debug(f"Broadcasting worker ports to {expected_clients} client nodes")
logger.debug(f"Worker ports: {worker_ports}")
rep_socket = get_zmq_socket(self.context, zmq.REP, endpoint, True)
try:
connected_clients = 0
while connected_clients < expected_clients:
# Wait for client handshake
client_rank = rep_socket.recv().decode()
logger.debug(f"Received handshake from node {client_rank}")
# Send worker ports to client
rep_socket.send_pyobj(worker_ports)
connected_clients += 1
logger.debug(
f"Sent worker ports to {connected_clients}/{expected_clients} nodes"
)
logger.debug("Worker port broadcast completed")
return worker_ports
finally:
rep_socket.close()
def _receive_ports_as_client(self, endpoint: str, node_rank: int) -> List[int]:
"""Receive worker ports from the server node."""
logger.debug(f"Connecting to node 0 to receive worker ports")
req_socket = get_zmq_socket(self.context, zmq.REQ, endpoint, False)
req_socket.setsockopt(zmq.RCVTIMEO, 60 * 1000) # 1 minute timeout
req_socket.setsockopt(zmq.SNDTIMEO, 60 * 1000)
try:
# Send handshake with our node rank
req_socket.send(str(node_rank).encode())
# Receive worker ports
worker_ports = req_socket.recv_pyobj()
logger.debug(f"Received {len(worker_ports)} worker ports from node 0")
return worker_ports
except zmq.Again:
logger.error("Timeout waiting for worker ports from node 0")
raise RuntimeError(
"Failed to receive worker ports from node 0 within timeout"
)
finally:
req_socket.close()
def launch_dp_attention_schedulers(
self, server_args: ServerArgs, port_args: PortArgs
):
# Pre-allocate worker ports on node 0 to avoid conflicts
worker_ports = []
if server_args.node_rank == 0:
for dp_rank in range(server_args.dp_size):
port_and_socket = get_zmq_socket(self.context, zmq.PUSH)
worker_ports.append(port_and_socket[0])
self.workers[dp_rank] = port_and_socket[1]
logger.debug(f"Assigned port {port_and_socket[0]} to worker {dp_rank}")
broadcasted_ports = self._broadcast_worker_ports(
server_args, worker_ports if worker_ports else None
)
self.launch_tensor_parallel_group(
server_args, port_args, 0, None, broadcasted_ports
)
def launch_tensor_parallel_group(
self,
server_args: ServerArgs,
port_args: PortArgs,
base_gpu_id: int,
dp_rank: int,
dp_rank: Optional[int],
worker_ports: Optional[List[int]] = None,
):
if not server_args.enable_dp_attention:
logger.info(f"Launch DP{dp_rank} starting at GPU #{base_gpu_id}.")
......@@ -290,7 +384,9 @@ class DataParallelController:
server_args.dp_size,
)
# compute zmq ports for this dp rank
rank_port_args = PortArgs.init_new(server_args, dp_rank)
rank_port_args = PortArgs.init_new(
server_args, dp_rank, worker_ports
)
# Data parallelism reuses the tensor parallelism group,
# so all dp ranks should use the same nccl port.
rank_port_args.nccl_port = port_args.nccl_port
......
......@@ -13,6 +13,8 @@
# ==============================================================================
"""The arguments of the server."""
from __future__ import annotations
import argparse
import dataclasses
import json
......@@ -3362,6 +3364,7 @@ def prepare_server_args(argv: List[str]) -> ServerArgs:
ZMQ_TCP_PORT_DELTA = 233
DP_ATTENTION_HANDSHAKE_PORT_DELTA = 5
@dataclasses.dataclass
......@@ -3386,7 +3389,11 @@ class PortArgs:
tokenizer_worker_ipc_name: Optional[str]
@staticmethod
def init_new(server_args, dp_rank: Optional[int] = None) -> "PortArgs":
def init_new(
server_args: ServerArgs,
dp_rank: Optional[int] = None,
worker_ports: Optional[List[int]] = None,
) -> PortArgs:
if server_args.nccl_port is None:
nccl_port = server_args.port + random.randint(100, 1000)
while True:
......@@ -3433,8 +3440,8 @@ class PortArgs:
# TokenizerManager to DataParallelController
scheduler_input_port = port_base + 4
else:
scheduler_input_port = port_base + 4 + 1 + dp_rank
assert worker_ports is not None
scheduler_input_port = worker_ports[dp_rank]
return PortArgs(
tokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base}",
scheduler_input_ipc_name=f"tcp://{dist_init_host}:{scheduler_input_port}",
......
......@@ -1291,8 +1291,46 @@ def pytorch_profile(name, func, *args, data_size=-1):
def get_zmq_socket(
context: zmq.Context, socket_type: zmq.SocketType, endpoint: str, bind: bool
) -> zmq.Socket:
context: zmq.Context,
socket_type: zmq.SocketType,
endpoint: Optional[str] = None,
bind: bool = True,
) -> Union[zmq.Socket, Tuple[int, zmq.Socket]]:
"""Create and configure a ZeroMQ socket.
Args:
context: ZeroMQ context to create the socket from.
socket_type: Type of ZeroMQ socket to create.
endpoint: Optional endpoint to bind/connect to. If None, binds to a random TCP port.
bind: Whether to bind (True) or connect (False) to the endpoint. Ignored if endpoint is None.
Returns:
If endpoint is None: Tuple of (port, socket) where port is the randomly assigned TCP port.
If endpoint is provided: The configured ZeroMQ socket.
"""
socket = context.socket(socket_type)
if endpoint is None:
# Bind to random TCP port
config_socket(socket, socket_type)
port = socket.bind_to_random_port("tcp://*")
return port, socket
else:
# Handle IPv6 if endpoint contains brackets
if endpoint.find("[") != -1:
socket.setsockopt(zmq.IPV6, 1)
config_socket(socket, socket_type)
if bind:
socket.bind(endpoint)
else:
socket.connect(endpoint)
return socket
def config_socket(socket, socket_type: zmq.SocketType):
mem = psutil.virtual_memory()
total_mem = mem.total / 1024**3
available_mem = mem.available / 1024**3
......@@ -1301,10 +1339,6 @@ def get_zmq_socket(
else:
buf_size = -1
socket = context.socket(socket_type)
if endpoint.find("[") != -1:
socket.setsockopt(zmq.IPV6, 1)
def set_send_opt():
socket.setsockopt(zmq.SNDHWM, 0)
socket.setsockopt(zmq.SNDBUF, buf_size)
......@@ -1317,19 +1351,12 @@ def get_zmq_socket(
set_send_opt()
elif socket_type == zmq.PULL:
set_recv_opt()
elif socket_type == zmq.DEALER:
elif socket_type in [zmq.DEALER, zmq.REQ, zmq.REP]:
set_send_opt()
set_recv_opt()
else:
raise ValueError(f"Unsupported socket type: {socket_type}")
if bind:
socket.bind(endpoint)
else:
socket.connect(endpoint)
return socket
def dump_to_file(dirpath, name, value):
from sglang.srt.distributed import get_tensor_model_parallel_rank
......
......@@ -75,7 +75,8 @@ class TestPortArgs(unittest.TestCase):
server_args.nnodes = 1
server_args.dist_init_addr = "192.168.1.1:25000"
port_args = PortArgs.init_new(server_args, dp_rank=2)
worker_ports = [25006, 25007, 25008, 25009]
port_args = PortArgs.init_new(server_args, dp_rank=2, worker_ports=worker_ports)
self.assertTrue(port_args.scheduler_input_ipc_name.endswith(":25008"))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment