Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9e764e7b
Unverified
Commit
9e764e7b
authored
Jan 06, 2025
by
cennn
Committed by
GitHub
Jan 06, 2025
Browse files
[distributed] remove pynccl's redundant change_state (#11749)
parent
33fc1e2e
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
28 additions
and
62 deletions
+28
-62
tests/distributed/test_pynccl.py
tests/distributed/test_pynccl.py
+27
-37
vllm/distributed/device_communicators/pynccl.py
vllm/distributed/device_communicators/pynccl.py
+0
-17
vllm/distributed/parallel_state.py
vllm/distributed/parallel_state.py
+1
-8
No files found.
tests/distributed/test_pynccl.py
View file @
9e764e7b
...
...
@@ -59,7 +59,6 @@ def worker_fn():
device
=
get_world_group
().
device
)
tensor
=
torch
.
ones
(
16
,
1024
,
1024
,
dtype
=
torch
.
float32
).
cuda
(
pynccl_comm
.
rank
)
with
pynccl_comm
.
change_state
(
enable
=
True
):
tensor
=
pynccl_comm
.
all_reduce
(
tensor
)
torch
.
cuda
.
synchronize
()
assert
torch
.
all
(
tensor
==
pynccl_comm
.
world_size
).
cpu
().
item
()
...
...
@@ -81,7 +80,6 @@ def multiple_allreduce_worker_fn():
group
=
groups
[
0
]
if
torch
.
distributed
.
get_rank
()
in
[
0
,
1
]
else
groups
[
1
]
pynccl_comm
=
PyNcclCommunicator
(
group
=
group
,
device
=
device
)
tensor
=
torch
.
ones
(
16
,
1024
,
1024
,
dtype
=
torch
.
float32
,
device
=
device
)
with
pynccl_comm
.
change_state
(
enable
=
True
):
# two groups can communicate independently
if
torch
.
distributed
.
get_rank
()
in
[
0
,
1
]:
tensor
=
pynccl_comm
.
all_reduce
(
tensor
)
...
...
@@ -137,8 +135,7 @@ def worker_fn_with_cudagraph():
# run something in the default stream to initialize torch engine
a
=
torch
.
ones
((
4
,
4
),
device
=
f
'cuda:
{
pynccl_comm
.
rank
}
'
)
torch
.
cuda
.
synchronize
()
with
torch
.
cuda
.
graph
(
graph
),
\
pynccl_comm
.
change_state
(
enable
=
True
):
with
torch
.
cuda
.
graph
(
graph
):
a_out
=
pynccl_comm
.
all_reduce
(
a
)
torch
.
cuda
.
synchronize
()
graph
.
replay
()
...
...
@@ -167,7 +164,6 @@ def all_gather_worker_fn():
for
r
in
range
(
world_size
)
]).
to
(
device
)
with
pynccl_comm
.
change_state
(
enable
=
True
):
pynccl_comm
.
all_gather
(
result
,
tensor
)
torch
.
cuda
.
synchronize
()
torch
.
testing
.
assert_close
(
result
,
expected
,
rtol
=
1e-5
,
atol
=
1e-8
)
...
...
@@ -205,7 +201,6 @@ def reduce_scatter_worker_fn():
expected
=
sum
(
tensor
[
rank
*
scattered_size
:(
rank
+
1
)
*
scattered_size
]
for
tensor
in
all_tensors
).
to
(
device
)
with
pynccl_comm
.
change_state
(
enable
=
True
):
pynccl_comm
.
reduce_scatter
(
result
,
tensor
)
torch
.
cuda
.
synchronize
()
torch
.
testing
.
assert_close
(
result
,
expected
,
rtol
=
1e-5
,
atol
=
1e-8
)
...
...
@@ -233,15 +228,13 @@ def send_recv_worker_fn():
else
:
tensor
=
torch
.
empty
(
16
,
1024
,
1024
,
dtype
=
torch
.
float32
).
cuda
(
pynccl_comm
.
rank
)
with
pynccl_comm
.
change_state
(
enable
=
True
):
if
pynccl_comm
.
rank
==
0
:
pynccl_comm
.
send
(
tensor
,
dst
=
(
pynccl_comm
.
rank
+
1
)
%
pynccl_comm
.
world_size
)
dst
=
(
pynccl_comm
.
rank
+
1
)
%
pynccl_comm
.
world_size
)
else
:
pynccl_comm
.
recv
(
tensor
,
src
=
(
pynccl_comm
.
rank
-
1
)
%
pynccl_comm
.
world_size
)
src
=
(
pynccl_comm
.
rank
-
1
)
%
pynccl_comm
.
world_size
)
torch
.
cuda
.
synchronize
()
assert
torch
.
all
(
tensor
==
1
).
cpu
().
item
()
...
...
@@ -272,15 +265,12 @@ def multiple_send_recv_worker_fn():
1024
,
dtype
=
torch
.
float32
,
device
=
device
)
with
pynccl_comm
.
change_state
(
enable
=
True
):
if
torch
.
distributed
.
get_rank
()
in
[
0
,
1
]:
pynccl_comm
.
send
(
tensor
,
dst
=
(
pynccl_comm
.
rank
+
1
)
%
pynccl_comm
.
world_size
)
dst
=
(
pynccl_comm
.
rank
+
1
)
%
pynccl_comm
.
world_size
)
else
:
pynccl_comm
.
recv
(
tensor
,
src
=
(
pynccl_comm
.
rank
-
1
)
%
pynccl_comm
.
world_size
)
src
=
(
pynccl_comm
.
rank
-
1
)
%
pynccl_comm
.
world_size
)
torch
.
cuda
.
synchronize
()
if
torch
.
distributed
.
get_rank
()
in
[
0
,
2
]:
assert
torch
.
all
(
tensor
==
1
).
cpu
().
item
()
...
...
vllm/distributed/device_communicators/pynccl.py
View file @
9e764e7b
from
contextlib
import
contextmanager
from
typing
import
Optional
,
Union
# ===================== import region =====================
...
...
@@ -213,19 +212,3 @@ class PyNcclCommunicator:
self
.
nccl
.
ncclBroadcast
(
sendbuff
,
recvbuff
,
tensor
.
numel
(),
ncclDataTypeEnum
.
from_torch
(
tensor
.
dtype
),
src
,
self
.
comm
,
cudaStream_t
(
stream
.
cuda_stream
))
@
contextmanager
def
change_state
(
self
,
enable
:
Optional
[
bool
]
=
None
):
"""
A context manager to change the state of the communicator.
"""
if
enable
is
None
:
# guess a default value when not specified
enable
=
self
.
available
old_disable
=
self
.
disabled
self
.
disabled
=
not
enable
yield
self
.
disabled
=
old_disable
vllm/distributed/parallel_state.py
View file @
9e764e7b
...
...
@@ -305,13 +305,6 @@ class GroupCoordinator:
stream
.
wait_stream
(
curr_stream
)
with
torch
.
cuda
.
stream
(
stream
),
maybe_ca_context
:
pynccl_comm
=
self
.
pynccl_comm
maybe_pynccl_context
:
Any
if
not
pynccl_comm
:
maybe_pynccl_context
=
nullcontext
()
else
:
maybe_pynccl_context
=
pynccl_comm
.
change_state
()
with
maybe_pynccl_context
:
yield
graph_capture_context
def
all_reduce
(
self
,
input_
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment