Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
c4e81e64
Unverified
Commit
c4e81e64
authored
Oct 20, 2025
by
ykcombat
Committed by
GitHub
Oct 20, 2025
Browse files
[Feature] Use current greenctx stream to communicate in PD-Multiplexing. (#11594)
parent
c726d44c
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
48 additions
and
20 deletions
+48
-20
python/sglang/srt/distributed/device_communicators/pynccl.py
python/sglang/srt/distributed/device_communicators/pynccl.py
+24
-12
python/sglang/srt/distributed/parallel_state.py
python/sglang/srt/distributed/parallel_state.py
+24
-8
No files found.
python/sglang/srt/distributed/device_communicators/pynccl.py
View file @
c4e81e64
...
@@ -30,6 +30,7 @@ class PyNcclCommunicator:
...
@@ -30,6 +30,7 @@ class PyNcclCommunicator:
group
:
Union
[
ProcessGroup
,
StatelessProcessGroup
],
group
:
Union
[
ProcessGroup
,
StatelessProcessGroup
],
device
:
Union
[
int
,
str
,
torch
.
device
],
device
:
Union
[
int
,
str
,
torch
.
device
],
library_path
:
Optional
[
str
]
=
None
,
library_path
:
Optional
[
str
]
=
None
,
use_current_stream
:
bool
=
False
,
):
):
"""
"""
Args:
Args:
...
@@ -74,6 +75,7 @@ class PyNcclCommunicator:
...
@@ -74,6 +75,7 @@ class PyNcclCommunicator:
self
.
available
=
True
self
.
available
=
True
self
.
disabled
=
False
self
.
disabled
=
False
self
.
use_current_stream
=
use_current_stream
self
.
nccl_version
=
self
.
nccl
.
ncclGetRawVersion
()
self
.
nccl_version
=
self
.
nccl
.
ncclGetRawVersion
()
if
self
.
rank
==
0
:
if
self
.
rank
==
0
:
...
@@ -123,6 +125,21 @@ class PyNcclCommunicator:
...
@@ -123,6 +125,21 @@ class PyNcclCommunicator:
# when we are using CUDA graph.
# when we are using CUDA graph.
self
.
disabled
=
True
self
.
disabled
=
True
def
_resolve_stream
(
self
,
stream
:
Optional
[
torch
.
cuda
.
Stream
]):
"""Return the stream to use for NCCL calls.
Behavior mirrors the previous inline logic:
- if an explicit stream is provided, return it
- if stream is None and self.use_current_stream is True, return
torch.cuda.current_stream()
- otherwise return the communicator's default stream (self.stream)
"""
if
stream
is
not
None
:
return
stream
if
self
.
use_current_stream
:
return
torch
.
cuda
.
current_stream
()
return
self
.
stream
def
all_reduce
(
def
all_reduce
(
self
,
tensor
:
torch
.
Tensor
,
op
:
ReduceOp
=
ReduceOp
.
SUM
,
stream
=
None
self
,
tensor
:
torch
.
Tensor
,
op
:
ReduceOp
=
ReduceOp
.
SUM
,
stream
=
None
):
):
...
@@ -135,8 +152,7 @@ class PyNcclCommunicator:
...
@@ -135,8 +152,7 @@ class PyNcclCommunicator:
f
"this nccl communicator is created to work on
{
self
.
device
}
, "
f
"this nccl communicator is created to work on
{
self
.
device
}
, "
f
"but the input tensor is on
{
tensor
.
device
}
"
f
"but the input tensor is on
{
tensor
.
device
}
"
)
)
if
stream
is
None
:
stream
=
self
.
_resolve_stream
(
stream
)
stream
=
self
.
stream
self
.
nccl
.
ncclAllReduce
(
self
.
nccl
.
ncclAllReduce
(
buffer_type
(
tensor
.
data_ptr
()),
buffer_type
(
tensor
.
data_ptr
()),
buffer_type
(
tensor
.
data_ptr
()),
buffer_type
(
tensor
.
data_ptr
()),
...
@@ -163,8 +179,7 @@ class PyNcclCommunicator:
...
@@ -163,8 +179,7 @@ class PyNcclCommunicator:
f
"this nccl communicator is created to work on
{
self
.
device
}
, "
f
"this nccl communicator is created to work on
{
self
.
device
}
, "
f
"but the input tensor is on
{
input_tensor
.
device
}
"
f
"but the input tensor is on
{
input_tensor
.
device
}
"
)
)
if
stream
is
None
:
stream
=
self
.
_resolve_stream
(
stream
)
stream
=
self
.
stream
if
sizes
is
not
None
:
if
sizes
is
not
None
:
split_offset
=
0
split_offset
=
0
...
@@ -210,8 +225,7 @@ class PyNcclCommunicator:
...
@@ -210,8 +225,7 @@ class PyNcclCommunicator:
f
"this nccl communicator is created to work on
{
self
.
device
}
, "
f
"this nccl communicator is created to work on
{
self
.
device
}
, "
f
"but the input tensor is on
{
input_tensor
.
device
}
"
f
"but the input tensor is on
{
input_tensor
.
device
}
"
)
)
if
stream
is
None
:
stream
=
self
.
_resolve_stream
(
stream
)
stream
=
self
.
stream
if
sizes
is
not
None
:
if
sizes
is
not
None
:
split_offset
=
0
split_offset
=
0
...
@@ -249,8 +263,7 @@ class PyNcclCommunicator:
...
@@ -249,8 +263,7 @@ class PyNcclCommunicator:
f
"this nccl communicator is created to work on
{
self
.
device
}
, "
f
"this nccl communicator is created to work on
{
self
.
device
}
, "
f
"but the input tensor is on
{
tensor
.
device
}
"
f
"but the input tensor is on
{
tensor
.
device
}
"
)
)
if
stream
is
None
:
stream
=
self
.
_resolve_stream
(
stream
)
stream
=
self
.
stream
self
.
nccl
.
ncclSend
(
self
.
nccl
.
ncclSend
(
buffer_type
(
tensor
.
data_ptr
()),
buffer_type
(
tensor
.
data_ptr
()),
tensor
.
numel
(),
tensor
.
numel
(),
...
@@ -267,8 +280,7 @@ class PyNcclCommunicator:
...
@@ -267,8 +280,7 @@ class PyNcclCommunicator:
f
"this nccl communicator is created to work on
{
self
.
device
}
, "
f
"this nccl communicator is created to work on
{
self
.
device
}
, "
f
"but the input tensor is on
{
tensor
.
device
}
"
f
"but the input tensor is on
{
tensor
.
device
}
"
)
)
if
stream
is
None
:
stream
=
self
.
_resolve_stream
(
stream
)
stream
=
self
.
stream
self
.
nccl
.
ncclRecv
(
self
.
nccl
.
ncclRecv
(
buffer_type
(
tensor
.
data_ptr
()),
buffer_type
(
tensor
.
data_ptr
()),
tensor
.
numel
(),
tensor
.
numel
(),
...
@@ -285,8 +297,8 @@ class PyNcclCommunicator:
...
@@ -285,8 +297,8 @@ class PyNcclCommunicator:
f
"this nccl communicator is created to work on
{
self
.
device
}
, "
f
"this nccl communicator is created to work on
{
self
.
device
}
, "
f
"but the input tensor is on
{
tensor
.
device
}
"
f
"but the input tensor is on
{
tensor
.
device
}
"
)
)
if
stream
is
None
:
stream
=
self
.
_resolve_stream
(
stream
)
stream
=
self
.
stream
if
src
==
self
.
rank
:
if
src
==
self
.
rank
:
sendbuff
=
buffer_type
(
tensor
.
data_ptr
())
sendbuff
=
buffer_type
(
tensor
.
data_ptr
())
# NCCL requires the sender also to have a receive buffer
# NCCL requires the sender also to have a receive buffer
...
...
python/sglang/srt/distributed/parallel_state.py
View file @
c4e81e64
...
@@ -239,6 +239,7 @@ class GroupCoordinator:
...
@@ -239,6 +239,7 @@ class GroupCoordinator:
use_npu_communicator
:
bool
,
use_npu_communicator
:
bool
,
use_message_queue_broadcaster
:
bool
=
False
,
use_message_queue_broadcaster
:
bool
=
False
,
group_name
:
Optional
[
str
]
=
None
,
group_name
:
Optional
[
str
]
=
None
,
pynccl_use_current_stream
:
bool
=
False
,
torch_compile
:
Optional
[
bool
]
=
None
,
torch_compile
:
Optional
[
bool
]
=
None
,
gloo_timeout
:
timedelta
=
timedelta
(
seconds
=
120
*
60
),
gloo_timeout
:
timedelta
=
timedelta
(
seconds
=
120
*
60
),
):
):
...
@@ -289,6 +290,7 @@ class GroupCoordinator:
...
@@ -289,6 +290,7 @@ class GroupCoordinator:
# Import communicators
# Import communicators
self
.
use_pynccl
=
use_pynccl
self
.
use_pynccl
=
use_pynccl
self
.
pynccl_use_current_stream
=
pynccl_use_current_stream
self
.
use_pymscclpp
=
use_pymscclpp
self
.
use_pymscclpp
=
use_pymscclpp
self
.
use_custom_allreduce
=
use_custom_allreduce
self
.
use_custom_allreduce
=
use_custom_allreduce
self
.
use_torch_symm_mem
=
use_torch_symm_mem
self
.
use_torch_symm_mem
=
use_torch_symm_mem
...
@@ -322,6 +324,7 @@ class GroupCoordinator:
...
@@ -322,6 +324,7 @@ class GroupCoordinator:
self
.
pynccl_comm
=
PyNcclCommunicator
(
self
.
pynccl_comm
=
PyNcclCommunicator
(
group
=
self
.
cpu_group
,
group
=
self
.
cpu_group
,
device
=
self
.
device
,
device
=
self
.
device
,
use_current_stream
=
pynccl_use_current_stream
,
)
)
self
.
pymscclpp_comm
:
Optional
[
PyMscclppCommunicator
]
=
None
self
.
pymscclpp_comm
:
Optional
[
PyMscclppCommunicator
]
=
None
...
@@ -449,10 +452,13 @@ class GroupCoordinator:
...
@@ -449,10 +452,13 @@ class GroupCoordinator:
@
contextmanager
@
contextmanager
def
graph_capture
(
def
graph_capture
(
self
,
graph_capture_context
:
Optional
[
GraphCaptureContext
]
=
None
self
,
graph_capture_context
:
Optional
[
GraphCaptureContext
]
=
None
,
stream
:
Optional
[
torch
.
cuda
.
Stream
]
=
None
,
):
):
if
graph_capture_context
is
None
:
if
graph_capture_context
is
None
:
stream
=
self
.
device_module
.
Stream
()
if
stream
is
None
:
stream
=
self
.
device_module
.
Stream
()
graph_capture_context
=
GraphCaptureContext
(
stream
)
graph_capture_context
=
GraphCaptureContext
(
stream
)
else
:
else
:
stream
=
graph_capture_context
.
stream
stream
=
graph_capture_context
.
stream
...
@@ -1278,6 +1284,7 @@ def init_model_parallel_group(
...
@@ -1278,6 +1284,7 @@ def init_model_parallel_group(
use_message_queue_broadcaster
:
bool
=
False
,
use_message_queue_broadcaster
:
bool
=
False
,
group_name
:
Optional
[
str
]
=
None
,
group_name
:
Optional
[
str
]
=
None
,
use_mscclpp_allreduce
:
Optional
[
bool
]
=
None
,
use_mscclpp_allreduce
:
Optional
[
bool
]
=
None
,
pynccl_use_current_stream
:
bool
=
True
,
use_symm_mem_allreduce
:
Optional
[
bool
]
=
None
,
use_symm_mem_allreduce
:
Optional
[
bool
]
=
None
,
torch_compile
:
Optional
[
bool
]
=
None
,
torch_compile
:
Optional
[
bool
]
=
None
,
)
->
GroupCoordinator
:
)
->
GroupCoordinator
:
...
@@ -1300,6 +1307,7 @@ def init_model_parallel_group(
...
@@ -1300,6 +1307,7 @@ def init_model_parallel_group(
use_npu_communicator
=
True
,
use_npu_communicator
=
True
,
use_message_queue_broadcaster
=
use_message_queue_broadcaster
,
use_message_queue_broadcaster
=
use_message_queue_broadcaster
,
group_name
=
group_name
,
group_name
=
group_name
,
pynccl_use_current_stream
=
pynccl_use_current_stream
,
torch_compile
=
torch_compile
,
torch_compile
=
torch_compile
,
)
)
...
@@ -1357,7 +1365,7 @@ get_pipeline_model_parallel_group = get_pp_group
...
@@ -1357,7 +1365,7 @@ get_pipeline_model_parallel_group = get_pp_group
@
contextmanager
@
contextmanager
def
graph_capture
():
def
graph_capture
(
stream
:
Optional
[
torch
.
cuda
.
Stream
]
=
None
):
"""
"""
`graph_capture` is a context manager which should surround the code that
`graph_capture` is a context manager which should surround the code that
is capturing the CUDA graph. Its main purpose is to ensure that the
is capturing the CUDA graph. Its main purpose is to ensure that the
...
@@ -1371,9 +1379,9 @@ def graph_capture():
...
@@ -1371,9 +1379,9 @@ def graph_capture():
in order to explicitly distinguish the kernels to capture
in order to explicitly distinguish the kernels to capture
from other kernels possibly launched on background in the default stream.
from other kernels possibly launched on background in the default stream.
"""
"""
with
get_tp_group
().
graph_capture
(
)
as
context
,
get_pp_group
().
graph_capture
(
with
get_tp_group
().
graph_capture
(
context
stream
=
stream
):
)
as
context
,
get_pp_group
().
graph_capture
(
context
)
:
yield
context
yield
context
...
@@ -1527,6 +1535,7 @@ def initialize_model_parallel(
...
@@ -1527,6 +1535,7 @@ def initialize_model_parallel(
"SGLANG_USE_MESSAGE_QUEUE_BROADCASTER"
,
"true"
"SGLANG_USE_MESSAGE_QUEUE_BROADCASTER"
,
"true"
),
),
group_name
=
"tp"
,
group_name
=
"tp"
,
pynccl_use_current_stream
=
duplicate_tp_group
,
torch_compile
=
torch_compile
,
torch_compile
=
torch_compile
,
)
)
...
@@ -1543,10 +1552,12 @@ def initialize_model_parallel(
...
@@ -1543,10 +1552,12 @@ def initialize_model_parallel(
"SGLANG_USE_MESSAGE_QUEUE_BROADCASTER"
,
"true"
"SGLANG_USE_MESSAGE_QUEUE_BROADCASTER"
,
"true"
),
),
group_name
=
"pdmux_prefill_tp"
,
group_name
=
"pdmux_prefill_tp"
,
pynccl_use_current_stream
=
True
,
torch_compile
=
torch_compile
,
torch_compile
=
torch_compile
,
)
)
_TP
.
pynccl_comm
.
disabled
=
False
if
_TP
.
pynccl_comm
:
_PDMUX_PREFILL_TP_GROUP
.
pynccl_comm
.
disabled
=
False
_TP
.
pynccl_comm
.
disabled
=
False
_PDMUX_PREFILL_TP_GROUP
.
pynccl_comm
.
disabled
=
False
moe_ep_size
=
expert_model_parallel_size
moe_ep_size
=
expert_model_parallel_size
moe_tp_size
=
tensor_model_parallel_size
//
moe_ep_size
moe_tp_size
=
tensor_model_parallel_size
//
moe_ep_size
...
@@ -1737,6 +1748,11 @@ def destroy_model_parallel():
...
@@ -1737,6 +1748,11 @@ def destroy_model_parallel():
_PP
.
destroy
()
_PP
.
destroy
()
_PP
=
None
_PP
=
None
global
_PDMUX_PREFILL_TP_GROUP
if
_PDMUX_PREFILL_TP_GROUP
:
# type: ignore[union-attr]
_PDMUX_PREFILL_TP_GROUP
.
destroy
()
_PDMUX_PREFILL_TP_GROUP
=
None
def
destroy_distributed_environment
():
def
destroy_distributed_environment
():
global
_WORLD
global
_WORLD
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment