Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
83d55ac5
Unverified
Commit
83d55ac5
authored
Sep 09, 2025
by
Liangsheng Yin
Committed by
GitHub
Sep 09, 2025
Browse files
[1/N]DP refactor: Improve dp rank scheduling in PD disaggregation mode. (#10169)
parent
2fe17735
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
61 additions
and
36 deletions
+61
-36
python/sglang/srt/disaggregation/common/conn.py
python/sglang/srt/disaggregation/common/conn.py
+8
-6
python/sglang/srt/disaggregation/decode.py
python/sglang/srt/disaggregation/decode.py
+1
-1
python/sglang/srt/disaggregation/fake/conn.py
python/sglang/srt/disaggregation/fake/conn.py
+1
-1
python/sglang/srt/disaggregation/mooncake/conn.py
python/sglang/srt/disaggregation/mooncake/conn.py
+8
-6
python/sglang/srt/disaggregation/nixl/conn.py
python/sglang/srt/disaggregation/nixl/conn.py
+2
-2
python/sglang/srt/managers/data_parallel_controller.py
python/sglang/srt/managers/data_parallel_controller.py
+21
-14
python/sglang/srt/managers/multi_tokenizer_mixin.py
python/sglang/srt/managers/multi_tokenizer_mixin.py
+1
-3
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+19
-3
No files found.
python/sglang/srt/disaggregation/common/conn.py
View file @
83d55ac5
...
@@ -128,12 +128,11 @@ class CommonKVReceiver(BaseKVReceiver):
...
@@ -128,12 +128,11 @@ class CommonKVReceiver(BaseKVReceiver):
mgr
:
BaseKVManager
,
mgr
:
BaseKVManager
,
bootstrap_addr
:
str
,
bootstrap_addr
:
str
,
bootstrap_room
:
Optional
[
int
]
=
None
,
bootstrap_room
:
Optional
[
int
]
=
None
,
data_parallel
_rank
:
Optional
[
int
]
=
None
,
prefill_dp
_rank
:
Optional
[
int
]
=
None
,
):
):
self
.
bootstrap_room
=
bootstrap_room
self
.
bootstrap_room
=
bootstrap_room
self
.
bootstrap_addr
=
bootstrap_addr
self
.
bootstrap_addr
=
bootstrap_addr
self
.
kv_mgr
=
mgr
self
.
kv_mgr
=
mgr
self
.
data_parallel_rank
=
data_parallel_rank
if
self
.
bootstrap_addr
not
in
self
.
kv_mgr
.
prefill_dp_size_table
:
if
self
.
bootstrap_addr
not
in
self
.
kv_mgr
.
prefill_dp_size_table
:
self
.
prefill_tp_size
,
self
.
prefill_dp_size
=
(
self
.
prefill_tp_size
,
self
.
prefill_dp_size
=
(
...
@@ -201,11 +200,14 @@ class CommonKVReceiver(BaseKVReceiver):
...
@@ -201,11 +200,14 @@ class CommonKVReceiver(BaseKVReceiver):
self
.
target_tp_rank
=
self
.
target_tp_ranks
[
0
]
self
.
target_tp_rank
=
self
.
target_tp_ranks
[
0
]
self
.
required_dst_info_num
=
1
self
.
required_dst_info_num
=
1
if
self
.
data_parallel
_rank
is
not
None
:
if
prefill_dp
_rank
is
not
None
:
logger
.
debug
(
f
"Targeting DP rank:
{
self
.
data_parallel
_rank
}
"
)
logger
.
debug
(
f
"Targeting DP rank:
{
prefill_dp
_rank
}
"
)
self
.
target_dp_group
=
self
.
data_parallel
_rank
self
.
prefill_dp_rank
=
prefill_dp
_rank
else
:
else
:
self
.
target_dp_group
=
bootstrap_room
%
self
.
prefill_dp_size
self
.
prefill_dp_rank
=
bootstrap_room
%
self
.
prefill_dp_size
# FIXME: alias here: target_dp_group -> prefill_dp_rank
self
.
target_dp_group
=
self
.
prefill_dp_rank
# NOTE: key distinguished by bootstrap_addr, target_dp_group, and target_tp_rank
# NOTE: key distinguished by bootstrap_addr, target_dp_group, and target_tp_rank
bootstrap_key
=
(
bootstrap_key
=
(
...
...
python/sglang/srt/disaggregation/decode.py
View file @
83d55ac5
...
@@ -250,7 +250,7 @@ class DecodePreallocQueue:
...
@@ -250,7 +250,7 @@ class DecodePreallocQueue:
mgr
=
self
.
kv_manager
,
mgr
=
self
.
kv_manager
,
bootstrap_addr
=
f
"
{
req
.
bootstrap_host
}
:
{
req
.
bootstrap_port
}
"
,
bootstrap_addr
=
f
"
{
req
.
bootstrap_host
}
:
{
req
.
bootstrap_port
}
"
,
bootstrap_room
=
req
.
bootstrap_room
,
bootstrap_room
=
req
.
bootstrap_room
,
data_parallel
_rank
=
req
.
data_parallel_rank
,
prefill_dp
_rank
=
req
.
data_parallel_rank
,
)
)
self
.
queue
.
append
(
self
.
queue
.
append
(
...
...
python/sglang/srt/disaggregation/fake/conn.py
View file @
83d55ac5
...
@@ -62,7 +62,7 @@ class FakeKVReceiver(BaseKVReceiver):
...
@@ -62,7 +62,7 @@ class FakeKVReceiver(BaseKVReceiver):
mgr
:
BaseKVManager
,
mgr
:
BaseKVManager
,
bootstrap_addr
:
str
,
bootstrap_addr
:
str
,
bootstrap_room
:
Optional
[
int
]
=
None
,
bootstrap_room
:
Optional
[
int
]
=
None
,
data_parallel
_rank
:
Optional
[
int
]
=
None
,
prefill_dp
_rank
:
Optional
[
int
]
=
None
,
):
):
self
.
has_init
=
False
self
.
has_init
=
False
...
...
python/sglang/srt/disaggregation/mooncake/conn.py
View file @
83d55ac5
...
@@ -1212,7 +1212,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
...
@@ -1212,7 +1212,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
mgr
:
MooncakeKVManager
,
mgr
:
MooncakeKVManager
,
bootstrap_addr
:
str
,
bootstrap_addr
:
str
,
bootstrap_room
:
Optional
[
int
]
=
None
,
bootstrap_room
:
Optional
[
int
]
=
None
,
data_parallel
_rank
:
Optional
[
int
]
=
None
,
prefill_dp
_rank
:
Optional
[
int
]
=
None
,
):
):
self
.
bootstrap_room
=
bootstrap_room
self
.
bootstrap_room
=
bootstrap_room
self
.
bootstrap_addr
=
bootstrap_addr
self
.
bootstrap_addr
=
bootstrap_addr
...
@@ -1221,7 +1221,6 @@ class MooncakeKVReceiver(BaseKVReceiver):
...
@@ -1221,7 +1221,6 @@ class MooncakeKVReceiver(BaseKVReceiver):
self
.
kv_mgr
.
update_status
(
self
.
bootstrap_room
,
KVPoll
.
Bootstrapping
)
self
.
kv_mgr
.
update_status
(
self
.
bootstrap_room
,
KVPoll
.
Bootstrapping
)
self
.
conclude_state
=
None
self
.
conclude_state
=
None
self
.
init_time
=
None
self
.
init_time
=
None
self
.
data_parallel_rank
=
data_parallel_rank
if
self
.
bootstrap_addr
not
in
self
.
kv_mgr
.
prefill_dp_size_table
:
if
self
.
bootstrap_addr
not
in
self
.
kv_mgr
.
prefill_dp_size_table
:
(
(
...
@@ -1320,11 +1319,14 @@ class MooncakeKVReceiver(BaseKVReceiver):
...
@@ -1320,11 +1319,14 @@ class MooncakeKVReceiver(BaseKVReceiver):
self
.
prefill_attn_tp_size
//
self
.
kv_mgr
.
attn_tp_size
self
.
prefill_attn_tp_size
//
self
.
kv_mgr
.
attn_tp_size
)
*
(
self
.
prefill_pp_size
//
self
.
kv_mgr
.
pp_size
)
)
*
(
self
.
prefill_pp_size
//
self
.
kv_mgr
.
pp_size
)
if
self
.
data_parallel
_rank
is
not
None
:
if
prefill_dp
_rank
is
not
None
:
logger
.
debug
(
f
"Targeting DP rank:
{
self
.
data_parallel
_rank
}
"
)
logger
.
debug
(
f
"Targeting DP rank:
{
prefill_dp
_rank
}
"
)
self
.
target_dp_group
=
self
.
data_parallel
_rank
self
.
prefill_dp_rank
=
prefill_dp
_rank
else
:
else
:
self
.
target_dp_group
=
bootstrap_room
%
self
.
prefill_dp_size
self
.
prefill_dp_rank
=
bootstrap_room
%
self
.
prefill_dp_size
# FIXME: alias here: target_dp_group -> prefill_dp_rank
self
.
target_dp_group
=
self
.
prefill_dp_rank
self
.
kv_mgr
.
required_prefill_response_num_table
[
self
.
bootstrap_room
]
=
(
self
.
kv_mgr
.
required_prefill_response_num_table
[
self
.
bootstrap_room
]
=
(
self
.
required_prefill_response_num
self
.
required_prefill_response_num
...
...
python/sglang/srt/disaggregation/nixl/conn.py
View file @
83d55ac5
...
@@ -454,11 +454,11 @@ class NixlKVReceiver(CommonKVReceiver):
...
@@ -454,11 +454,11 @@ class NixlKVReceiver(CommonKVReceiver):
mgr
:
NixlKVManager
,
mgr
:
NixlKVManager
,
bootstrap_addr
:
str
,
bootstrap_addr
:
str
,
bootstrap_room
:
Optional
[
int
]
=
None
,
bootstrap_room
:
Optional
[
int
]
=
None
,
data_parallel
_rank
:
Optional
[
int
]
=
None
,
prefill_dp
_rank
:
Optional
[
int
]
=
None
,
):
):
self
.
started_transfer
=
False
self
.
started_transfer
=
False
self
.
conclude_state
=
None
self
.
conclude_state
=
None
super
().
__init__
(
mgr
,
bootstrap_addr
,
bootstrap_room
,
data_parallel
_rank
)
super
().
__init__
(
mgr
,
bootstrap_addr
,
bootstrap_room
,
prefill_dp
_rank
)
def
init
(
self
,
kv_indices
:
npt
.
NDArray
[
np
.
int32
],
aux_index
:
Optional
[
int
]
=
None
):
def
init
(
self
,
kv_indices
:
npt
.
NDArray
[
np
.
int32
],
aux_index
:
Optional
[
int
]
=
None
):
for
bootstrap_info
in
self
.
bootstrap_infos
:
for
bootstrap_info
in
self
.
bootstrap_infos
:
...
...
python/sglang/srt/managers/data_parallel_controller.py
View file @
83d55ac5
...
@@ -106,7 +106,7 @@ class DataParallelController:
...
@@ -106,7 +106,7 @@ class DataParallelController:
# Launch data parallel workers
# Launch data parallel workers
self
.
scheduler_procs
=
[]
self
.
scheduler_procs
=
[]
self
.
workers
=
[
None
]
*
server_args
.
dp_size
self
.
workers
:
List
[
zmq
.
Socket
]
=
[
None
]
*
server_args
.
dp_size
if
server_args
.
enable_dp_attention
:
if
server_args
.
enable_dp_attention
:
dp_port_args
=
self
.
launch_dp_attention_schedulers
(
server_args
,
port_args
)
dp_port_args
=
self
.
launch_dp_attention_schedulers
(
server_args
,
port_args
)
...
@@ -272,27 +272,34 @@ class DataParallelController:
...
@@ -272,27 +272,34 @@ class DataParallelController:
self
.
max_total_num_tokens
=
scheduler_info
[
0
][
"max_total_num_tokens"
]
self
.
max_total_num_tokens
=
scheduler_info
[
0
][
"max_total_num_tokens"
]
self
.
max_req_input_len
=
scheduler_info
[
0
][
"max_req_input_len"
]
self
.
max_req_input_len
=
scheduler_info
[
0
][
"max_req_input_len"
]
def
maybe_external_dp_rank_routing
(
self
,
req
:
Req
):
if
req
.
data_parallel_rank
is
not
None
:
logger
.
debug
(
f
"Direct routing to DP rank
{
req
.
data_parallel_rank
}
"
)
self
.
workers
[
req
.
data_parallel_rank
].
send_pyobj
(
req
)
return
True
return
False
def
round_robin_scheduler
(
self
,
req
:
Req
):
def
round_robin_scheduler
(
self
,
req
:
Req
):
if
self
.
maybe_external_dp_rank_routing
(
req
):
return
if
self
.
server_args
.
disaggregation_mode
==
"null"
:
if
self
.
server_args
.
disaggregation_mode
==
"null"
:
if
req
.
data_parallel_rank
is
not
None
:
self
.
workers
[
self
.
round_robin_counter
].
send_pyobj
(
req
)
logger
.
debug
(
f
"Direct routing to DP rank
{
req
.
data_parallel_rank
}
"
)
self
.
round_robin_counter
=
(
self
.
round_robin_counter
+
1
)
%
len
(
self
.
workers
[
req
.
data_parallel_rank
].
send_pyobj
(
req
)
self
.
workers
else
:
)
self
.
workers
[
self
.
round_robin_counter
].
send_pyobj
(
req
)
self
.
round_robin_counter
=
(
self
.
round_robin_counter
+
1
)
%
len
(
self
.
workers
)
else
:
else
:
if
req
.
data_parallel_rank
is
not
None
:
self
.
workers
[
req
.
bootstrap_room
%
len
(
self
.
workers
)].
send_pyobj
(
req
)
logger
.
debug
(
f
"Direct routing to DP rank
{
req
.
data_parallel_rank
}
"
)
self
.
workers
[
req
.
data_parallel_rank
].
send_pyobj
(
req
)
else
:
self
.
workers
[
req
.
bootstrap_room
%
len
(
self
.
workers
)].
send_pyobj
(
req
)
def
shortest_queue_scheduler
(
self
,
input_requests
):
def
shortest_queue_scheduler
(
self
,
input_requests
):
if
self
.
maybe_external_dp_rank_routing
(
req
):
return
raise
NotImplementedError
()
raise
NotImplementedError
()
def
minimum_tokens_scheduler
(
self
,
req
):
def
minimum_tokens_scheduler
(
self
,
req
):
if
self
.
maybe_external_dp_rank_routing
(
req
):
return
# This variable corresponds to the balance_id in TokenizedGenerateReqInput.
# This variable corresponds to the balance_id in TokenizedGenerateReqInput.
# We use it to to control the number of onfly tokens (requests dispatched to workers but not yet received).
# We use it to to control the number of onfly tokens (requests dispatched to workers but not yet received).
def
get_next_global_balance_id
()
->
int
:
def
get_next_global_balance_id
()
->
int
:
...
...
python/sglang/srt/managers/multi_tokenizer_mixin.py
View file @
83d55ac5
...
@@ -450,9 +450,7 @@ class MultiTokenizerManager(TokenizerManager):
...
@@ -450,9 +450,7 @@ class MultiTokenizerManager(TokenizerManager):
server_args
:
ServerArgs
,
server_args
:
ServerArgs
,
port_args
:
PortArgs
,
port_args
:
PortArgs
,
):
):
setproctitle
.
setproctitle
(
setproctitle
.
setproctitle
(
f
"sglang::tokenizer_worker:
{
os
.
getpid
()
}
"
)
f
"sglang::http_server/multi_tokenizer_manager:
{
os
.
getpid
()
}
"
)
# prevent init prefill bootstrapserver again
# prevent init prefill bootstrapserver again
disaggregation_mode
=
server_args
.
disaggregation_mode
disaggregation_mode
=
server_args
.
disaggregation_mode
server_args
.
disaggregation_mode
=
"null"
server_args
.
disaggregation_mode
=
"null"
...
...
python/sglang/srt/server_args.py
View file @
83d55ac5
...
@@ -44,6 +44,7 @@ from sglang.srt.utils import (
...
@@ -44,6 +44,7 @@ from sglang.srt.utils import (
is_valid_ipv6_address
,
is_valid_ipv6_address
,
nullable_str
,
nullable_str
,
)
)
from
sglang.utils
import
is_in_ci
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -223,6 +224,8 @@ class ServerArgs:
...
@@ -223,6 +224,8 @@ class ServerArgs:
# Data parallelism
# Data parallelism
dp_size
:
int
=
1
dp_size
:
int
=
1
load_balance_method
:
str
=
"round_robin"
load_balance_method
:
str
=
"round_robin"
# FIXME: remove this after dp rank scheduling is fully supported with PD-Disaggregation
prefill_round_robin_balance
:
bool
=
False
# Multi-node distributed serving
# Multi-node distributed serving
dist_init_addr
:
Optional
[
str
]
=
None
dist_init_addr
:
Optional
[
str
]
=
None
...
@@ -623,12 +626,12 @@ class ServerArgs:
...
@@ -623,12 +626,12 @@ class ServerArgs:
if
self
.
grammar_backend
is
None
:
if
self
.
grammar_backend
is
None
:
self
.
grammar_backend
=
"xgrammar"
self
.
grammar_backend
=
"xgrammar"
if
self
.
dp_size
==
1
:
self
.
enable_dp_attention
=
False
# Data parallelism attention
# Data parallelism attention
if
self
.
enable_dp_attention
:
if
self
.
enable_dp_attention
:
self
.
schedule_conservativeness
=
self
.
schedule_conservativeness
*
0.3
self
.
schedule_conservativeness
=
self
.
schedule_conservativeness
*
0.3
assert
(
self
.
dp_size
>
1
),
"Please set a dp-size > 1. You can use 1 < dp-size <= tp-size "
assert
self
.
tp_size
%
self
.
dp_size
==
0
assert
self
.
tp_size
%
self
.
dp_size
==
0
self
.
chunked_prefill_size
=
self
.
chunked_prefill_size
//
self
.
dp_size
self
.
chunked_prefill_size
=
self
.
chunked_prefill_size
//
self
.
dp_size
logger
.
warning
(
logger
.
warning
(
...
@@ -807,6 +810,13 @@ class ServerArgs:
...
@@ -807,6 +810,13 @@ class ServerArgs:
self
.
disable_radix_cache
=
True
self
.
disable_radix_cache
=
True
logger
.
warning
(
"KV cache is forced as chunk cache for decode server"
)
logger
.
warning
(
"KV cache is forced as chunk cache for decode server"
)
if
self
.
dp_size
>
1
and
not
is_in_ci
():
assert
self
.
prefill_round_robin_balance
,
(
"Prefill round robin balance is required when dp size > 1. "
"Please make sure that the prefill instance is launched with `--load-balance-method round_robin`"
" and `--prefill-round-robin-balance` is set for decode server."
)
elif
self
.
disaggregation_mode
==
"prefill"
:
elif
self
.
disaggregation_mode
==
"prefill"
:
if
self
.
disaggregation_decode_tp
is
None
:
if
self
.
disaggregation_decode_tp
is
None
:
self
.
disaggregation_decode_tp
=
self
.
tp_size
self
.
disaggregation_decode_tp
=
self
.
tp_size
...
@@ -1384,6 +1394,12 @@ class ServerArgs:
...
@@ -1384,6 +1394,12 @@ class ServerArgs:
"minimum_tokens"
,
"minimum_tokens"
,
],
],
)
)
parser
.
add_argument
(
"--prefill-round-robin-balance"
,
default
=
ServerArgs
.
prefill_round_robin_balance
,
action
=
"store_true"
,
help
=
"Prefill is round robin balanced. This is used to promise decode server can get the correct dp rank."
,
)
# Multi-node distributed serving
# Multi-node distributed serving
parser
.
add_argument
(
parser
.
add_argument
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment