"tests/python/vscode:/vscode.git/clone" did not exist on "f4989867713acae87e11993c479723251a0fd942"
Unverified Commit a20e7df8 authored by Yongtong Wu's avatar Yongtong Wu Committed by GitHub
Browse files

Improve dp attention port assignment scheme (#5889)


Co-authored-by: default avatarCheng Wan <cwan@x.ai>
parent 1bdd0102
...@@ -21,7 +21,7 @@ import threading ...@@ -21,7 +21,7 @@ import threading
import time import time
from collections import deque from collections import deque
from enum import Enum, auto from enum import Enum, auto
from typing import List from typing import List, Optional
import psutil import psutil
import setproctitle import setproctitle
...@@ -36,7 +36,11 @@ from sglang.srt.managers.io_struct import ( ...@@ -36,7 +36,11 @@ from sglang.srt.managers.io_struct import (
) )
from sglang.srt.managers.schedule_batch import Req from sglang.srt.managers.schedule_batch import Req
from sglang.srt.managers.scheduler import run_scheduler_process from sglang.srt.managers.scheduler import run_scheduler_process
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import (
DP_ATTENTION_HANDSHAKE_PORT_DELTA,
PortArgs,
ServerArgs,
)
from sglang.srt.utils import ( from sglang.srt.utils import (
bind_port, bind_port,
configure_logger, configure_logger,
...@@ -140,22 +144,12 @@ class DataParallelController: ...@@ -140,22 +144,12 @@ class DataParallelController:
self.workers: List[zmq.Socket] = [None] * server_args.dp_size self.workers: List[zmq.Socket] = [None] * server_args.dp_size
if server_args.enable_dp_attention: if server_args.enable_dp_attention:
dp_port_args = self.launch_dp_attention_schedulers(server_args, port_args) self.launch_dp_attention_schedulers(server_args, port_args)
self.control_message_step = server_args.tp_size self.control_message_step = server_args.tp_size
else: else:
dp_port_args = self.launch_dp_schedulers(server_args, port_args) self.launch_dp_schedulers(server_args, port_args)
self.control_message_step = 1 self.control_message_step = 1
# Only node rank 0 runs the real data parallel controller that dispatches the requests.
if server_args.node_rank == 0:
for dp_rank in range(server_args.dp_size):
self.workers[dp_rank] = get_zmq_socket(
self.context,
zmq.PUSH,
dp_port_args[dp_rank].scheduler_input_ipc_name,
True,
)
self.max_req_input_len = None self.max_req_input_len = None
self.init_dispatcher() self.init_dispatcher()
...@@ -188,13 +182,11 @@ class DataParallelController: ...@@ -188,13 +182,11 @@ class DataParallelController:
threads = [] threads = []
sockets = [] sockets = []
dp_port_args = []
ready_events = [] ready_events = []
for dp_rank in range(server_args.dp_size): for dp_rank in range(server_args.dp_size):
tmp_port_args = PortArgs.init_new(server_args) tmp_port_args = PortArgs.init_new(server_args)
tmp_port_args.tokenizer_ipc_name = port_args.tokenizer_ipc_name tmp_port_args.tokenizer_ipc_name = port_args.tokenizer_ipc_name
tmp_port_args.detokenizer_ipc_name = port_args.detokenizer_ipc_name tmp_port_args.detokenizer_ipc_name = port_args.detokenizer_ipc_name
dp_port_args.append(tmp_port_args)
# This port is checked free in PortArgs.init_new. # This port is checked free in PortArgs.init_new.
# We hold it first so that the next dp worker gets a different port # We hold it first so that the next dp worker gets a different port
...@@ -213,6 +205,14 @@ class DataParallelController: ...@@ -213,6 +205,14 @@ class DataParallelController:
server_args.tp_size * server_args.pp_size * server_args.gpu_id_step server_args.tp_size * server_args.pp_size * server_args.gpu_id_step
) )
if server_args.node_rank == 0:
self.workers[dp_rank] = get_zmq_socket(
self.context,
zmq.PUSH,
tmp_port_args.scheduler_input_ipc_name,
True,
)
# Free all sockets before starting the threads to launch TP workers # Free all sockets before starting the threads to launch TP workers
for sock in sockets: for sock in sockets:
sock.close() sock.close()
...@@ -223,8 +223,6 @@ class DataParallelController: ...@@ -223,8 +223,6 @@ class DataParallelController:
for event in ready_events: for event in ready_events:
event.wait() event.wait()
return dp_port_args
def launch_tensor_parallel_group_thread( def launch_tensor_parallel_group_thread(
self, self,
server_args: ServerArgs, server_args: ServerArgs,
...@@ -241,19 +239,115 @@ class DataParallelController: ...@@ -241,19 +239,115 @@ class DataParallelController:
while True: while True:
time.sleep(30 * 24 * 3600) time.sleep(30 * 24 * 3600)
def launch_dp_attention_schedulers(self, server_args, port_args): def _broadcast_worker_ports(
self.launch_tensor_parallel_group(server_args, port_args, 0, None) self, server_args: ServerArgs, worker_ports: Optional[List[int]] = None
dp_port_args = [] ) -> List[int]:
for dp_rank in range(server_args.dp_size): """Broadcast worker ports from node 0 to all other nodes.
dp_port_args.append(PortArgs.init_new(server_args, dp_rank))
return dp_port_args Node 0 acts as the server, waiting for all other nodes to connect and
sending them the pre-allocated worker ports. Other nodes act as clients,
connecting to node 0 to receive their copy of the worker ports.
Args:
server_args: Server arguments containing node configuration.
worker_ports: Pre-allocated worker ports to broadcast.
Returns:
List of worker ports (same on all nodes after broadcast).
"""
# Determine the endpoint for inter-node communication
if server_args.dist_init_addr is None:
endpoint = f"tcp://127.0.0.1:{server_args.port + DP_ATTENTION_HANDSHAKE_PORT_DELTA}"
else:
endpoint = f"tcp://{server_args.dist_init_addr}"
if server_args.node_rank == 0:
# Node 0: Broadcast worker ports to all other nodes
return self._broadcast_ports_as_server(
endpoint, server_args.nnodes - 1, worker_ports
)
else:
# Other nodes: Receive worker ports from node 0
return self._receive_ports_as_client(endpoint, server_args.node_rank)
def _broadcast_ports_as_server(
self, endpoint: str, expected_clients: int, worker_ports: List[int]
) -> List[int]:
"""Broadcast worker ports to all client nodes."""
logger.debug(f"Broadcasting worker ports to {expected_clients} client nodes")
logger.debug(f"Worker ports: {worker_ports}")
rep_socket = get_zmq_socket(self.context, zmq.REP, endpoint, True)
try:
connected_clients = 0
while connected_clients < expected_clients:
# Wait for client handshake
client_rank = rep_socket.recv().decode()
logger.debug(f"Received handshake from node {client_rank}")
# Send worker ports to client
rep_socket.send_pyobj(worker_ports)
connected_clients += 1
logger.debug(
f"Sent worker ports to {connected_clients}/{expected_clients} nodes"
)
logger.debug("Worker port broadcast completed")
return worker_ports
finally:
rep_socket.close()
def _receive_ports_as_client(self, endpoint: str, node_rank: int) -> List[int]:
"""Receive worker ports from the server node."""
logger.debug(f"Connecting to node 0 to receive worker ports")
req_socket = get_zmq_socket(self.context, zmq.REQ, endpoint, False)
req_socket.setsockopt(zmq.RCVTIMEO, 60 * 1000) # 1 minute timeout
req_socket.setsockopt(zmq.SNDTIMEO, 60 * 1000)
try:
# Send handshake with our node rank
req_socket.send(str(node_rank).encode())
# Receive worker ports
worker_ports = req_socket.recv_pyobj()
logger.debug(f"Received {len(worker_ports)} worker ports from node 0")
return worker_ports
except zmq.Again:
logger.error("Timeout waiting for worker ports from node 0")
raise RuntimeError(
"Failed to receive worker ports from node 0 within timeout"
)
finally:
req_socket.close()
def launch_dp_attention_schedulers(
self, server_args: ServerArgs, port_args: PortArgs
):
# Pre-allocate worker ports on node 0 to avoid conflicts
worker_ports = []
if server_args.node_rank == 0:
for dp_rank in range(server_args.dp_size):
port_and_socket = get_zmq_socket(self.context, zmq.PUSH)
worker_ports.append(port_and_socket[0])
self.workers[dp_rank] = port_and_socket[1]
logger.debug(f"Assigned port {port_and_socket[0]} to worker {dp_rank}")
broadcasted_ports = self._broadcast_worker_ports(
server_args, worker_ports if worker_ports else None
)
self.launch_tensor_parallel_group(
server_args, port_args, 0, None, broadcasted_ports
)
def launch_tensor_parallel_group( def launch_tensor_parallel_group(
self, self,
server_args: ServerArgs, server_args: ServerArgs,
port_args: PortArgs, port_args: PortArgs,
base_gpu_id: int, base_gpu_id: int,
dp_rank: int, dp_rank: Optional[int],
worker_ports: Optional[List[int]] = None,
): ):
if not server_args.enable_dp_attention: if not server_args.enable_dp_attention:
logger.info(f"Launch DP{dp_rank} starting at GPU #{base_gpu_id}.") logger.info(f"Launch DP{dp_rank} starting at GPU #{base_gpu_id}.")
...@@ -290,7 +384,9 @@ class DataParallelController: ...@@ -290,7 +384,9 @@ class DataParallelController:
server_args.dp_size, server_args.dp_size,
) )
# compute zmq ports for this dp rank # compute zmq ports for this dp rank
rank_port_args = PortArgs.init_new(server_args, dp_rank) rank_port_args = PortArgs.init_new(
server_args, dp_rank, worker_ports
)
# Data parallelism reuses the tensor parallelism group, # Data parallelism reuses the tensor parallelism group,
# so all dp ranks should use the same nccl port. # so all dp ranks should use the same nccl port.
rank_port_args.nccl_port = port_args.nccl_port rank_port_args.nccl_port = port_args.nccl_port
......
...@@ -13,6 +13,8 @@ ...@@ -13,6 +13,8 @@
# ============================================================================== # ==============================================================================
"""The arguments of the server.""" """The arguments of the server."""
from __future__ import annotations
import argparse import argparse
import dataclasses import dataclasses
import json import json
...@@ -3362,6 +3364,7 @@ def prepare_server_args(argv: List[str]) -> ServerArgs: ...@@ -3362,6 +3364,7 @@ def prepare_server_args(argv: List[str]) -> ServerArgs:
ZMQ_TCP_PORT_DELTA = 233 ZMQ_TCP_PORT_DELTA = 233
DP_ATTENTION_HANDSHAKE_PORT_DELTA = 5
@dataclasses.dataclass @dataclasses.dataclass
...@@ -3386,7 +3389,11 @@ class PortArgs: ...@@ -3386,7 +3389,11 @@ class PortArgs:
tokenizer_worker_ipc_name: Optional[str] tokenizer_worker_ipc_name: Optional[str]
@staticmethod @staticmethod
def init_new(server_args, dp_rank: Optional[int] = None) -> "PortArgs": def init_new(
server_args: ServerArgs,
dp_rank: Optional[int] = None,
worker_ports: Optional[List[int]] = None,
) -> PortArgs:
if server_args.nccl_port is None: if server_args.nccl_port is None:
nccl_port = server_args.port + random.randint(100, 1000) nccl_port = server_args.port + random.randint(100, 1000)
while True: while True:
...@@ -3433,8 +3440,8 @@ class PortArgs: ...@@ -3433,8 +3440,8 @@ class PortArgs:
# TokenizerManager to DataParallelController # TokenizerManager to DataParallelController
scheduler_input_port = port_base + 4 scheduler_input_port = port_base + 4
else: else:
scheduler_input_port = port_base + 4 + 1 + dp_rank assert worker_ports is not None
scheduler_input_port = worker_ports[dp_rank]
return PortArgs( return PortArgs(
tokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base}", tokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base}",
scheduler_input_ipc_name=f"tcp://{dist_init_host}:{scheduler_input_port}", scheduler_input_ipc_name=f"tcp://{dist_init_host}:{scheduler_input_port}",
......
...@@ -1291,8 +1291,46 @@ def pytorch_profile(name, func, *args, data_size=-1): ...@@ -1291,8 +1291,46 @@ def pytorch_profile(name, func, *args, data_size=-1):
def get_zmq_socket( def get_zmq_socket(
context: zmq.Context, socket_type: zmq.SocketType, endpoint: str, bind: bool context: zmq.Context,
) -> zmq.Socket: socket_type: zmq.SocketType,
endpoint: Optional[str] = None,
bind: bool = True,
) -> Union[zmq.Socket, Tuple[int, zmq.Socket]]:
"""Create and configure a ZeroMQ socket.
Args:
context: ZeroMQ context to create the socket from.
socket_type: Type of ZeroMQ socket to create.
endpoint: Optional endpoint to bind/connect to. If None, binds to a random TCP port.
bind: Whether to bind (True) or connect (False) to the endpoint. Ignored if endpoint is None.
Returns:
If endpoint is None: Tuple of (port, socket) where port is the randomly assigned TCP port.
If endpoint is provided: The configured ZeroMQ socket.
"""
socket = context.socket(socket_type)
if endpoint is None:
# Bind to random TCP port
config_socket(socket, socket_type)
port = socket.bind_to_random_port("tcp://*")
return port, socket
else:
# Handle IPv6 if endpoint contains brackets
if endpoint.find("[") != -1:
socket.setsockopt(zmq.IPV6, 1)
config_socket(socket, socket_type)
if bind:
socket.bind(endpoint)
else:
socket.connect(endpoint)
return socket
def config_socket(socket, socket_type: zmq.SocketType):
mem = psutil.virtual_memory() mem = psutil.virtual_memory()
total_mem = mem.total / 1024**3 total_mem = mem.total / 1024**3
available_mem = mem.available / 1024**3 available_mem = mem.available / 1024**3
...@@ -1301,10 +1339,6 @@ def get_zmq_socket( ...@@ -1301,10 +1339,6 @@ def get_zmq_socket(
else: else:
buf_size = -1 buf_size = -1
socket = context.socket(socket_type)
if endpoint.find("[") != -1:
socket.setsockopt(zmq.IPV6, 1)
def set_send_opt(): def set_send_opt():
socket.setsockopt(zmq.SNDHWM, 0) socket.setsockopt(zmq.SNDHWM, 0)
socket.setsockopt(zmq.SNDBUF, buf_size) socket.setsockopt(zmq.SNDBUF, buf_size)
...@@ -1317,19 +1351,12 @@ def get_zmq_socket( ...@@ -1317,19 +1351,12 @@ def get_zmq_socket(
set_send_opt() set_send_opt()
elif socket_type == zmq.PULL: elif socket_type == zmq.PULL:
set_recv_opt() set_recv_opt()
elif socket_type == zmq.DEALER: elif socket_type in [zmq.DEALER, zmq.REQ, zmq.REP]:
set_send_opt() set_send_opt()
set_recv_opt() set_recv_opt()
else: else:
raise ValueError(f"Unsupported socket type: {socket_type}") raise ValueError(f"Unsupported socket type: {socket_type}")
if bind:
socket.bind(endpoint)
else:
socket.connect(endpoint)
return socket
def dump_to_file(dirpath, name, value): def dump_to_file(dirpath, name, value):
from sglang.srt.distributed import get_tensor_model_parallel_rank from sglang.srt.distributed import get_tensor_model_parallel_rank
......
...@@ -75,7 +75,8 @@ class TestPortArgs(unittest.TestCase): ...@@ -75,7 +75,8 @@ class TestPortArgs(unittest.TestCase):
server_args.nnodes = 1 server_args.nnodes = 1
server_args.dist_init_addr = "192.168.1.1:25000" server_args.dist_init_addr = "192.168.1.1:25000"
port_args = PortArgs.init_new(server_args, dp_rank=2) worker_ports = [25006, 25007, 25008, 25009]
port_args = PortArgs.init_new(server_args, dp_rank=2, worker_ports=worker_ports)
self.assertTrue(port_args.scheduler_input_ipc_name.endswith(":25008")) self.assertTrue(port_args.scheduler_input_ipc_name.endswith(":25008"))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment