Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
635b8972
Unverified
Commit
635b8972
authored
Jan 05, 2025
by
cennn
Committed by
GitHub
Jan 05, 2025
Browse files
[distributed] remove pynccl's redundant stream (#11744)
parent
4068f4b5
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
12 additions
and
24 deletions
+12
-24
tests/distributed/test_pynccl.py
tests/distributed/test_pynccl.py
+2
-3
vllm/distributed/device_communicators/pynccl.py
vllm/distributed/device_communicators/pynccl.py
+9
-19
vllm/distributed/parallel_state.py
vllm/distributed/parallel_state.py
+1
-2
No files found.
tests/distributed/test_pynccl.py
View file @
635b8972
...
...
@@ -137,9 +137,8 @@ def worker_fn_with_cudagraph():
# run something in the default stream to initialize torch engine
a
=
torch
.
ones
((
4
,
4
),
device
=
f
'cuda:
{
pynccl_comm
.
rank
}
'
)
torch
.
cuda
.
synchronize
()
with
torch
.
cuda
.
graph
(
graph
,
stream
=
pynccl_comm
.
stream
),
pynccl_comm
.
change_state
(
enable
=
True
):
with
torch
.
cuda
.
graph
(
graph
),
\
pynccl_comm
.
change_state
(
enable
=
True
):
a_out
=
pynccl_comm
.
all_reduce
(
a
)
torch
.
cuda
.
synchronize
()
graph
.
replay
()
...
...
vllm/distributed/device_communicators/pynccl.py
View file @
635b8972
...
...
@@ -51,7 +51,6 @@ class PyNcclCommunicator:
if
self
.
world_size
==
1
:
self
.
available
=
False
self
.
disabled
=
True
self
.
stream
=
None
return
try
:
self
.
nccl
=
NCCLLibrary
(
library_path
)
...
...
@@ -60,7 +59,6 @@ class PyNcclCommunicator:
# e.g. in a non-GPU environment
self
.
available
=
False
self
.
disabled
=
True
self
.
stream
=
None
return
self
.
available
=
True
...
...
@@ -98,12 +96,12 @@ class PyNcclCommunicator:
with
torch
.
cuda
.
device
(
device
):
self
.
comm
:
ncclComm_t
=
self
.
nccl
.
ncclCommInitRank
(
self
.
world_size
,
self
.
unique_id
,
self
.
rank
)
self
.
stream
=
torch
.
cuda
.
Stream
()
stream
=
torch
.
cuda
.
current_stream
()
# A small all_reduce for warmup.
data
=
torch
.
zeros
(
1
,
device
=
device
)
self
.
all_reduce
(
data
)
self
.
stream
.
synchronize
()
stream
.
synchronize
()
del
data
def
all_reduce
(
self
,
...
...
@@ -122,7 +120,7 @@ class PyNcclCommunicator:
out_tensor
=
torch
.
empty_like
(
in_tensor
)
if
stream
is
None
:
stream
=
self
.
stream
stream
=
torch
.
cuda
.
current_
stream
()
self
.
nccl
.
ncclAllReduce
(
buffer_type
(
in_tensor
.
data_ptr
()),
buffer_type
(
out_tensor
.
data_ptr
()),
in_tensor
.
numel
(),
...
...
@@ -144,7 +142,7 @@ class PyNcclCommunicator:
f
"this nccl communicator is created to work on
{
self
.
device
}
, "
f
"but the input tensor is on
{
input_tensor
.
device
}
"
)
if
stream
is
None
:
stream
=
self
.
stream
stream
=
torch
.
cuda
.
current_
stream
()
self
.
nccl
.
ncclAllGather
(
buffer_type
(
input_tensor
.
data_ptr
()),
buffer_type
(
output_tensor
.
data_ptr
()),
input_tensor
.
numel
(),
...
...
@@ -165,7 +163,7 @@ class PyNcclCommunicator:
f
"this nccl communicator is created to work on
{
self
.
device
}
, "
f
"but the input tensor is on
{
input_tensor
.
device
}
"
)
if
stream
is
None
:
stream
=
self
.
stream
stream
=
torch
.
cuda
.
current_
stream
()
self
.
nccl
.
ncclReduceScatter
(
buffer_type
(
input_tensor
.
data_ptr
()),
buffer_type
(
output_tensor
.
data_ptr
()),
output_tensor
.
numel
(),
...
...
@@ -180,7 +178,7 @@ class PyNcclCommunicator:
f
"this nccl communicator is created to work on
{
self
.
device
}
, "
f
"but the input tensor is on
{
tensor
.
device
}
"
)
if
stream
is
None
:
stream
=
self
.
stream
stream
=
torch
.
cuda
.
current_
stream
()
self
.
nccl
.
ncclSend
(
buffer_type
(
tensor
.
data_ptr
()),
tensor
.
numel
(),
ncclDataTypeEnum
.
from_torch
(
tensor
.
dtype
),
dst
,
self
.
comm
,
cudaStream_t
(
stream
.
cuda_stream
))
...
...
@@ -192,7 +190,7 @@ class PyNcclCommunicator:
f
"this nccl communicator is created to work on
{
self
.
device
}
, "
f
"but the input tensor is on
{
tensor
.
device
}
"
)
if
stream
is
None
:
stream
=
self
.
stream
stream
=
torch
.
cuda
.
current_
stream
()
self
.
nccl
.
ncclRecv
(
buffer_type
(
tensor
.
data_ptr
()),
tensor
.
numel
(),
ncclDataTypeEnum
.
from_torch
(
tensor
.
dtype
),
src
,
self
.
comm
,
cudaStream_t
(
stream
.
cuda_stream
))
...
...
@@ -204,7 +202,7 @@ class PyNcclCommunicator:
f
"this nccl communicator is created to work on
{
self
.
device
}
, "
f
"but the input tensor is on
{
tensor
.
device
}
"
)
if
stream
is
None
:
stream
=
self
.
stream
stream
=
torch
.
cuda
.
current_
stream
()
if
src
==
self
.
rank
:
sendbuff
=
buffer_type
(
tensor
.
data_ptr
())
# NCCL requires the sender also to have a receive buffer
...
...
@@ -217,9 +215,7 @@ class PyNcclCommunicator:
self
.
comm
,
cudaStream_t
(
stream
.
cuda_stream
))
@
contextmanager
def
change_state
(
self
,
enable
:
Optional
[
bool
]
=
None
,
stream
:
Optional
[
torch
.
cuda
.
Stream
]
=
None
):
def
change_state
(
self
,
enable
:
Optional
[
bool
]
=
None
):
"""
A context manager to change the state of the communicator.
"""
...
...
@@ -227,15 +223,9 @@ class PyNcclCommunicator:
# guess a default value when not specified
enable
=
self
.
available
if
stream
is
None
:
stream
=
self
.
stream
old_disable
=
self
.
disabled
old_stream
=
self
.
stream
self
.
stream
=
stream
self
.
disabled
=
not
enable
yield
self
.
disabled
=
old_disable
self
.
stream
=
old_stream
vllm/distributed/parallel_state.py
View file @
635b8972
...
...
@@ -310,8 +310,7 @@ class GroupCoordinator:
if
not
pynccl_comm
:
maybe_pynccl_context
=
nullcontext
()
else
:
maybe_pynccl_context
=
pynccl_comm
.
change_state
(
stream
=
torch
.
cuda
.
current_stream
())
maybe_pynccl_context
=
pynccl_comm
.
change_state
()
with
maybe_pynccl_context
:
yield
graph_capture_context
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment