Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
097725bb
Unverified
Commit
097725bb
authored
Oct 02, 2025
by
Lianmin Zheng
Committed by
GitHub
Oct 02, 2025
Browse files
Clean up parallel_state.py (#11148)
parent
44b1fbe2
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
99 additions
and
80 deletions
+99
-80
python/sglang/srt/distributed/parallel_state.py
python/sglang/srt/distributed/parallel_state.py
+99
-80
No files found.
python/sglang/srt/distributed/parallel_state.py
View file @
097725bb
...
...
@@ -4,7 +4,7 @@
# Adapted from
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""
vLLM d
istributed state.
"""
D
istributed state.
It takes over the control of the distributed environment from PyTorch.
The typical workflow is:
...
...
@@ -53,19 +53,26 @@ from sglang.srt.utils import (
_is_npu
=
is_npu
()
_is_cpu
=
is_cpu
()
_supports_custom_op
=
supports_custom_op
()
IS_ONE_DEVICE_PER_PROCESS
=
get_bool_env_var
(
"SGLANG_ONE_DEVICE_PER_PROCESS"
)
TensorMetadata
=
namedtuple
(
"TensorMetadata"
,
[
"device"
,
"dtype"
,
"size"
])
# use int value instead of ReduceOp.SUM to support torch compile
REDUCE_OP_SUM
=
int
(
torch
.
distributed
.
ReduceOp
.
SUM
)
@
dataclass
class
GraphCaptureContext
:
stream
:
torch
.
cuda
.
Stream
if
not
_is_npu
else
torch
.
npu
.
Stream
TensorMetadata
=
namedtuple
(
"TensorMetadata"
,
[
"device"
,
"dtype"
,
"size"
])
# use int value instead of ReduceOp.SUM to support torch compile
REDUCE_OP_SUM
=
int
(
torch
.
distributed
.
ReduceOp
.
SUM
)
@
dataclass
class
P2PWork
:
work
:
Optional
[
torch
.
distributed
.
Work
]
payload
:
Optional
[
torch
.
Tensor
]
def
_split_tensor_dict
(
...
...
@@ -117,7 +124,7 @@ def _register_group(group: "GroupCoordinator") -> None:
_groups
[
group
.
unique_name
]
=
weakref
.
ref
(
group
)
if
supports_custom_op
()
:
if
_
supports_custom_op
:
def
inplace_all_reduce
(
tensor
:
torch
.
Tensor
,
group_name
:
str
)
->
None
:
assert
group_name
in
_groups
,
f
"Group
{
group_name
}
is not found."
...
...
@@ -277,7 +284,7 @@ class GroupCoordinator:
self
.
use_npu_communicator
=
use_npu_communicator
self
.
use_message_queue_broadcaster
=
use_message_queue_broadcaster
#
l
azy import to avoid documentation build error
#
L
azy import to avoid documentation build error
from
sglang.srt.distributed.device_communicators.custom_all_reduce
import
(
CustomAllreduce
,
)
...
...
@@ -497,7 +504,7 @@ class GroupCoordinator:
torch
.
distributed
.
all_reduce
(
input_
,
group
=
self
.
device_group
)
return
input_
if
not
supports_custom_op
()
:
if
not
_
supports_custom_op
:
self
.
_all_reduce_in_place
(
input_
)
return
input_
...
...
@@ -523,23 +530,24 @@ class GroupCoordinator:
outplace_all_reduce_method
=
None
if
(
self
.
qr_comm
is
not
None
and
not
self
.
qr_comm
.
disabled
and
self
.
qr_comm
.
should_quick_allreduce
(
input_
)
):
outplace_all_reduce_method
=
"qr"
elif
(
self
.
ca_comm
is
not
None
and
not
self
.
ca_comm
.
disabled
and
self
.
ca_comm
.
should_custom_ar
(
input_
)
):
outplace_all_reduce_method
=
"ca"
elif
(
self
.
qr_comm
is
not
None
and
not
self
.
qr_comm
.
disabled
and
self
.
qr_comm
.
should_quick_allreduce
(
input_
)
):
outplace_all_reduce_method
=
"qr"
elif
(
self
.
pymscclpp_comm
is
not
None
and
not
self
.
pymscclpp_comm
.
disabled
and
self
.
pymscclpp_comm
.
should_mscclpp_allreduce
(
input_
)
):
outplace_all_reduce_method
=
"pymscclpp"
if
outplace_all_reduce_method
is
not
None
:
return
torch
.
ops
.
sglang
.
outplace_all_reduce
(
input_
,
...
...
@@ -553,16 +561,16 @@ class GroupCoordinator:
def
_all_reduce_out_place
(
self
,
input_
:
torch
.
Tensor
,
outplace_all_reduce_method
:
str
)
->
torch
.
Tensor
:
qr_comm
=
self
.
qr_comm
ca_comm
=
self
.
ca_comm
qr_comm
=
self
.
qr_comm
pymscclpp_comm
=
self
.
pymscclpp_comm
assert
any
([
qr_comm
,
ca_comm
,
pymscclpp_comm
])
if
outplace_all_reduce_method
==
"qr"
:
assert
not
qr_comm
.
disabled
out
=
qr_comm
.
quick_all_reduce
(
input_
)
elif
outplace_all_reduce_method
==
"ca"
:
if
outplace_all_reduce_method
==
"ca"
:
assert
not
ca_comm
.
disabled
out
=
ca_comm
.
custom_all_reduce
(
input_
)
elif
outplace_all_reduce_method
==
"qr"
:
assert
not
qr_comm
.
disabled
out
=
qr_comm
.
quick_all_reduce
(
input_
)
else
:
assert
not
pymscclpp_comm
.
disabled
out
=
pymscclpp_comm
.
all_reduce
(
input_
)
...
...
@@ -637,7 +645,7 @@ class GroupCoordinator:
)
def
all_gather_into_tensor
(
self
,
output
:
torch
.
Tensor
,
input
:
torch
.
Tensor
):
if
_is_npu
or
not
supports_custom_op
()
:
if
_is_npu
or
not
_
supports_custom_op
:
self
.
_all_gather_into_tensor
(
output
,
input
)
else
:
torch
.
ops
.
sglang
.
reg_all_gather_into_tensor
(
...
...
@@ -697,15 +705,13 @@ class GroupCoordinator:
)
# All-gather.
if
input_
.
is_cpu
and
is_shm_available
(
input_
.
dtype
,
self
.
world_size
,
self
.
local_size
):
return
torch
.
ops
.
sgl_kernel
.
shm_allgather
(
input_
,
dim
)
if
input_
.
is_cpu
:
torch
.
distributed
.
all_gather_into_tensor
(
output_tensor
,
input_
,
group
=
self
.
device_group
)
if
is_shm_available
(
input_
.
dtype
,
self
.
world_size
,
self
.
local_size
):
return
torch
.
ops
.
sgl_kernel
.
shm_allgather
(
input_
,
dim
)
else
:
torch
.
distributed
.
all_gather_into_tensor
(
output_tensor
,
input_
,
group
=
self
.
device_group
)
else
:
self
.
all_gather_into_tensor
(
output_tensor
,
input_
)
...
...
@@ -861,45 +867,63 @@ class GroupCoordinator:
torch
.
distributed
.
all_gather_object
(
objs
,
obj
,
group
=
self
.
cpu_group
)
return
objs
def
send_object
(
self
,
obj
:
Any
,
dst
:
int
)
->
None
:
"""Send the input object list to the destination rank."""
"""NOTE: `dst` is the local rank of the destination rank."""
def
send_object
(
self
,
obj
:
Any
,
dst
:
int
,
async_send
:
bool
=
False
,
)
->
List
[
P2PWork
]:
"""
Send the input object list to the destination rank.
This function uses the CPU group for all communications.
assert
dst
<
self
.
world_size
,
f
"Invalid dst rank (
{
dst
}
)"
TODO: If you want to use GPU communication, please add a new argument (e.g., data_group, group),
use other functions (e.g., send), or implement a new function (e.g., send_object_device).
NOTE: `dst` is the local rank of the destination rank.
"""
assert
dst
<
self
.
world_size
,
f
"Invalid dst rank (
{
dst
}
)"
assert
dst
!=
self
.
rank_in_group
,
(
"Invalid destination rank. Destination rank is the same "
"as the current rank."
)
send_func
=
torch
.
distributed
.
isend
if
async_send
else
torch
.
distributed
.
send
# Serialize object to tensor and get the size as well
object_tensor
=
torch
.
frombuffer
(
pickle
.
dumps
(
obj
),
dtype
=
torch
.
uint8
).
cuda
(
device
=
torch
.
cuda
.
current_device
()
)
object_tensor
=
torch
.
frombuffer
(
pickle
.
dumps
(
obj
),
dtype
=
torch
.
uint8
)
size_tensor
=
torch
.
tensor
(
[
object_tensor
.
numel
()],
dtype
=
torch
.
long
,
device
=
"cpu"
,
[
object_tensor
.
numel
()],
dtype
=
torch
.
long
,
device
=
"cpu"
)
# Send object size
torch
.
distributed
.
send
(
size_tensor
,
dst
=
self
.
ranks
[
dst
],
group
=
self
.
cpu_group
)
p2p_work
=
[]
size_work
=
send_func
(
size_tensor
,
self
.
ranks
[
dst
],
group
=
self
.
cpu_group
,
)
if
async_send
:
p2p_work
.
append
(
P2PWork
(
size_work
,
size_tensor
))
# Send object
torch
.
distributed
.
send
(
object_work
=
send_func
(
object_tensor
,
dst
=
self
.
ranks
[
dst
],
group
=
self
.
device
_group
,
self
.
ranks
[
dst
],
group
=
self
.
cpu
_group
,
)
if
async_send
:
p2p_work
.
append
(
P2PWork
(
object_work
,
object_tensor
))
return
None
return
p2p_work
def
recv_object
(
self
,
src
:
int
)
->
Any
:
def
recv_object
(
self
,
src
:
int
,
)
->
Any
:
"""Receive the input object list from the source rank."""
"""NOTE: `src` is the local rank of the source rank."""
assert
src
<
self
.
world_size
,
f
"Invalid src rank (
{
src
}
)"
assert
(
src
!=
self
.
rank_in_group
),
"Invalid source rank. Source rank is the same as the current rank."
...
...
@@ -907,27 +931,25 @@ class GroupCoordinator:
size_tensor
=
torch
.
empty
(
1
,
dtype
=
torch
.
long
,
device
=
"cpu"
)
# Receive object size
rank_size
=
torch
.
distributed
.
recv
(
# We have to use irecv here to make it work for both isend and send.
work
=
torch
.
distributed
.
irecv
(
size_tensor
,
src
=
self
.
ranks
[
src
],
group
=
self
.
cpu_group
)
work
.
wait
()
# Tensor to receive serialized objects into.
object_tensor
=
torch
.
empty
(
# type: ignore[call-overload]
object_tensor
:
Any
=
torch
.
empty
(
# type: ignore[call-overload]
size_tensor
.
item
(),
# type: ignore[arg-type]
dtype
=
torch
.
uint8
,
device
=
torch
.
cuda
.
current_device
()
,
device
=
"cpu"
,
)
rank_object
=
torch
.
distributed
.
recv
(
object_tensor
,
src
=
self
.
ranks
[
src
],
group
=
self
.
device
_group
work
=
torch
.
distributed
.
i
recv
(
object_tensor
,
src
=
self
.
ranks
[
src
],
group
=
self
.
cpu
_group
)
work
.
wait
()
assert
(
rank_object
==
rank_size
),
"Received object sender rank does not match the size sender rank."
obj
=
pickle
.
loads
(
object_tensor
.
cpu
().
numpy
())
obj
=
pickle
.
loads
(
object_tensor
.
numpy
())
return
obj
def
broadcast_tensor_dict
(
...
...
@@ -1017,12 +1039,13 @@ class GroupCoordinator:
tensor_dict
:
Dict
[
str
,
Union
[
torch
.
Tensor
,
Any
]],
dst
:
Optional
[
int
]
=
None
,
all_gather_group
:
Optional
[
"GroupCoordinator"
]
=
None
,
)
->
Optional
[
Dict
[
str
,
Union
[
torch
.
Tensor
,
Any
]]]:
async_send
:
bool
=
False
,
)
->
Optional
[
List
[
P2PWork
]]:
"""Send the input tensor dictionary.
NOTE: `dst` is the local rank of the source rank.
"""
# Bypass the function if we are using only 1 GPU.
if
not
torch
.
distributed
.
is_initialized
()
or
self
.
world_size
==
1
:
if
self
.
world_size
==
1
:
return
tensor_dict
all_gather_size
=
1
if
all_gather_group
is
None
else
all_gather_group
.
world_size
...
...
@@ -1047,7 +1070,10 @@ class GroupCoordinator:
# 1. Superior D2D transfer bandwidth
# 2. Ability to overlap send and recv operations
# Thus the net performance gain justifies this approach.
self
.
send_object
(
metadata_list
,
dst
=
dst
)
send_func
=
torch
.
distributed
.
isend
if
async_send
else
torch
.
distributed
.
send
p2p_works
=
self
.
send_object
(
metadata_list
,
dst
=
dst
,
async_send
=
async_send
)
for
tensor
in
tensor_list
:
if
tensor
.
numel
()
==
0
:
# Skip sending empty tensors.
...
...
@@ -1057,15 +1083,10 @@ class GroupCoordinator:
if
all_gather_group
is
not
None
and
tensor
.
numel
()
%
all_gather_size
==
0
:
tensor
=
tensor
.
reshape
(
all_gather_size
,
-
1
)[
all_gather_rank
]
if
tensor
.
is_cpu
:
# use metadata_group for CPU tensors
torch
.
distributed
.
send
(
tensor
,
dst
=
self
.
ranks
[
dst
],
group
=
metadata_group
)
else
:
# use group for GPU tensors
torch
.
distributed
.
send
(
tensor
,
dst
=
self
.
ranks
[
dst
],
group
=
group
)
return
None
comm_group
=
metadata_group
if
tensor
.
is_cpu
else
group
work
=
send_func
(
tensor
,
self
.
ranks
[
dst
],
group
=
comm_group
)
p2p_works
.
append
(
P2PWork
(
work
,
tensor
))
return
p2p_works
def
recv_tensor_dict
(
self
,
...
...
@@ -1111,17 +1132,15 @@ class GroupCoordinator:
orig_shape
=
tensor
.
shape
tensor
=
tensor
.
reshape
(
all_gather_size
,
-
1
)[
all_gather_rank
]
if
tensor
.
is_cpu
:
# use metadata_group for CPU tensors
torch
.
distributed
.
recv
(
tensor
,
src
=
self
.
ranks
[
src
],
group
=
metadata_group
)
else
:
# use group for GPU tensors
torch
.
distributed
.
recv
(
tensor
,
src
=
self
.
ranks
[
src
],
group
=
group
)
# We have to use irecv here to make it work for both isend and send.
comm_group
=
metadata_group
if
tensor
.
is_cpu
else
group
work
=
torch
.
distributed
.
irecv
(
tensor
,
src
=
self
.
ranks
[
src
],
group
=
comm_group
)
work
.
wait
()
if
use_all_gather
:
# do the allgather
tensor
=
all_gather_group
.
all_gather
(
tensor
,
dim
=
0
)
# type: ignore
tensor
=
all_gather_group
.
all_gather
(
tensor
,
dim
=
0
)
tensor
=
tensor
.
reshape
(
orig_shape
)
tensor_dict
[
key
]
=
tensor
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment