Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
a20e7df8
Unverified
Commit
a20e7df8
authored
Oct 13, 2025
by
Yongtong Wu
Committed by
GitHub
Oct 12, 2025
Browse files
Improve dp attention port assignment scheme (#5889)
Co-authored-by:
Cheng Wan
<
cwan@x.ai
>
parent
1bdd0102
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
175 additions
and
44 deletions
+175
-44
python/sglang/srt/managers/data_parallel_controller.py
python/sglang/srt/managers/data_parallel_controller.py
+122
-26
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+10
-3
python/sglang/srt/utils/common.py
python/sglang/srt/utils/common.py
+41
-14
test/srt/test_server_args.py
test/srt/test_server_args.py
+2
-1
No files found.
python/sglang/srt/managers/data_parallel_controller.py
View file @
a20e7df8
...
...
@@ -21,7 +21,7 @@ import threading
import
time
from
collections
import
deque
from
enum
import
Enum
,
auto
from
typing
import
List
from
typing
import
List
,
Optional
import
psutil
import
setproctitle
...
...
@@ -36,7 +36,11 @@ from sglang.srt.managers.io_struct import (
)
from
sglang.srt.managers.schedule_batch
import
Req
from
sglang.srt.managers.scheduler
import
run_scheduler_process
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.server_args
import
(
DP_ATTENTION_HANDSHAKE_PORT_DELTA
,
PortArgs
,
ServerArgs
,
)
from
sglang.srt.utils
import
(
bind_port
,
configure_logger
,
...
...
@@ -140,22 +144,12 @@ class DataParallelController:
self
.
workers
:
List
[
zmq
.
Socket
]
=
[
None
]
*
server_args
.
dp_size
if
server_args
.
enable_dp_attention
:
dp_port_args
=
self
.
launch_dp_attention_schedulers
(
server_args
,
port_args
)
self
.
launch_dp_attention_schedulers
(
server_args
,
port_args
)
self
.
control_message_step
=
server_args
.
tp_size
else
:
dp_port_args
=
self
.
launch_dp_schedulers
(
server_args
,
port_args
)
self
.
launch_dp_schedulers
(
server_args
,
port_args
)
self
.
control_message_step
=
1
# Only node rank 0 runs the real data parallel controller that dispatches the requests.
if
server_args
.
node_rank
==
0
:
for
dp_rank
in
range
(
server_args
.
dp_size
):
self
.
workers
[
dp_rank
]
=
get_zmq_socket
(
self
.
context
,
zmq
.
PUSH
,
dp_port_args
[
dp_rank
].
scheduler_input_ipc_name
,
True
,
)
self
.
max_req_input_len
=
None
self
.
init_dispatcher
()
...
...
@@ -188,13 +182,11 @@ class DataParallelController:
threads
=
[]
sockets
=
[]
dp_port_args
=
[]
ready_events
=
[]
for
dp_rank
in
range
(
server_args
.
dp_size
):
tmp_port_args
=
PortArgs
.
init_new
(
server_args
)
tmp_port_args
.
tokenizer_ipc_name
=
port_args
.
tokenizer_ipc_name
tmp_port_args
.
detokenizer_ipc_name
=
port_args
.
detokenizer_ipc_name
dp_port_args
.
append
(
tmp_port_args
)
# This port is checked free in PortArgs.init_new.
# We hold it first so that the next dp worker gets a different port
...
...
@@ -213,6 +205,14 @@ class DataParallelController:
server_args
.
tp_size
*
server_args
.
pp_size
*
server_args
.
gpu_id_step
)
if
server_args
.
node_rank
==
0
:
self
.
workers
[
dp_rank
]
=
get_zmq_socket
(
self
.
context
,
zmq
.
PUSH
,
tmp_port_args
.
scheduler_input_ipc_name
,
True
,
)
# Free all sockets before starting the threads to launch TP workers
for
sock
in
sockets
:
sock
.
close
()
...
...
@@ -223,8 +223,6 @@ class DataParallelController:
for
event
in
ready_events
:
event
.
wait
()
return
dp_port_args
def
launch_tensor_parallel_group_thread
(
self
,
server_args
:
ServerArgs
,
...
...
@@ -241,19 +239,115 @@ class DataParallelController:
while
True
:
time
.
sleep
(
30
*
24
*
3600
)
def
launch_dp_attention_schedulers
(
self
,
server_args
,
port_args
):
self
.
launch_tensor_parallel_group
(
server_args
,
port_args
,
0
,
None
)
dp_port_args
=
[]
for
dp_rank
in
range
(
server_args
.
dp_size
):
dp_port_args
.
append
(
PortArgs
.
init_new
(
server_args
,
dp_rank
))
return
dp_port_args
def
_broadcast_worker_ports
(
self
,
server_args
:
ServerArgs
,
worker_ports
:
Optional
[
List
[
int
]]
=
None
)
->
List
[
int
]:
"""Broadcast worker ports from node 0 to all other nodes.
Node 0 acts as the server, waiting for all other nodes to connect and
sending them the pre-allocated worker ports. Other nodes act as clients,
connecting to node 0 to receive their copy of the worker ports.
Args:
server_args: Server arguments containing node configuration.
worker_ports: Pre-allocated worker ports to broadcast.
Returns:
List of worker ports (same on all nodes after broadcast).
"""
# Determine the endpoint for inter-node communication
if
server_args
.
dist_init_addr
is
None
:
endpoint
=
f
"tcp://127.0.0.1:
{
server_args
.
port
+
DP_ATTENTION_HANDSHAKE_PORT_DELTA
}
"
else
:
endpoint
=
f
"tcp://
{
server_args
.
dist_init_addr
}
"
if
server_args
.
node_rank
==
0
:
# Node 0: Broadcast worker ports to all other nodes
return
self
.
_broadcast_ports_as_server
(
endpoint
,
server_args
.
nnodes
-
1
,
worker_ports
)
else
:
# Other nodes: Receive worker ports from node 0
return
self
.
_receive_ports_as_client
(
endpoint
,
server_args
.
node_rank
)
def
_broadcast_ports_as_server
(
self
,
endpoint
:
str
,
expected_clients
:
int
,
worker_ports
:
List
[
int
]
)
->
List
[
int
]:
"""Broadcast worker ports to all client nodes."""
logger
.
debug
(
f
"Broadcasting worker ports to
{
expected_clients
}
client nodes"
)
logger
.
debug
(
f
"Worker ports:
{
worker_ports
}
"
)
rep_socket
=
get_zmq_socket
(
self
.
context
,
zmq
.
REP
,
endpoint
,
True
)
try
:
connected_clients
=
0
while
connected_clients
<
expected_clients
:
# Wait for client handshake
client_rank
=
rep_socket
.
recv
().
decode
()
logger
.
debug
(
f
"Received handshake from node
{
client_rank
}
"
)
# Send worker ports to client
rep_socket
.
send_pyobj
(
worker_ports
)
connected_clients
+=
1
logger
.
debug
(
f
"Sent worker ports to
{
connected_clients
}
/
{
expected_clients
}
nodes"
)
logger
.
debug
(
"Worker port broadcast completed"
)
return
worker_ports
finally
:
rep_socket
.
close
()
def
_receive_ports_as_client
(
self
,
endpoint
:
str
,
node_rank
:
int
)
->
List
[
int
]:
"""Receive worker ports from the server node."""
logger
.
debug
(
f
"Connecting to node 0 to receive worker ports"
)
req_socket
=
get_zmq_socket
(
self
.
context
,
zmq
.
REQ
,
endpoint
,
False
)
req_socket
.
setsockopt
(
zmq
.
RCVTIMEO
,
60
*
1000
)
# 1 minute timeout
req_socket
.
setsockopt
(
zmq
.
SNDTIMEO
,
60
*
1000
)
try
:
# Send handshake with our node rank
req_socket
.
send
(
str
(
node_rank
).
encode
())
# Receive worker ports
worker_ports
=
req_socket
.
recv_pyobj
()
logger
.
debug
(
f
"Received
{
len
(
worker_ports
)
}
worker ports from node 0"
)
return
worker_ports
except
zmq
.
Again
:
logger
.
error
(
"Timeout waiting for worker ports from node 0"
)
raise
RuntimeError
(
"Failed to receive worker ports from node 0 within timeout"
)
finally
:
req_socket
.
close
()
def
launch_dp_attention_schedulers
(
self
,
server_args
:
ServerArgs
,
port_args
:
PortArgs
):
# Pre-allocate worker ports on node 0 to avoid conflicts
worker_ports
=
[]
if
server_args
.
node_rank
==
0
:
for
dp_rank
in
range
(
server_args
.
dp_size
):
port_and_socket
=
get_zmq_socket
(
self
.
context
,
zmq
.
PUSH
)
worker_ports
.
append
(
port_and_socket
[
0
])
self
.
workers
[
dp_rank
]
=
port_and_socket
[
1
]
logger
.
debug
(
f
"Assigned port
{
port_and_socket
[
0
]
}
to worker
{
dp_rank
}
"
)
broadcasted_ports
=
self
.
_broadcast_worker_ports
(
server_args
,
worker_ports
if
worker_ports
else
None
)
self
.
launch_tensor_parallel_group
(
server_args
,
port_args
,
0
,
None
,
broadcasted_ports
)
def
launch_tensor_parallel_group
(
self
,
server_args
:
ServerArgs
,
port_args
:
PortArgs
,
base_gpu_id
:
int
,
dp_rank
:
int
,
dp_rank
:
Optional
[
int
],
worker_ports
:
Optional
[
List
[
int
]]
=
None
,
):
if
not
server_args
.
enable_dp_attention
:
logger
.
info
(
f
"Launch DP
{
dp_rank
}
starting at GPU #
{
base_gpu_id
}
."
)
...
...
@@ -290,7 +384,9 @@ class DataParallelController:
server_args
.
dp_size
,
)
# compute zmq ports for this dp rank
rank_port_args
=
PortArgs
.
init_new
(
server_args
,
dp_rank
)
rank_port_args
=
PortArgs
.
init_new
(
server_args
,
dp_rank
,
worker_ports
)
# Data parallelism reuses the tensor parallelism group,
# so all dp ranks should use the same nccl port.
rank_port_args
.
nccl_port
=
port_args
.
nccl_port
...
...
python/sglang/srt/server_args.py
View file @
a20e7df8
...
...
@@ -13,6 +13,8 @@
# ==============================================================================
"""The arguments of the server."""
from
__future__
import
annotations
import
argparse
import
dataclasses
import
json
...
...
@@ -3362,6 +3364,7 @@ def prepare_server_args(argv: List[str]) -> ServerArgs:
ZMQ_TCP_PORT_DELTA
=
233
DP_ATTENTION_HANDSHAKE_PORT_DELTA
=
5
@
dataclasses
.
dataclass
...
...
@@ -3386,7 +3389,11 @@ class PortArgs:
tokenizer_worker_ipc_name
:
Optional
[
str
]
@
staticmethod
def
init_new
(
server_args
,
dp_rank
:
Optional
[
int
]
=
None
)
->
"PortArgs"
:
def
init_new
(
server_args
:
ServerArgs
,
dp_rank
:
Optional
[
int
]
=
None
,
worker_ports
:
Optional
[
List
[
int
]]
=
None
,
)
->
PortArgs
:
if
server_args
.
nccl_port
is
None
:
nccl_port
=
server_args
.
port
+
random
.
randint
(
100
,
1000
)
while
True
:
...
...
@@ -3433,8 +3440,8 @@ class PortArgs:
# TokenizerManager to DataParallelController
scheduler_input_port
=
port_base
+
4
else
:
scheduler_input_port
=
port_base
+
4
+
1
+
dp_rank
assert
worker_ports
is
not
None
scheduler_input_port
=
worker_ports
[
dp_rank
]
return
PortArgs
(
tokenizer_ipc_name
=
f
"tcp://
{
dist_init_host
}
:
{
port_base
}
"
,
scheduler_input_ipc_name
=
f
"tcp://
{
dist_init_host
}
:
{
scheduler_input_port
}
"
,
...
...
python/sglang/srt/utils/common.py
View file @
a20e7df8
...
...
@@ -1291,8 +1291,46 @@ def pytorch_profile(name, func, *args, data_size=-1):
def
get_zmq_socket
(
context
:
zmq
.
Context
,
socket_type
:
zmq
.
SocketType
,
endpoint
:
str
,
bind
:
bool
)
->
zmq
.
Socket
:
context
:
zmq
.
Context
,
socket_type
:
zmq
.
SocketType
,
endpoint
:
Optional
[
str
]
=
None
,
bind
:
bool
=
True
,
)
->
Union
[
zmq
.
Socket
,
Tuple
[
int
,
zmq
.
Socket
]]:
"""Create and configure a ZeroMQ socket.
Args:
context: ZeroMQ context to create the socket from.
socket_type: Type of ZeroMQ socket to create.
endpoint: Optional endpoint to bind/connect to. If None, binds to a random TCP port.
bind: Whether to bind (True) or connect (False) to the endpoint. Ignored if endpoint is None.
Returns:
If endpoint is None: Tuple of (port, socket) where port is the randomly assigned TCP port.
If endpoint is provided: The configured ZeroMQ socket.
"""
socket
=
context
.
socket
(
socket_type
)
if
endpoint
is
None
:
# Bind to random TCP port
config_socket
(
socket
,
socket_type
)
port
=
socket
.
bind_to_random_port
(
"tcp://*"
)
return
port
,
socket
else
:
# Handle IPv6 if endpoint contains brackets
if
endpoint
.
find
(
"["
)
!=
-
1
:
socket
.
setsockopt
(
zmq
.
IPV6
,
1
)
config_socket
(
socket
,
socket_type
)
if
bind
:
socket
.
bind
(
endpoint
)
else
:
socket
.
connect
(
endpoint
)
return
socket
def
config_socket
(
socket
,
socket_type
:
zmq
.
SocketType
):
mem
=
psutil
.
virtual_memory
()
total_mem
=
mem
.
total
/
1024
**
3
available_mem
=
mem
.
available
/
1024
**
3
...
...
@@ -1301,10 +1339,6 @@ def get_zmq_socket(
else
:
buf_size
=
-
1
socket
=
context
.
socket
(
socket_type
)
if
endpoint
.
find
(
"["
)
!=
-
1
:
socket
.
setsockopt
(
zmq
.
IPV6
,
1
)
def
set_send_opt
():
socket
.
setsockopt
(
zmq
.
SNDHWM
,
0
)
socket
.
setsockopt
(
zmq
.
SNDBUF
,
buf_size
)
...
...
@@ -1317,19 +1351,12 @@ def get_zmq_socket(
set_send_opt
()
elif
socket_type
==
zmq
.
PULL
:
set_recv_opt
()
elif
socket_type
==
zmq
.
DEALER
:
elif
socket_type
in
[
zmq
.
DEALER
,
zmq
.
REQ
,
zmq
.
REP
]
:
set_send_opt
()
set_recv_opt
()
else
:
raise
ValueError
(
f
"Unsupported socket type:
{
socket_type
}
"
)
if
bind
:
socket
.
bind
(
endpoint
)
else
:
socket
.
connect
(
endpoint
)
return
socket
def
dump_to_file
(
dirpath
,
name
,
value
):
from
sglang.srt.distributed
import
get_tensor_model_parallel_rank
...
...
test/srt/test_server_args.py
View file @
a20e7df8
...
...
@@ -75,7 +75,8 @@ class TestPortArgs(unittest.TestCase):
server_args
.
nnodes
=
1
server_args
.
dist_init_addr
=
"192.168.1.1:25000"
port_args
=
PortArgs
.
init_new
(
server_args
,
dp_rank
=
2
)
worker_ports
=
[
25006
,
25007
,
25008
,
25009
]
port_args
=
PortArgs
.
init_new
(
server_args
,
dp_rank
=
2
,
worker_ports
=
worker_ports
)
self
.
assertTrue
(
port_args
.
scheduler_input_ipc_name
.
endswith
(
":25008"
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment