Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
5ec5eaf7
"torchvision/vscode:/vscode.git/clone" did not exist on "b5401b9424412a172dc26439d1958d4d9fa7b979"
Unverified
Commit
5ec5eaf7
authored
Mar 30, 2025
by
Yi Zhang
Committed by
GitHub
Mar 29, 2025
Browse files
fix allreduce test (#4909)
parent
0d7fe866
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
14 additions
and
91 deletions
+14
-91
sgl-kernel/tests/test_trt_allreduce.py
sgl-kernel/tests/test_trt_allreduce.py
+14
-91
No files found.
sgl-kernel/tests/test_trt_allreduce.py
View file @
5ec5eaf7
import
ctypes
import
ctypes
import
logging
import
multiprocessing
as
mp
import
random
import
random
import
socket
import
socket
import
time
import
unittest
import
unittest
from
typing
import
Any
,
List
,
Optional
from
typing
import
Any
,
List
,
Optional
import
ray
import
sgl_kernel.allreduce
as
custom_ops
import
sgl_kernel.allreduce
as
custom_ops
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
torch.distributed
import
ProcessGroup
from
torch.distributed
import
ProcessGroup
from
vllm
import
_custom_ops
as
vllm_ops
from
sglang.srt.distributed.device_communicators.cuda_wrapper
import
CudaRTLibrary
from
sglang.srt.distributed.device_communicators.cuda_wrapper
import
CudaRTLibrary
logger
=
logging
.
getLogger
(
__name__
)
def
get_open_port
()
->
int
:
def
get_open_port
()
->
int
:
# try ipv4
# try ipv4
...
@@ -33,22 +28,21 @@ def get_open_port() -> int:
...
@@ -33,22 +28,21 @@ def get_open_port() -> int:
def
multi_process_parallel
(
def
multi_process_parallel
(
world_size
:
int
,
world_size
:
int
,
cls
:
Any
,
test_target
:
Any
,
test_target
:
Any
,
)
->
None
:
)
->
None
:
# Using ray helps debugging the error when it failed
procs
=
[]
# as compared to multiprocessing.
# NOTE: We need to set working_dir for distributed tests,
# otherwise we may get import errors on ray workers
ray
.
init
(
log_to_driver
=
True
)
distributed_init_port
=
get_open_port
()
distributed_init_port
=
get_open_port
()
refs
=
[]
for
i
in
range
(
world_size
):
for
rank
in
range
(
world_size
):
proc
=
mp
.
Process
(
refs
.
append
(
test_target
.
remote
(
cls
,
world_size
,
rank
,
distributed_init_port
))
target
=
test_target
,
ray
.
get
(
refs
)
args
=
(
world_size
,
i
,
distributed_init_port
),
)
proc
.
start
()
procs
.
append
(
proc
)
ray
.
shutdown
()
for
i
in
range
(
world_size
):
procs
[
i
].
join
()
assert
procs
[
i
].
exitcode
==
0
class
TestCustomAllReduce
(
unittest
.
TestCase
):
class
TestCustomAllReduce
(
unittest
.
TestCase
):
...
@@ -95,13 +89,8 @@ class TestCustomAllReduce(unittest.TestCase):
...
@@ -95,13 +89,8 @@ class TestCustomAllReduce(unittest.TestCase):
for
world_size
in
self
.
world_sizes
:
for
world_size
in
self
.
world_sizes
:
if
world_size
>
torch
.
cuda
.
device_count
():
if
world_size
>
torch
.
cuda
.
device_count
():
continue
continue
multi_process_parallel
(
world_size
,
self
,
self
.
correctness
)
multi_process_parallel
(
world_size
,
self
.
correctness
)
print
(
f
"custom allreduce tp =
{
world_size
}
: OK"
)
def
test_performance
(
self
):
for
world_size
in
self
.
world_sizes
:
if
world_size
>
torch
.
cuda
.
device_count
():
continue
multi_process_parallel
(
world_size
,
self
,
self
.
performance
)
def
init_custom_allreduce
(
self
,
rank
,
world_size
,
group
):
def
init_custom_allreduce
(
self
,
rank
,
world_size
,
group
):
buffer_max_size
=
8
*
1024
*
1024
buffer_max_size
=
8
*
1024
*
1024
...
@@ -137,37 +126,6 @@ class TestCustomAllReduce(unittest.TestCase):
...
@@ -137,37 +126,6 @@ class TestCustomAllReduce(unittest.TestCase):
self
.
free_shared_buffer
(
self
.
barrier_out_ptrs
,
group
)
self
.
free_shared_buffer
(
self
.
barrier_out_ptrs
,
group
)
custom_ops
.
custom_dispose
(
self
.
custom_ptr
)
custom_ops
.
custom_dispose
(
self
.
custom_ptr
)
def
init_vllm_allreduce
(
self
,
rank
,
group
):
self
.
vllm_rank
=
rank
self
.
vllm_max_size
=
8
*
1024
*
1024
self
.
vllm_meta_ptrs
=
self
.
create_shared_buffer
(
vllm_ops
.
meta_size
()
+
self
.
vllm_max_size
,
group
=
group
)
self
.
vllm_buffer_ptrs
=
self
.
create_shared_buffer
(
self
.
vllm_max_size
,
group
=
group
)
self
.
vllm_rank_data
=
torch
.
empty
(
8
*
1024
*
1024
,
dtype
=
torch
.
uint8
,
device
=
torch
.
device
(
"cuda:0"
)
)
self
.
vllm_ptr
=
vllm_ops
.
init_custom_ar
(
self
.
vllm_meta_ptrs
,
self
.
vllm_rank_data
,
rank
,
True
)
vllm_ops
.
register_buffer
(
self
.
vllm_ptr
,
self
.
vllm_buffer_ptrs
)
def
vllm_allreduce
(
self
,
inp
,
out
):
vllm_ops
.
all_reduce
(
self
.
vllm_ptr
,
inp
,
out
,
self
.
vllm_buffer_ptrs
[
self
.
vllm_rank
],
self
.
vllm_max_size
,
)
def
free_vllm_allreduce
(
self
,
group
):
vllm_ops
.
dispose
(
self
.
vllm_ptr
)
self
.
free_shared_buffer
(
self
.
vllm_meta_ptrs
,
group
)
self
.
free_shared_buffer
(
self
.
vllm_buffer_ptrs
,
group
)
@
staticmethod
@
staticmethod
def
init_distributed_env
(
world_size
,
rank
,
distributed_init_port
):
def
init_distributed_env
(
world_size
,
rank
,
distributed_init_port
):
device
=
torch
.
device
(
"cuda:0"
)
device
=
torch
.
device
(
"cuda:0"
)
...
@@ -184,7 +142,6 @@ class TestCustomAllReduce(unittest.TestCase):
...
@@ -184,7 +142,6 @@ class TestCustomAllReduce(unittest.TestCase):
return
group
return
group
# compare result with torch.distributed
# compare result with torch.distributed
@
ray
.
remote
(
num_gpus
=
1
,
max_calls
=
1
)
def
correctness
(
self
,
world_size
,
rank
,
distributed_init_port
):
def
correctness
(
self
,
world_size
,
rank
,
distributed_init_port
):
group
=
self
.
init_distributed_env
(
world_size
,
rank
,
distributed_init_port
)
group
=
self
.
init_distributed_env
(
world_size
,
rank
,
distributed_init_port
)
...
@@ -205,40 +162,6 @@ class TestCustomAllReduce(unittest.TestCase):
...
@@ -205,40 +162,6 @@ class TestCustomAllReduce(unittest.TestCase):
self
.
free_custom_allreduce
(
group
)
self
.
free_custom_allreduce
(
group
)
# compare performance with vllm
@
ray
.
remote
(
num_gpus
=
1
,
max_calls
=
1
)
def
performance
(
self
,
world_size
,
rank
,
distributed_init_port
):
group
=
self
.
init_distributed_env
(
world_size
,
rank
,
distributed_init_port
)
self
.
init_vllm_allreduce
(
rank
,
group
)
self
.
init_custom_allreduce
(
rank
=
rank
,
world_size
=
world_size
,
group
=
group
)
for
sz
in
self
.
test_sizes
:
inp1
=
torch
.
randint
(
1
,
16
,
(
sz
,),
dtype
=
torch
.
float32
,
device
=
torch
.
cuda
.
current_device
()
)
out1
=
torch
.
empty_like
(
inp1
)
test_loop
=
5000
start
=
time
.
time
()
for
_
in
range
(
test_loop
):
self
.
custom_allreduce
(
inp1
,
out1
)
elapse_custom
=
time
.
time
()
-
start
start
=
time
.
time
()
for
_
in
range
(
test_loop
):
self
.
vllm_allreduce
(
inp1
,
out1
)
elapse_vllm
=
time
.
time
()
-
start
if
rank
==
0
:
logger
.
warning
(
f
"test_size =
{
sz
}
, world_size =
{
world_size
}
, "
f
"vllm time =
{
elapse_vllm
*
1000
/
test_loop
:.
4
f
}
ms, "
f
"custom time =
{
elapse_custom
*
1000
/
test_loop
:.
4
f
}
ms "
)
self
.
free_custom_allreduce
(
group
)
self
.
free_vllm_allreduce
(
group
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment