Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
5ec5eaf7
Unverified
Commit
5ec5eaf7
authored
Mar 30, 2025
by
Yi Zhang
Committed by
GitHub
Mar 29, 2025
Browse files
fix allreduce test (#4909)
parent
0d7fe866
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
14 additions
and
91 deletions
+14
-91
sgl-kernel/tests/test_trt_allreduce.py
sgl-kernel/tests/test_trt_allreduce.py
+14
-91
No files found.
sgl-kernel/tests/test_trt_allreduce.py
View file @
5ec5eaf7
import
ctypes
import
logging
import
multiprocessing
as
mp
import
random
import
socket
import
time
import
unittest
from
typing
import
Any
,
List
,
Optional
import
ray
import
sgl_kernel.allreduce
as
custom_ops
import
torch
import
torch.distributed
as
dist
from
torch.distributed
import
ProcessGroup
from
vllm
import
_custom_ops
as
vllm_ops
from
sglang.srt.distributed.device_communicators.cuda_wrapper
import
CudaRTLibrary
logger
=
logging
.
getLogger
(
__name__
)
def
get_open_port
()
->
int
:
# try ipv4
...
...
@@ -33,22 +28,21 @@ def get_open_port() -> int:
def
multi_process_parallel
(
world_size
:
int
,
cls
:
Any
,
test_target
:
Any
,
)
->
None
:
# Using ray helps debugging the error when it failed
# as compared to multiprocessing.
# NOTE: We need to set working_dir for distributed tests,
# otherwise we may get import errors on ray workers
ray
.
init
(
log_to_driver
=
True
)
procs
=
[]
distributed_init_port
=
get_open_port
()
refs
=
[]
for
rank
in
range
(
world_size
):
refs
.
append
(
test_target
.
remote
(
cls
,
world_size
,
rank
,
distributed_init_port
))
ray
.
get
(
refs
)
for
i
in
range
(
world_size
):
proc
=
mp
.
Process
(
target
=
test_target
,
args
=
(
world_size
,
i
,
distributed_init_port
),
)
proc
.
start
()
procs
.
append
(
proc
)
ray
.
shutdown
()
for
i
in
range
(
world_size
):
procs
[
i
].
join
()
assert
procs
[
i
].
exitcode
==
0
class
TestCustomAllReduce
(
unittest
.
TestCase
):
...
...
@@ -95,13 +89,8 @@ class TestCustomAllReduce(unittest.TestCase):
for
world_size
in
self
.
world_sizes
:
if
world_size
>
torch
.
cuda
.
device_count
():
continue
multi_process_parallel
(
world_size
,
self
,
self
.
correctness
)
def
test_performance
(
self
):
for
world_size
in
self
.
world_sizes
:
if
world_size
>
torch
.
cuda
.
device_count
():
continue
multi_process_parallel
(
world_size
,
self
,
self
.
performance
)
multi_process_parallel
(
world_size
,
self
.
correctness
)
print
(
f
"custom allreduce tp =
{
world_size
}
: OK"
)
def
init_custom_allreduce
(
self
,
rank
,
world_size
,
group
):
buffer_max_size
=
8
*
1024
*
1024
...
...
@@ -137,37 +126,6 @@ class TestCustomAllReduce(unittest.TestCase):
self
.
free_shared_buffer
(
self
.
barrier_out_ptrs
,
group
)
custom_ops
.
custom_dispose
(
self
.
custom_ptr
)
def
init_vllm_allreduce
(
self
,
rank
,
group
):
self
.
vllm_rank
=
rank
self
.
vllm_max_size
=
8
*
1024
*
1024
self
.
vllm_meta_ptrs
=
self
.
create_shared_buffer
(
vllm_ops
.
meta_size
()
+
self
.
vllm_max_size
,
group
=
group
)
self
.
vllm_buffer_ptrs
=
self
.
create_shared_buffer
(
self
.
vllm_max_size
,
group
=
group
)
self
.
vllm_rank_data
=
torch
.
empty
(
8
*
1024
*
1024
,
dtype
=
torch
.
uint8
,
device
=
torch
.
device
(
"cuda:0"
)
)
self
.
vllm_ptr
=
vllm_ops
.
init_custom_ar
(
self
.
vllm_meta_ptrs
,
self
.
vllm_rank_data
,
rank
,
True
)
vllm_ops
.
register_buffer
(
self
.
vllm_ptr
,
self
.
vllm_buffer_ptrs
)
def
vllm_allreduce
(
self
,
inp
,
out
):
vllm_ops
.
all_reduce
(
self
.
vllm_ptr
,
inp
,
out
,
self
.
vllm_buffer_ptrs
[
self
.
vllm_rank
],
self
.
vllm_max_size
,
)
def
free_vllm_allreduce
(
self
,
group
):
vllm_ops
.
dispose
(
self
.
vllm_ptr
)
self
.
free_shared_buffer
(
self
.
vllm_meta_ptrs
,
group
)
self
.
free_shared_buffer
(
self
.
vllm_buffer_ptrs
,
group
)
@
staticmethod
def
init_distributed_env
(
world_size
,
rank
,
distributed_init_port
):
device
=
torch
.
device
(
"cuda:0"
)
...
...
@@ -184,7 +142,6 @@ class TestCustomAllReduce(unittest.TestCase):
return
group
# compare result with torch.distributed
@
ray
.
remote
(
num_gpus
=
1
,
max_calls
=
1
)
def
correctness
(
self
,
world_size
,
rank
,
distributed_init_port
):
group
=
self
.
init_distributed_env
(
world_size
,
rank
,
distributed_init_port
)
...
...
@@ -205,40 +162,6 @@ class TestCustomAllReduce(unittest.TestCase):
self
.
free_custom_allreduce
(
group
)
# compare performance with vllm
@
ray
.
remote
(
num_gpus
=
1
,
max_calls
=
1
)
def
performance
(
self
,
world_size
,
rank
,
distributed_init_port
):
group
=
self
.
init_distributed_env
(
world_size
,
rank
,
distributed_init_port
)
self
.
init_vllm_allreduce
(
rank
,
group
)
self
.
init_custom_allreduce
(
rank
=
rank
,
world_size
=
world_size
,
group
=
group
)
for
sz
in
self
.
test_sizes
:
inp1
=
torch
.
randint
(
1
,
16
,
(
sz
,),
dtype
=
torch
.
float32
,
device
=
torch
.
cuda
.
current_device
()
)
out1
=
torch
.
empty_like
(
inp1
)
test_loop
=
5000
start
=
time
.
time
()
for
_
in
range
(
test_loop
):
self
.
custom_allreduce
(
inp1
,
out1
)
elapse_custom
=
time
.
time
()
-
start
start
=
time
.
time
()
for
_
in
range
(
test_loop
):
self
.
vllm_allreduce
(
inp1
,
out1
)
elapse_vllm
=
time
.
time
()
-
start
if
rank
==
0
:
logger
.
warning
(
f
"test_size =
{
sz
}
, world_size =
{
world_size
}
, "
f
"vllm time =
{
elapse_vllm
*
1000
/
test_loop
:.
4
f
}
ms, "
f
"custom time =
{
elapse_custom
*
1000
/
test_loop
:.
4
f
}
ms "
)
self
.
free_custom_allreduce
(
group
)
self
.
free_vllm_allreduce
(
group
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment