Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
a20e7df8
"tests/python/vscode:/vscode.git/clone" did not exist on "f4989867713acae87e11993c479723251a0fd942"
Unverified
Commit
a20e7df8
authored
Oct 13, 2025
by
Yongtong Wu
Committed by
GitHub
Oct 12, 2025
Browse files
Improve dp attention port assignment scheme (#5889)
Co-authored-by:
Cheng Wan
<
cwan@x.ai
>
parent
1bdd0102
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
175 additions
and
44 deletions
+175
-44
python/sglang/srt/managers/data_parallel_controller.py
python/sglang/srt/managers/data_parallel_controller.py
+122
-26
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+10
-3
python/sglang/srt/utils/common.py
python/sglang/srt/utils/common.py
+41
-14
test/srt/test_server_args.py
test/srt/test_server_args.py
+2
-1
No files found.
python/sglang/srt/managers/data_parallel_controller.py
View file @
a20e7df8
...
@@ -21,7 +21,7 @@ import threading
...
@@ -21,7 +21,7 @@ import threading
import
time
import
time
from
collections
import
deque
from
collections
import
deque
from
enum
import
Enum
,
auto
from
enum
import
Enum
,
auto
from
typing
import
List
from
typing
import
List
,
Optional
import
psutil
import
psutil
import
setproctitle
import
setproctitle
...
@@ -36,7 +36,11 @@ from sglang.srt.managers.io_struct import (
...
@@ -36,7 +36,11 @@ from sglang.srt.managers.io_struct import (
)
)
from
sglang.srt.managers.schedule_batch
import
Req
from
sglang.srt.managers.schedule_batch
import
Req
from
sglang.srt.managers.scheduler
import
run_scheduler_process
from
sglang.srt.managers.scheduler
import
run_scheduler_process
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.server_args
import
(
DP_ATTENTION_HANDSHAKE_PORT_DELTA
,
PortArgs
,
ServerArgs
,
)
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
bind_port
,
bind_port
,
configure_logger
,
configure_logger
,
...
@@ -140,22 +144,12 @@ class DataParallelController:
...
@@ -140,22 +144,12 @@ class DataParallelController:
self
.
workers
:
List
[
zmq
.
Socket
]
=
[
None
]
*
server_args
.
dp_size
self
.
workers
:
List
[
zmq
.
Socket
]
=
[
None
]
*
server_args
.
dp_size
if
server_args
.
enable_dp_attention
:
if
server_args
.
enable_dp_attention
:
dp_port_args
=
self
.
launch_dp_attention_schedulers
(
server_args
,
port_args
)
self
.
launch_dp_attention_schedulers
(
server_args
,
port_args
)
self
.
control_message_step
=
server_args
.
tp_size
self
.
control_message_step
=
server_args
.
tp_size
else
:
else
:
dp_port_args
=
self
.
launch_dp_schedulers
(
server_args
,
port_args
)
self
.
launch_dp_schedulers
(
server_args
,
port_args
)
self
.
control_message_step
=
1
self
.
control_message_step
=
1
# Only node rank 0 runs the real data parallel controller that dispatches the requests.
if
server_args
.
node_rank
==
0
:
for
dp_rank
in
range
(
server_args
.
dp_size
):
self
.
workers
[
dp_rank
]
=
get_zmq_socket
(
self
.
context
,
zmq
.
PUSH
,
dp_port_args
[
dp_rank
].
scheduler_input_ipc_name
,
True
,
)
self
.
max_req_input_len
=
None
self
.
max_req_input_len
=
None
self
.
init_dispatcher
()
self
.
init_dispatcher
()
...
@@ -188,13 +182,11 @@ class DataParallelController:
...
@@ -188,13 +182,11 @@ class DataParallelController:
threads
=
[]
threads
=
[]
sockets
=
[]
sockets
=
[]
dp_port_args
=
[]
ready_events
=
[]
ready_events
=
[]
for
dp_rank
in
range
(
server_args
.
dp_size
):
for
dp_rank
in
range
(
server_args
.
dp_size
):
tmp_port_args
=
PortArgs
.
init_new
(
server_args
)
tmp_port_args
=
PortArgs
.
init_new
(
server_args
)
tmp_port_args
.
tokenizer_ipc_name
=
port_args
.
tokenizer_ipc_name
tmp_port_args
.
tokenizer_ipc_name
=
port_args
.
tokenizer_ipc_name
tmp_port_args
.
detokenizer_ipc_name
=
port_args
.
detokenizer_ipc_name
tmp_port_args
.
detokenizer_ipc_name
=
port_args
.
detokenizer_ipc_name
dp_port_args
.
append
(
tmp_port_args
)
# This port is checked free in PortArgs.init_new.
# This port is checked free in PortArgs.init_new.
# We hold it first so that the next dp worker gets a different port
# We hold it first so that the next dp worker gets a different port
...
@@ -213,6 +205,14 @@ class DataParallelController:
...
@@ -213,6 +205,14 @@ class DataParallelController:
server_args
.
tp_size
*
server_args
.
pp_size
*
server_args
.
gpu_id_step
server_args
.
tp_size
*
server_args
.
pp_size
*
server_args
.
gpu_id_step
)
)
if
server_args
.
node_rank
==
0
:
self
.
workers
[
dp_rank
]
=
get_zmq_socket
(
self
.
context
,
zmq
.
PUSH
,
tmp_port_args
.
scheduler_input_ipc_name
,
True
,
)
# Free all sockets before starting the threads to launch TP workers
# Free all sockets before starting the threads to launch TP workers
for
sock
in
sockets
:
for
sock
in
sockets
:
sock
.
close
()
sock
.
close
()
...
@@ -223,8 +223,6 @@ class DataParallelController:
...
@@ -223,8 +223,6 @@ class DataParallelController:
for
event
in
ready_events
:
for
event
in
ready_events
:
event
.
wait
()
event
.
wait
()
return
dp_port_args
def
launch_tensor_parallel_group_thread
(
def
launch_tensor_parallel_group_thread
(
self
,
self
,
server_args
:
ServerArgs
,
server_args
:
ServerArgs
,
...
@@ -241,19 +239,115 @@ class DataParallelController:
...
@@ -241,19 +239,115 @@ class DataParallelController:
while
True
:
while
True
:
time
.
sleep
(
30
*
24
*
3600
)
time
.
sleep
(
30
*
24
*
3600
)
def
launch_dp_attention_schedulers
(
self
,
server_args
,
port_args
):
def
_broadcast_worker_ports
(
self
.
launch_tensor_parallel_group
(
server_args
,
port_args
,
0
,
None
)
self
,
server_args
:
ServerArgs
,
worker_ports
:
Optional
[
List
[
int
]]
=
None
dp_port_args
=
[]
)
->
List
[
int
]:
for
dp_rank
in
range
(
server_args
.
dp_size
):
"""Broadcast worker ports from node 0 to all other nodes.
dp_port_args
.
append
(
PortArgs
.
init_new
(
server_args
,
dp_rank
))
return
dp_port_args
Node 0 acts as the server, waiting for all other nodes to connect and
sending them the pre-allocated worker ports. Other nodes act as clients,
connecting to node 0 to receive their copy of the worker ports.
Args:
server_args: Server arguments containing node configuration.
worker_ports: Pre-allocated worker ports to broadcast.
Returns:
List of worker ports (same on all nodes after broadcast).
"""
# Determine the endpoint for inter-node communication
if
server_args
.
dist_init_addr
is
None
:
endpoint
=
f
"tcp://127.0.0.1:
{
server_args
.
port
+
DP_ATTENTION_HANDSHAKE_PORT_DELTA
}
"
else
:
endpoint
=
f
"tcp://
{
server_args
.
dist_init_addr
}
"
if
server_args
.
node_rank
==
0
:
# Node 0: Broadcast worker ports to all other nodes
return
self
.
_broadcast_ports_as_server
(
endpoint
,
server_args
.
nnodes
-
1
,
worker_ports
)
else
:
# Other nodes: Receive worker ports from node 0
return
self
.
_receive_ports_as_client
(
endpoint
,
server_args
.
node_rank
)
def
_broadcast_ports_as_server
(
self
,
endpoint
:
str
,
expected_clients
:
int
,
worker_ports
:
List
[
int
]
)
->
List
[
int
]:
"""Broadcast worker ports to all client nodes."""
logger
.
debug
(
f
"Broadcasting worker ports to
{
expected_clients
}
client nodes"
)
logger
.
debug
(
f
"Worker ports:
{
worker_ports
}
"
)
rep_socket
=
get_zmq_socket
(
self
.
context
,
zmq
.
REP
,
endpoint
,
True
)
try
:
connected_clients
=
0
while
connected_clients
<
expected_clients
:
# Wait for client handshake
client_rank
=
rep_socket
.
recv
().
decode
()
logger
.
debug
(
f
"Received handshake from node
{
client_rank
}
"
)
# Send worker ports to client
rep_socket
.
send_pyobj
(
worker_ports
)
connected_clients
+=
1
logger
.
debug
(
f
"Sent worker ports to
{
connected_clients
}
/
{
expected_clients
}
nodes"
)
logger
.
debug
(
"Worker port broadcast completed"
)
return
worker_ports
finally
:
rep_socket
.
close
()
def
_receive_ports_as_client
(
self
,
endpoint
:
str
,
node_rank
:
int
)
->
List
[
int
]:
"""Receive worker ports from the server node."""
logger
.
debug
(
f
"Connecting to node 0 to receive worker ports"
)
req_socket
=
get_zmq_socket
(
self
.
context
,
zmq
.
REQ
,
endpoint
,
False
)
req_socket
.
setsockopt
(
zmq
.
RCVTIMEO
,
60
*
1000
)
# 1 minute timeout
req_socket
.
setsockopt
(
zmq
.
SNDTIMEO
,
60
*
1000
)
try
:
# Send handshake with our node rank
req_socket
.
send
(
str
(
node_rank
).
encode
())
# Receive worker ports
worker_ports
=
req_socket
.
recv_pyobj
()
logger
.
debug
(
f
"Received
{
len
(
worker_ports
)
}
worker ports from node 0"
)
return
worker_ports
except
zmq
.
Again
:
logger
.
error
(
"Timeout waiting for worker ports from node 0"
)
raise
RuntimeError
(
"Failed to receive worker ports from node 0 within timeout"
)
finally
:
req_socket
.
close
()
def
launch_dp_attention_schedulers
(
self
,
server_args
:
ServerArgs
,
port_args
:
PortArgs
):
# Pre-allocate worker ports on node 0 to avoid conflicts
worker_ports
=
[]
if
server_args
.
node_rank
==
0
:
for
dp_rank
in
range
(
server_args
.
dp_size
):
port_and_socket
=
get_zmq_socket
(
self
.
context
,
zmq
.
PUSH
)
worker_ports
.
append
(
port_and_socket
[
0
])
self
.
workers
[
dp_rank
]
=
port_and_socket
[
1
]
logger
.
debug
(
f
"Assigned port
{
port_and_socket
[
0
]
}
to worker
{
dp_rank
}
"
)
broadcasted_ports
=
self
.
_broadcast_worker_ports
(
server_args
,
worker_ports
if
worker_ports
else
None
)
self
.
launch_tensor_parallel_group
(
server_args
,
port_args
,
0
,
None
,
broadcasted_ports
)
def
launch_tensor_parallel_group
(
def
launch_tensor_parallel_group
(
self
,
self
,
server_args
:
ServerArgs
,
server_args
:
ServerArgs
,
port_args
:
PortArgs
,
port_args
:
PortArgs
,
base_gpu_id
:
int
,
base_gpu_id
:
int
,
dp_rank
:
int
,
dp_rank
:
Optional
[
int
],
worker_ports
:
Optional
[
List
[
int
]]
=
None
,
):
):
if
not
server_args
.
enable_dp_attention
:
if
not
server_args
.
enable_dp_attention
:
logger
.
info
(
f
"Launch DP
{
dp_rank
}
starting at GPU #
{
base_gpu_id
}
."
)
logger
.
info
(
f
"Launch DP
{
dp_rank
}
starting at GPU #
{
base_gpu_id
}
."
)
...
@@ -290,7 +384,9 @@ class DataParallelController:
...
@@ -290,7 +384,9 @@ class DataParallelController:
server_args
.
dp_size
,
server_args
.
dp_size
,
)
)
# compute zmq ports for this dp rank
# compute zmq ports for this dp rank
rank_port_args
=
PortArgs
.
init_new
(
server_args
,
dp_rank
)
rank_port_args
=
PortArgs
.
init_new
(
server_args
,
dp_rank
,
worker_ports
)
# Data parallelism reuses the tensor parallelism group,
# Data parallelism reuses the tensor parallelism group,
# so all dp ranks should use the same nccl port.
# so all dp ranks should use the same nccl port.
rank_port_args
.
nccl_port
=
port_args
.
nccl_port
rank_port_args
.
nccl_port
=
port_args
.
nccl_port
...
...
python/sglang/srt/server_args.py
View file @
a20e7df8
...
@@ -13,6 +13,8 @@
...
@@ -13,6 +13,8 @@
# ==============================================================================
# ==============================================================================
"""The arguments of the server."""
"""The arguments of the server."""
from
__future__
import
annotations
import
argparse
import
argparse
import
dataclasses
import
dataclasses
import
json
import
json
...
@@ -3362,6 +3364,7 @@ def prepare_server_args(argv: List[str]) -> ServerArgs:
...
@@ -3362,6 +3364,7 @@ def prepare_server_args(argv: List[str]) -> ServerArgs:
ZMQ_TCP_PORT_DELTA
=
233
ZMQ_TCP_PORT_DELTA
=
233
DP_ATTENTION_HANDSHAKE_PORT_DELTA
=
5
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
...
@@ -3386,7 +3389,11 @@ class PortArgs:
...
@@ -3386,7 +3389,11 @@ class PortArgs:
tokenizer_worker_ipc_name
:
Optional
[
str
]
tokenizer_worker_ipc_name
:
Optional
[
str
]
@
staticmethod
@
staticmethod
def
init_new
(
server_args
,
dp_rank
:
Optional
[
int
]
=
None
)
->
"PortArgs"
:
def
init_new
(
server_args
:
ServerArgs
,
dp_rank
:
Optional
[
int
]
=
None
,
worker_ports
:
Optional
[
List
[
int
]]
=
None
,
)
->
PortArgs
:
if
server_args
.
nccl_port
is
None
:
if
server_args
.
nccl_port
is
None
:
nccl_port
=
server_args
.
port
+
random
.
randint
(
100
,
1000
)
nccl_port
=
server_args
.
port
+
random
.
randint
(
100
,
1000
)
while
True
:
while
True
:
...
@@ -3433,8 +3440,8 @@ class PortArgs:
...
@@ -3433,8 +3440,8 @@ class PortArgs:
# TokenizerManager to DataParallelController
# TokenizerManager to DataParallelController
scheduler_input_port
=
port_base
+
4
scheduler_input_port
=
port_base
+
4
else
:
else
:
scheduler_input_port
=
port_base
+
4
+
1
+
dp_rank
assert
worker_ports
is
not
None
scheduler_input_port
=
worker_ports
[
dp_rank
]
return
PortArgs
(
return
PortArgs
(
tokenizer_ipc_name
=
f
"tcp://
{
dist_init_host
}
:
{
port_base
}
"
,
tokenizer_ipc_name
=
f
"tcp://
{
dist_init_host
}
:
{
port_base
}
"
,
scheduler_input_ipc_name
=
f
"tcp://
{
dist_init_host
}
:
{
scheduler_input_port
}
"
,
scheduler_input_ipc_name
=
f
"tcp://
{
dist_init_host
}
:
{
scheduler_input_port
}
"
,
...
...
python/sglang/srt/utils/common.py
View file @
a20e7df8
...
@@ -1291,8 +1291,46 @@ def pytorch_profile(name, func, *args, data_size=-1):
...
@@ -1291,8 +1291,46 @@ def pytorch_profile(name, func, *args, data_size=-1):
def
get_zmq_socket
(
def
get_zmq_socket
(
context
:
zmq
.
Context
,
socket_type
:
zmq
.
SocketType
,
endpoint
:
str
,
bind
:
bool
context
:
zmq
.
Context
,
)
->
zmq
.
Socket
:
socket_type
:
zmq
.
SocketType
,
endpoint
:
Optional
[
str
]
=
None
,
bind
:
bool
=
True
,
)
->
Union
[
zmq
.
Socket
,
Tuple
[
int
,
zmq
.
Socket
]]:
"""Create and configure a ZeroMQ socket.
Args:
context: ZeroMQ context to create the socket from.
socket_type: Type of ZeroMQ socket to create.
endpoint: Optional endpoint to bind/connect to. If None, binds to a random TCP port.
bind: Whether to bind (True) or connect (False) to the endpoint. Ignored if endpoint is None.
Returns:
If endpoint is None: Tuple of (port, socket) where port is the randomly assigned TCP port.
If endpoint is provided: The configured ZeroMQ socket.
"""
socket
=
context
.
socket
(
socket_type
)
if
endpoint
is
None
:
# Bind to random TCP port
config_socket
(
socket
,
socket_type
)
port
=
socket
.
bind_to_random_port
(
"tcp://*"
)
return
port
,
socket
else
:
# Handle IPv6 if endpoint contains brackets
if
endpoint
.
find
(
"["
)
!=
-
1
:
socket
.
setsockopt
(
zmq
.
IPV6
,
1
)
config_socket
(
socket
,
socket_type
)
if
bind
:
socket
.
bind
(
endpoint
)
else
:
socket
.
connect
(
endpoint
)
return
socket
def
config_socket
(
socket
,
socket_type
:
zmq
.
SocketType
):
mem
=
psutil
.
virtual_memory
()
mem
=
psutil
.
virtual_memory
()
total_mem
=
mem
.
total
/
1024
**
3
total_mem
=
mem
.
total
/
1024
**
3
available_mem
=
mem
.
available
/
1024
**
3
available_mem
=
mem
.
available
/
1024
**
3
...
@@ -1301,10 +1339,6 @@ def get_zmq_socket(
...
@@ -1301,10 +1339,6 @@ def get_zmq_socket(
else
:
else
:
buf_size
=
-
1
buf_size
=
-
1
socket
=
context
.
socket
(
socket_type
)
if
endpoint
.
find
(
"["
)
!=
-
1
:
socket
.
setsockopt
(
zmq
.
IPV6
,
1
)
def
set_send_opt
():
def
set_send_opt
():
socket
.
setsockopt
(
zmq
.
SNDHWM
,
0
)
socket
.
setsockopt
(
zmq
.
SNDHWM
,
0
)
socket
.
setsockopt
(
zmq
.
SNDBUF
,
buf_size
)
socket
.
setsockopt
(
zmq
.
SNDBUF
,
buf_size
)
...
@@ -1317,19 +1351,12 @@ def get_zmq_socket(
...
@@ -1317,19 +1351,12 @@ def get_zmq_socket(
set_send_opt
()
set_send_opt
()
elif
socket_type
==
zmq
.
PULL
:
elif
socket_type
==
zmq
.
PULL
:
set_recv_opt
()
set_recv_opt
()
elif
socket_type
==
zmq
.
DEALER
:
elif
socket_type
in
[
zmq
.
DEALER
,
zmq
.
REQ
,
zmq
.
REP
]
:
set_send_opt
()
set_send_opt
()
set_recv_opt
()
set_recv_opt
()
else
:
else
:
raise
ValueError
(
f
"Unsupported socket type:
{
socket_type
}
"
)
raise
ValueError
(
f
"Unsupported socket type:
{
socket_type
}
"
)
if
bind
:
socket
.
bind
(
endpoint
)
else
:
socket
.
connect
(
endpoint
)
return
socket
def
dump_to_file
(
dirpath
,
name
,
value
):
def
dump_to_file
(
dirpath
,
name
,
value
):
from
sglang.srt.distributed
import
get_tensor_model_parallel_rank
from
sglang.srt.distributed
import
get_tensor_model_parallel_rank
...
...
test/srt/test_server_args.py
View file @
a20e7df8
...
@@ -75,7 +75,8 @@ class TestPortArgs(unittest.TestCase):
...
@@ -75,7 +75,8 @@ class TestPortArgs(unittest.TestCase):
server_args
.
nnodes
=
1
server_args
.
nnodes
=
1
server_args
.
dist_init_addr
=
"192.168.1.1:25000"
server_args
.
dist_init_addr
=
"192.168.1.1:25000"
port_args
=
PortArgs
.
init_new
(
server_args
,
dp_rank
=
2
)
worker_ports
=
[
25006
,
25007
,
25008
,
25009
]
port_args
=
PortArgs
.
init_new
(
server_args
,
dp_rank
=
2
,
worker_ports
=
worker_ports
)
self
.
assertTrue
(
port_args
.
scheduler_input_ipc_name
.
endswith
(
":25008"
))
self
.
assertTrue
(
port_args
.
scheduler_input_ipc_name
.
endswith
(
":25008"
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment