Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
91f50a6f
Unverified
Commit
91f50a6f
authored
Apr 23, 2024
by
youkaichao
Committed by
GitHub
Apr 23, 2024
Browse files
[Core][Distributed] use cpu/gloo to initialize pynccl (#4248)
parent
79a268c4
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
93 additions
and
71 deletions
+93
-71
tests/distributed/test_pynccl.py
tests/distributed/test_pynccl.py
+10
-5
vllm/distributed/device_communicators/pynccl.py
vllm/distributed/device_communicators/pynccl.py
+71
-51
vllm/distributed/device_communicators/pynccl_utils.py
vllm/distributed/device_communicators/pynccl_utils.py
+3
-9
vllm/distributed/parallel_state.py
vllm/distributed/parallel_state.py
+6
-0
vllm/worker/worker.py
vllm/worker/worker.py
+3
-6
No files found.
tests/distributed/test_pynccl.py
View file @
91f50a6f
...
...
@@ -5,6 +5,7 @@ import torch
from
vllm.distributed.device_communicators.pynccl
import
(
NCCLCommunicator
,
ncclGetUniqueId
)
from
vllm.distributed.parallel_state
import
init_distributed_environment
from
vllm.utils
import
update_environment_variables
...
...
@@ -26,19 +27,23 @@ def distributed_run(fn, world_size):
for
p
in
processes
:
p
.
join
()
for
p
in
processes
:
assert
p
.
exitcode
==
0
def
update_env
(
fn
):
def
worker_fn_wrapper
(
fn
):
# `multiprocessing.Process` cannot accept environment variables directly
# so we need to pass the environment variables as arguments
# and update the environment variables in the function
def
wrappe
r
(
env
):
def
wrappe
d_fn
(
env
):
update_environment_variables
(
env
)
init_distributed_environment
()
fn
()
return
wrappe
r
return
wrappe
d_fn
@
update_env
@
worker_fn_wrapper
def
worker_fn
():
comm
=
NCCLCommunicator
()
tensor
=
torch
.
ones
(
16
,
1024
,
1024
,
dtype
=
torch
.
float32
).
cuda
(
comm
.
rank
)
...
...
@@ -53,7 +58,7 @@ def test_pynccl():
distributed_run
(
worker_fn
,
2
)
@
update_env
@
worker_fn_wrapper
def
worker_fn_with_cudagraph
():
with
torch
.
no_grad
():
graph
=
torch
.
cuda
.
CUDAGraph
()
...
...
vllm/distributed/device_communicators/pynccl.py
View file @
91f50a6f
...
...
@@ -20,14 +20,15 @@
# variable in the code.
import
ctypes
import
datetime
import
platform
from
typing
import
Optional
,
Union
# ===================== import region =====================
import
torch
import
torch.distributed
as
dist
from
torch.distributed
import
ReduceOp
from
torch.distributed
import
ProcessGroup
,
ReduceOp
from
vllm.distributed.parallel_state
import
get_cpu_world_group
,
get_local_rank
from
vllm.logger
import
init_logger
from
vllm.utils
import
find_nccl_library
,
nccl_integrity_check
...
...
@@ -59,6 +60,18 @@ except Exception as e:
ncclResult_t
=
ctypes
.
c_int
_c_ncclGetErrorString
=
nccl
.
ncclGetErrorString
_c_ncclGetErrorString
.
restype
=
ctypes
.
c_char_p
_c_ncclGetErrorString
.
argtypes
=
[
ncclResult_t
]
def
NCCL_CHECK
(
result
:
ncclResult_t
)
->
None
:
if
result
!=
0
:
error_str
=
_c_ncclGetErrorString
(
result
)
error_str
=
error_str
.
decode
(
"utf-8"
)
raise
RuntimeError
(
f
"NCCL error:
{
error_str
}
"
)
# equivalent to c declaration:
# ncclResult_t ncclGetVersion(int *version);
_c_ncclGetVersion
=
nccl
.
ncclGetVersion
...
...
@@ -68,8 +81,7 @@ _c_ncclGetVersion.argtypes = [ctypes.POINTER(ctypes.c_int)]
def
ncclGetVersion
()
->
str
:
version
=
ctypes
.
c_int
()
result
=
_c_ncclGetVersion
(
ctypes
.
byref
(
version
))
assert
result
==
0
NCCL_CHECK
(
_c_ncclGetVersion
(
ctypes
.
byref
(
version
)))
# something like 21903 --> "2.19.3"
version_str
=
str
(
version
.
value
)
major
=
version_str
[
0
].
lstrip
(
"0"
)
...
...
@@ -91,8 +103,7 @@ _c_ncclGetUniqueId.argtypes = [ctypes.POINTER(NcclUniqueId)]
def
ncclGetUniqueId
()
->
NcclUniqueId
:
unique_id
=
NcclUniqueId
()
result
=
_c_ncclGetUniqueId
(
ctypes
.
byref
(
unique_id
))
assert
result
==
0
NCCL_CHECK
(
_c_ncclGetUniqueId
(
ctypes
.
byref
(
unique_id
)))
return
unique_id
...
...
@@ -199,66 +210,75 @@ class NCCLCommunicator:
def
__init__
(
self
,
backend
=
None
,
init_method
=
None
,
timeout
=
datetime
.
timedelta
(
seconds
=
10
),
world_size
:
int
=
-
1
,
rank
:
int
=
-
1
,
store
=
None
,
group_name
:
str
=
""
,
pg_options
=
None
,
local_rank
:
int
=
-
1
,
group
:
Optional
[
ProcessGroup
]
=
None
,
device
:
Optional
[
Union
[
int
,
str
,
torch
.
device
]]
=
None
,
):
if
not
dist
.
is_initialized
():
backend
=
backend
or
"nccl"
assert
backend
==
'nccl'
,
(
"only use nccl backend for starting the NCCL communicator"
)
dist
.
init_process_group
(
backend
=
backend
,
init_method
=
init_method
,
timeout
=
timeout
,
world_size
=
world_size
,
rank
=
rank
,
store
=
store
,
group_name
=
group_name
,
pg_options
=
pg_options
)
self
.
rank
=
dist
.
get_rank
()
self
.
world_size
=
dist
.
get_world_size
()
if
local_rank
==
-
1
:
local_rank
=
self
.
rank
self
.
local_rank
=
local_rank
# don't use these args, as they can be -1
# use `self.rank`, `self.local_rank` and `self.world_size` instead
del
world_size
,
rank
,
local_rank
torch
.
cuda
.
set_device
(
self
.
local_rank
)
"""
Args:
group: the process group to work on. If None, it will use the
default process group.
device: the device to bind the NCCLCommunicator to. If None,
it will be bind to f"cuda:{local_rank}".
It is the caller's responsibility to make sure each communicator
is bind to a unique device.
"""
assert
dist
.
is_initialized
()
group
=
get_cpu_world_group
()
if
group
is
None
else
group
assert
dist
.
get_backend
(
group
)
!=
dist
.
Backend
.
NCCL
,
(
"NCCLCommunicator should be attached to a non-NCCL group."
)
self
.
group
=
group
self
.
rank
=
dist
.
get_rank
(
group
)
self
.
world_size
=
dist
.
get_world_size
(
group
)
if
self
.
rank
==
0
:
self
.
unique_id
=
ncclGetUniqueId
()
else
:
self
.
unique_id
=
NcclUniqueId
()
tensor
=
torch
.
ByteTensor
(
list
(
self
.
unique_id
.
internal
)).
cuda
(
self
.
local_rank
)
dist
.
broadcast
(
tensor
,
src
=
0
)
byte_list
=
tensor
.
cpu
().
tolist
()
tensor
=
torch
.
ByteTensor
(
list
(
self
.
unique_id
.
internal
))
dist
.
broadcast
(
tensor
,
src
=
0
,
group
=
group
)
byte_list
=
tensor
.
tolist
()
for
i
,
byte
in
enumerate
(
byte_list
):
self
.
unique_id
.
internal
[
i
]
=
byte
self
.
comm
=
ctypes
.
c_void_p
()
result
=
_c_ncclCommInitRank
(
ctypes
.
byref
(
self
.
comm
),
self
.
world_size
,
self
.
unique_id
,
self
.
rank
)
assert
result
==
0
self
.
stream
=
torch
.
cuda
.
Stream
(
device
=
f
"cuda:
{
self
.
local_rank
}
"
)
if
device
is
None
:
local_rank
=
get_local_rank
()
device
=
torch
.
device
(
f
"cuda:
{
local_rank
}
"
)
elif
isinstance
(
device
,
int
):
device
=
torch
.
device
(
f
"cuda:
{
device
}
"
)
elif
isinstance
(
device
,
str
):
device
=
torch
.
device
(
device
)
# now `device` is a `torch.device` object
assert
isinstance
(
device
,
torch
.
device
)
self
.
device
=
device
# nccl communicator and stream will use this device
current_device
=
torch
.
cuda
.
current_device
()
try
:
torch
.
cuda
.
set_device
(
device
)
NCCL_CHECK
(
_c_ncclCommInitRank
(
ctypes
.
byref
(
self
.
comm
),
self
.
world_size
,
self
.
unique_id
,
self
.
rank
))
self
.
stream
=
torch
.
cuda
.
Stream
()
finally
:
torch
.
cuda
.
set_device
(
current_device
)
def
all_reduce
(
self
,
tensor
:
torch
.
Tensor
,
op
:
ReduceOp
=
ReduceOp
.
SUM
,
stream
=
None
):
# nccl communicator created on a specific device
# will only work on tensors on the same device
# otherwise it will cause "illegal memory access"
assert
tensor
.
device
==
self
.
device
,
(
f
"this nccl communicator is created to work on
{
self
.
device
}
, "
f
"but the input tensor is on
{
tensor
.
device
}
"
)
if
stream
is
None
:
stream
=
self
.
stream
result
=
_c_ncclAllReduce
(
ctypes
.
c_void_p
(
tensor
.
data_ptr
()),
ctypes
.
c_void_p
(
tensor
.
data_ptr
()),
tensor
.
numel
(
),
ncclDataTypeEnum
.
from_torch
(
tensor
.
dtype
),
ncclRedOp
TypeEnum
.
from_torch
(
op
),
self
.
comm
,
ctypes
.
c_void_p
(
stream
.
cuda_stream
))
assert
result
==
0
NCCL_CHECK
(
_c_ncclAllReduce
(
ctypes
.
c_void_p
(
tensor
.
data_ptr
()),
ctypes
.
c_void_p
(
tensor
.
data_ptr
()
),
tensor
.
numel
(
),
ncclData
TypeEnum
.
from_torch
(
tensor
.
dtype
)
,
ncclRedOpTypeEnum
.
from_torch
(
op
),
self
.
comm
,
ctypes
.
c_void_p
(
stream
.
cuda_stream
)))
def
__del__
(
self
):
# `dist` module might have been already destroyed
...
...
vllm/distributed/device_communicators/pynccl_utils.py
View file @
91f50a6f
...
...
@@ -2,7 +2,7 @@ import contextlib
from
typing
import
Optional
import
torch
from
torch.distributed
import
ReduceOp
from
torch.distributed
import
ProcessGroup
,
ReduceOp
from
vllm.logger
import
init_logger
...
...
@@ -37,17 +37,11 @@ def set_pynccl_stream(stream: torch.cuda.Stream):
pass
def
init_process_group
(
world_size
:
int
,
rank
:
int
,
init_method
:
str
,
local_rank
:
int
=
-
1
)
->
None
:
def
init_process_group
(
group
:
Optional
[
ProcessGroup
]
=
None
)
->
None
:
assert
not
is_initialized
()
global
comm
logger
.
info
(
f
"vLLM is using nccl==
{
ncclGetVersion
()
}
"
)
comm
=
NCCLCommunicator
(
init_method
=
init_method
,
world_size
=
world_size
,
local_rank
=
local_rank
,
rank
=
rank
)
comm
=
NCCLCommunicator
(
group
=
group
)
def
all_reduce
(
input_
:
torch
.
Tensor
,
op
=
ReduceOp
.
SUM
)
->
None
:
...
...
vllm/distributed/parallel_state.py
View file @
91f50a6f
...
...
@@ -4,6 +4,7 @@
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Tensor and pipeline parallel groups."""
import
contextlib
import
os
from
typing
import
Optional
import
torch
...
...
@@ -73,6 +74,11 @@ def init_distributed_environment(
ranks
=
list
(
range
(
torch
.
distributed
.
get_world_size
()))
_CPU_WORLD_GROUP
=
torch
.
distributed
.
new_group
(
ranks
=
ranks
,
backend
=
"gloo"
)
# set the local rank
# local_rank is not available in torch ProcessGroup,
# see https://github.com/pytorch/pytorch/issues/122816
if
local_rank
==
-
1
and
distributed_init_method
==
"env://"
:
local_rank
=
int
(
os
.
environ
[
'LOCAL_RANK'
])
global
_LOCAL_RANK
_LOCAL_RANK
=
local_rank
...
...
vllm/worker/worker.py
View file @
91f50a6f
...
...
@@ -298,12 +298,9 @@ def init_worker_distributed_environment(
elif
parallel_config
.
world_size
>
1
:
# NOTE(woosuk): We don't initialize pynccl process group when world size
# is 1.
pynccl_utils
.
init_process_group
(
world_size
=
parallel_config
.
world_size
,
local_rank
=
local_rank
,
rank
=
rank
,
init_method
=
distributed_init_method
,
)
# NOTE(kaichao): By default, pynccl will use information inside
# `parallel_state` for initialization.
pynccl_utils
.
init_process_group
()
ensure_model_parallel_initialized
(
parallel_config
.
tensor_parallel_size
,
parallel_config
.
pipeline_parallel_size
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment