Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
8609e637
Unverified
Commit
8609e637
authored
Jun 20, 2025
by
Cheng Wan
Committed by
GitHub
Jun 20, 2025
Browse files
Fix All-Gather under world size one (#7219)
parent
dea2b84b
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
12 additions
and
4 deletions
+12
-4
python/sglang/srt/distributed/parallel_state.py
python/sglang/srt/distributed/parallel_state.py
+12
-4
No files found.
python/sglang/srt/distributed/parallel_state.py
View file @
8609e637
...
@@ -523,17 +523,25 @@ class GroupCoordinator:
...
@@ -523,17 +523,25 @@ class GroupCoordinator:
self
,
self
,
input_
:
torch
.
Tensor
,
input_
:
torch
.
Tensor
,
dim
:
int
=
-
1
,
dim
:
int
=
-
1
,
tensor_list
:
List
[
torch
.
Tensor
]
=
None
,
output_
tensor_list
:
Optional
[
List
[
torch
.
Tensor
]
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
world_size
=
self
.
world_size
world_size
=
self
.
world_size
# Bypass the function if we are using only 1 GPU.
# Bypass the function if we are using only 1 GPU.
if
world_size
==
1
:
if
world_size
==
1
:
if
output_tensor_list
is
not
None
:
logger
.
warning
(
"Performing in-place all-gather with a group size of 1. "
"This may be unnecessary; consider bypassing it for better efficiency."
)
output_tensor_list
[
0
].
copy_
(
input_
)
return
None
else
:
return
input_
return
input_
if
tensor_list
is
not
None
:
if
output_
tensor_list
is
not
None
:
# TODO(ch-wan): support other backends
# TODO(ch-wan): support other backends
return
torch
.
distributed
.
all_gather
(
return
torch
.
distributed
.
all_gather
(
tensor_list
,
input_
,
group
=
self
.
device_group
output_
tensor_list
,
input_
,
group
=
self
.
device_group
)
)
assert
(
assert
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment