Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
d4d0c7c3
Unverified
Commit
d4d0c7c3
authored
Jul 15, 2025
by
ykcombat
Committed by
GitHub
Jul 15, 2025
Browse files
[Feature]TP Group Switching for PD-Multiplexing (#7653)
parent
8d2cf38c
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
49 additions
and
0 deletions
+49
-0
python/sglang/srt/distributed/parallel_state.py
python/sglang/srt/distributed/parallel_state.py
+33
-0
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+1
-0
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+15
-0
No files found.
python/sglang/srt/distributed/parallel_state.py
View file @
d4d0c7c3
...
@@ -1065,8 +1065,23 @@ def init_model_parallel_group(
...
@@ -1065,8 +1065,23 @@ def init_model_parallel_group(
_TP
:
Optional
[
GroupCoordinator
]
=
None
_TP
:
Optional
[
GroupCoordinator
]
=
None
# duplicate GroupCoordinator for prefill in PD-Multiplexing
_PDMUX_PREFILL_TP_GROUP
:
Optional
[
GroupCoordinator
]
=
None
_ENABLE_PDMUX_P_TP
:
bool
=
False
def
set_pdmux_status
(
enable_prefill_multiplexing
:
bool
):
global
_ENABLE_PDMUX_P_TP
_ENABLE_PDMUX_P_TP
=
enable_prefill_multiplexing
def
get_tp_group
()
->
GroupCoordinator
:
def
get_tp_group
()
->
GroupCoordinator
:
if
_ENABLE_PDMUX_P_TP
:
assert
(
_PDMUX_PREFILL_TP_GROUP
is
not
None
),
"tensor model parallel group for PD-Multiplexing Prefill is not initialized"
return
_PDMUX_PREFILL_TP_GROUP
assert
_TP
is
not
None
,
"tensor model parallel group is not initialized"
assert
_TP
is
not
None
,
"tensor model parallel group is not initialized"
return
_TP
return
_TP
...
@@ -1182,6 +1197,7 @@ def initialize_model_parallel(
...
@@ -1182,6 +1197,7 @@ def initialize_model_parallel(
tensor_model_parallel_size
:
int
=
1
,
tensor_model_parallel_size
:
int
=
1
,
pipeline_model_parallel_size
:
int
=
1
,
pipeline_model_parallel_size
:
int
=
1
,
backend
:
Optional
[
str
]
=
None
,
backend
:
Optional
[
str
]
=
None
,
duplicate_tp_group
:
bool
=
False
,
)
->
None
:
)
->
None
:
"""
"""
Initialize model parallel groups.
Initialize model parallel groups.
...
@@ -1239,6 +1255,23 @@ def initialize_model_parallel(
...
@@ -1239,6 +1255,23 @@ def initialize_model_parallel(
group_name
=
"tp"
,
group_name
=
"tp"
,
)
)
if
duplicate_tp_group
:
global
_PDMUX_PREFILL_TP_GROUP
assert
(
_PDMUX_PREFILL_TP_GROUP
is
None
),
"tensor model parallel group for PD-Multiplexing Prefill is already initialized"
_PDMUX_PREFILL_TP_GROUP
=
init_model_parallel_group
(
group_ranks
,
get_world_group
().
local_rank
,
backend
,
use_message_queue_broadcaster
=
get_bool_env_var
(
"SGLANG_USE_MESSAGE_QUEUE_BROADCASTER"
,
"true"
),
group_name
=
"pdmux_prefill_tp"
,
)
_TP
.
pynccl_comm
.
disabled
=
False
_PDMUX_PREFILL_TP_GROUP
.
pynccl_comm
.
disabled
=
False
# Build the pipeline model-parallel groups.
# Build the pipeline model-parallel groups.
num_pipeline_model_parallel_groups
:
int
=
world_size
//
pipeline_model_parallel_size
num_pipeline_model_parallel_groups
:
int
=
world_size
//
pipeline_model_parallel_size
global
_PP
global
_PP
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
d4d0c7c3
...
@@ -539,6 +539,7 @@ class ModelRunner:
...
@@ -539,6 +539,7 @@ class ModelRunner:
initialize_model_parallel
(
initialize_model_parallel
(
tensor_model_parallel_size
=
self
.
tp_size
,
tensor_model_parallel_size
=
self
.
tp_size
,
pipeline_model_parallel_size
=
self
.
pp_size
,
pipeline_model_parallel_size
=
self
.
pp_size
,
duplicate_tp_group
=
self
.
server_args
.
enable_pdmux
,
)
)
initialize_dp_attention
(
initialize_dp_attention
(
enable_dp_attention
=
self
.
server_args
.
enable_dp_attention
,
enable_dp_attention
=
self
.
server_args
.
enable_dp_attention
,
...
...
python/sglang/srt/server_args.py
View file @
d4d0c7c3
...
@@ -251,6 +251,10 @@ class ServerArgs:
...
@@ -251,6 +251,10 @@ class ServerArgs:
custom_weight_loader
:
Optional
[
List
[
str
]]
=
None
custom_weight_loader
:
Optional
[
List
[
str
]]
=
None
weight_loader_disable_mmap
:
bool
=
False
weight_loader_disable_mmap
:
bool
=
False
# For PD-Multiplexing
enable_pdmux
:
bool
=
False
sm_group_num
:
int
=
3
def
__post_init__
(
self
):
def
__post_init__
(
self
):
# Expert parallelism
# Expert parallelism
if
self
.
enable_ep_moe
:
if
self
.
enable_ep_moe
:
...
@@ -1721,6 +1725,17 @@ class ServerArgs:
...
@@ -1721,6 +1725,17 @@ class ServerArgs:
default
=
None
,
default
=
None
,
help
=
"The custom dataloader which used to update the model. Should be set with a valid import path, such as my_package.weight_load_func"
,
help
=
"The custom dataloader which used to update the model. Should be set with a valid import path, such as my_package.weight_load_func"
,
)
)
parser
.
add_argument
(
"--enable-pdmux"
,
action
=
"store_true"
,
help
=
"Enable PD-Multiplexing, PD running on greenctx stream."
,
)
parser
.
add_argument
(
"--sm-group-num"
,
type
=
int
,
default
=
ServerArgs
.
sm_group_num
,
help
=
"Number of sm partition groups."
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--weight-loader-disable-mmap"
,
"--weight-loader-disable-mmap"
,
action
=
"store_true"
,
action
=
"store_true"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment