Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
8fcc69e7
Unverified
Commit
8fcc69e7
authored
Oct 18, 2025
by
Chunyuan WU
Committed by
GitHub
Oct 17, 2025
Browse files
Turn on shm_allreduce and shm_allgather for fp16 (#10725)
parent
f440baa1
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
119 additions
and
1 deletion
+119
-1
python/sglang/srt/utils/common.py
python/sglang/srt/utils/common.py
+1
-1
test/srt/cpu/test_comm.py
test/srt/cpu/test_comm.py
+118
-0
No files found.
python/sglang/srt/utils/common.py
View file @
8fcc69e7
...
...
@@ -2921,7 +2921,7 @@ def get_cpu_ids_by_node():
def
is_shm_available
(
dtype
,
world_size
,
local_size
):
return
(
cpu_has_amx_support
()
and
dtype
in
[
torch
.
bfloat16
,
torch
.
float
]
and
dtype
in
[
torch
.
bfloat16
,
torch
.
float16
,
torch
.
float
]
and
world_size
>=
1
and
world_size
==
local_size
)
...
...
test/srt/cpu/test_comm.py
0 → 100644
View file @
8fcc69e7
import
copy
import
multiprocessing
import
os
import
traceback
import
unittest
from
multiprocessing
import
Process
import
sgl_kernel
import
torch
import
torch.distributed
as
dist
import
torch.multiprocessing
as
mp
from
utils
import
precision
from
sglang.test.test_utils
import
CustomTestCase
,
find_available_port
def
run_distributed_test
(
rank
,
world_size
,
master_port
,
output_writer
,
fn
):
try
:
os
.
environ
[
"RANK"
]
=
str
(
rank
)
os
.
environ
[
"WORLD_SIZE"
]
=
str
(
world_size
)
os
.
environ
[
"MASTER_ADDR"
]
=
"localhost"
os
.
environ
[
"MASTER_PORT"
]
=
str
(
master_port
)
os
.
environ
[
"LOCAL_SIZE"
]
=
str
(
world_size
)
dist
.
init_process_group
(
"gloo"
,
rank
=
rank
,
world_size
=
world_size
)
torch
.
ops
.
sgl_kernel
.
initialize
(
world_size
,
rank
)
fn
(
rank
,
world_size
)
execution_ok
=
True
except
Exception
as
e
:
print
(
f
"subprocess[
{
rank
=
}
] has error:
{
e
}
"
,
flush
=
True
)
traceback
.
print_exc
()
execution_ok
=
False
output_writer
.
send
(
execution_ok
)
output_writer
.
close
()
if
dist
.
is_initialized
():
dist
.
destroy_process_group
()
def
all_reduce_fn
(
rank
,
world_size
):
op
=
dist
.
ReduceOp
.
SUM
for
dtype
in
[
torch
.
float32
,
torch
.
bfloat16
,
torch
.
float16
]:
tensor
=
torch
.
randn
(
2
,
10
,
dtype
=
dtype
)
tensor_shm
=
copy
.
deepcopy
(
tensor
)
dist
.
all_reduce
(
tensor
,
op
=
op
)
torch
.
ops
.
sgl_kernel
.
shm_allreduce
(
tensor_shm
,
op
)
torch
.
testing
.
assert_close
(
tensor
,
tensor_shm
)
def
all_gather_fn
(
rank
,
world_size
):
dim
=
-
1
for
dtype
in
[
torch
.
float32
,
torch
.
bfloat16
,
torch
.
float16
]:
tensor
=
torch
.
randn
(
2
,
10
,
dtype
=
dtype
)
if
dim
<
0
:
# Convert negative dim to positive.
dim
+=
tensor
.
dim
()
input_size
=
tensor
.
size
()
output_size
=
(
input_size
[
0
]
*
world_size
,)
+
input_size
[
1
:]
output_tensor
=
torch
.
empty
(
output_size
,
dtype
=
tensor
.
dtype
,
device
=
tensor
.
device
)
dist
.
all_gather_into_tensor
(
output_tensor
,
tensor
)
output_tensor
=
output_tensor
.
reshape
((
world_size
,)
+
input_size
)
output_tensor
=
output_tensor
.
movedim
(
0
,
dim
)
output_tensor
=
output_tensor
.
reshape
(
input_size
[:
dim
]
+
(
world_size
*
input_size
[
dim
],)
+
input_size
[
dim
+
1
:]
)
output_shm
=
torch
.
ops
.
sgl_kernel
.
shm_allgather
(
tensor
,
dim
)
torch
.
testing
.
assert_close
(
output_tensor
,
output_shm
)
class
TestComm
(
CustomTestCase
):
def
_spawn_and_check
(
self
,
fn
,
world_size
=
2
):
mp
.
set_start_method
(
"spawn"
,
force
=
True
)
master_port
=
find_available_port
(
23456
)
processes
=
[]
output_reader
,
output_writer
=
multiprocessing
.
Pipe
(
duplex
=
False
)
for
rank
in
range
(
world_size
):
p
=
Process
(
target
=
run_distributed_test
,
kwargs
=
dict
(
rank
=
rank
,
world_size
=
world_size
,
master_port
=
master_port
,
output_writer
=
output_writer
,
fn
=
fn
,
),
)
p
.
start
()
processes
.
append
(
p
)
for
_
in
range
(
world_size
):
self
.
assertTrue
(
output_reader
.
recv
(),
"Subprocess fail. Check logs above."
)
for
p
in
processes
:
p
.
join
()
def
test_all_reduce
(
self
):
self
.
_spawn_and_check
(
all_reduce_fn
)
def
test_all_gather
(
self
):
self
.
_spawn_and_check
(
all_gather_fn
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment