Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
91f50a6f
Unverified
Commit
91f50a6f
authored
Apr 23, 2024
by
youkaichao
Committed by
GitHub
Apr 23, 2024
Browse files
[Core][Distributed] use cpu/gloo to initialize pynccl (#4248)
parent
79a268c4
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
93 additions
and
71 deletions
+93
-71
tests/distributed/test_pynccl.py
tests/distributed/test_pynccl.py
+10
-5
vllm/distributed/device_communicators/pynccl.py
vllm/distributed/device_communicators/pynccl.py
+71
-51
vllm/distributed/device_communicators/pynccl_utils.py
vllm/distributed/device_communicators/pynccl_utils.py
+3
-9
vllm/distributed/parallel_state.py
vllm/distributed/parallel_state.py
+6
-0
vllm/worker/worker.py
vllm/worker/worker.py
+3
-6
No files found.
tests/distributed/test_pynccl.py
View file @
91f50a6f
...
@@ -5,6 +5,7 @@ import torch
...
@@ -5,6 +5,7 @@ import torch
from
vllm.distributed.device_communicators.pynccl
import
(
NCCLCommunicator
,
from
vllm.distributed.device_communicators.pynccl
import
(
NCCLCommunicator
,
ncclGetUniqueId
)
ncclGetUniqueId
)
from
vllm.distributed.parallel_state
import
init_distributed_environment
from
vllm.utils
import
update_environment_variables
from
vllm.utils
import
update_environment_variables
...
@@ -26,19 +27,23 @@ def distributed_run(fn, world_size):
...
@@ -26,19 +27,23 @@ def distributed_run(fn, world_size):
for
p
in
processes
:
for
p
in
processes
:
p
.
join
()
p
.
join
()
for
p
in
processes
:
assert
p
.
exitcode
==
0
def
update_env
(
fn
):
def
worker_fn_wrapper
(
fn
):
# `multiprocessing.Process` cannot accept environment variables directly
# `multiprocessing.Process` cannot accept environment variables directly
# so we need to pass the environment variables as arguments
# so we need to pass the environment variables as arguments
# and update the environment variables in the function
# and update the environment variables in the function
def
wrappe
r
(
env
):
def
wrappe
d_fn
(
env
):
update_environment_variables
(
env
)
update_environment_variables
(
env
)
init_distributed_environment
()
fn
()
fn
()
return
wrappe
r
return
wrappe
d_fn
@
update_env
@
worker_fn_wrapper
def
worker_fn
():
def
worker_fn
():
comm
=
NCCLCommunicator
()
comm
=
NCCLCommunicator
()
tensor
=
torch
.
ones
(
16
,
1024
,
1024
,
dtype
=
torch
.
float32
).
cuda
(
comm
.
rank
)
tensor
=
torch
.
ones
(
16
,
1024
,
1024
,
dtype
=
torch
.
float32
).
cuda
(
comm
.
rank
)
...
@@ -53,7 +58,7 @@ def test_pynccl():
...
@@ -53,7 +58,7 @@ def test_pynccl():
distributed_run
(
worker_fn
,
2
)
distributed_run
(
worker_fn
,
2
)
@
update_env
@
worker_fn_wrapper
def
worker_fn_with_cudagraph
():
def
worker_fn_with_cudagraph
():
with
torch
.
no_grad
():
with
torch
.
no_grad
():
graph
=
torch
.
cuda
.
CUDAGraph
()
graph
=
torch
.
cuda
.
CUDAGraph
()
...
...
vllm/distributed/device_communicators/pynccl.py
View file @
91f50a6f
...
@@ -20,14 +20,15 @@
...
@@ -20,14 +20,15 @@
# variable in the code.
# variable in the code.
import
ctypes
import
ctypes
import
datetime
import
platform
import
platform
from
typing
import
Optional
,
Union
# ===================== import region =====================
# ===================== import region =====================
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
torch.distributed
import
ReduceOp
from
torch.distributed
import
ProcessGroup
,
ReduceOp
from
vllm.distributed.parallel_state
import
get_cpu_world_group
,
get_local_rank
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.utils
import
find_nccl_library
,
nccl_integrity_check
from
vllm.utils
import
find_nccl_library
,
nccl_integrity_check
...
@@ -59,6 +60,18 @@ except Exception as e:
...
@@ -59,6 +60,18 @@ except Exception as e:
ncclResult_t
=
ctypes
.
c_int
ncclResult_t
=
ctypes
.
c_int
_c_ncclGetErrorString
=
nccl
.
ncclGetErrorString
_c_ncclGetErrorString
.
restype
=
ctypes
.
c_char_p
_c_ncclGetErrorString
.
argtypes
=
[
ncclResult_t
]
def
NCCL_CHECK
(
result
:
ncclResult_t
)
->
None
:
if
result
!=
0
:
error_str
=
_c_ncclGetErrorString
(
result
)
error_str
=
error_str
.
decode
(
"utf-8"
)
raise
RuntimeError
(
f
"NCCL error:
{
error_str
}
"
)
# equivalent to c declaration:
# equivalent to c declaration:
# ncclResult_t ncclGetVersion(int *version);
# ncclResult_t ncclGetVersion(int *version);
_c_ncclGetVersion
=
nccl
.
ncclGetVersion
_c_ncclGetVersion
=
nccl
.
ncclGetVersion
...
@@ -68,8 +81,7 @@ _c_ncclGetVersion.argtypes = [ctypes.POINTER(ctypes.c_int)]
...
@@ -68,8 +81,7 @@ _c_ncclGetVersion.argtypes = [ctypes.POINTER(ctypes.c_int)]
def
ncclGetVersion
()
->
str
:
def
ncclGetVersion
()
->
str
:
version
=
ctypes
.
c_int
()
version
=
ctypes
.
c_int
()
result
=
_c_ncclGetVersion
(
ctypes
.
byref
(
version
))
NCCL_CHECK
(
_c_ncclGetVersion
(
ctypes
.
byref
(
version
)))
assert
result
==
0
# something like 21903 --> "2.19.3"
# something like 21903 --> "2.19.3"
version_str
=
str
(
version
.
value
)
version_str
=
str
(
version
.
value
)
major
=
version_str
[
0
].
lstrip
(
"0"
)
major
=
version_str
[
0
].
lstrip
(
"0"
)
...
@@ -91,8 +103,7 @@ _c_ncclGetUniqueId.argtypes = [ctypes.POINTER(NcclUniqueId)]
...
@@ -91,8 +103,7 @@ _c_ncclGetUniqueId.argtypes = [ctypes.POINTER(NcclUniqueId)]
def
ncclGetUniqueId
()
->
NcclUniqueId
:
def
ncclGetUniqueId
()
->
NcclUniqueId
:
unique_id
=
NcclUniqueId
()
unique_id
=
NcclUniqueId
()
result
=
_c_ncclGetUniqueId
(
ctypes
.
byref
(
unique_id
))
NCCL_CHECK
(
_c_ncclGetUniqueId
(
ctypes
.
byref
(
unique_id
)))
assert
result
==
0
return
unique_id
return
unique_id
...
@@ -199,66 +210,75 @@ class NCCLCommunicator:
...
@@ -199,66 +210,75 @@ class NCCLCommunicator:
def
__init__
(
def
__init__
(
self
,
self
,
backend
=
None
,
group
:
Optional
[
ProcessGroup
]
=
None
,
init_method
=
None
,
device
:
Optional
[
Union
[
int
,
str
,
torch
.
device
]]
=
None
,
timeout
=
datetime
.
timedelta
(
seconds
=
10
),
world_size
:
int
=
-
1
,
rank
:
int
=
-
1
,
store
=
None
,
group_name
:
str
=
""
,
pg_options
=
None
,
local_rank
:
int
=
-
1
,
):
):
if
not
dist
.
is_initialized
():
"""
backend
=
backend
or
"nccl"
Args:
assert
backend
==
'nccl'
,
(
group: the process group to work on. If None, it will use the
"only use nccl backend for starting the NCCL communicator"
)
default process group.
dist
.
init_process_group
(
backend
=
backend
,
device: the device to bind the NCCLCommunicator to. If None,
init_method
=
init_method
,
it will be bind to f"cuda:{local_rank}".
timeout
=
timeout
,
It is the caller's responsibility to make sure each communicator
world_size
=
world_size
,
is bind to a unique device.
rank
=
rank
,
"""
store
=
store
,
assert
dist
.
is_initialized
()
group_name
=
group_name
,
group
=
get_cpu_world_group
()
if
group
is
None
else
group
pg_options
=
pg_options
)
assert
dist
.
get_backend
(
group
)
!=
dist
.
Backend
.
NCCL
,
(
self
.
rank
=
dist
.
get_rank
()
"NCCLCommunicator should be attached to a non-NCCL group."
)
self
.
world_size
=
dist
.
get_world_size
()
self
.
group
=
group
if
local_rank
==
-
1
:
self
.
rank
=
dist
.
get_rank
(
group
)
local_rank
=
self
.
rank
self
.
world_size
=
dist
.
get_world_size
(
group
)
self
.
local_rank
=
local_rank
# don't use these args, as they can be -1
# use `self.rank`, `self.local_rank` and `self.world_size` instead
del
world_size
,
rank
,
local_rank
torch
.
cuda
.
set_device
(
self
.
local_rank
)
if
self
.
rank
==
0
:
if
self
.
rank
==
0
:
self
.
unique_id
=
ncclGetUniqueId
()
self
.
unique_id
=
ncclGetUniqueId
()
else
:
else
:
self
.
unique_id
=
NcclUniqueId
()
self
.
unique_id
=
NcclUniqueId
()
tensor
=
torch
.
ByteTensor
(
list
(
self
.
unique_id
.
internal
)).
cuda
(
tensor
=
torch
.
ByteTensor
(
list
(
self
.
unique_id
.
internal
))
self
.
local_rank
)
dist
.
broadcast
(
tensor
,
src
=
0
,
group
=
group
)
dist
.
broadcast
(
tensor
,
src
=
0
)
byte_list
=
tensor
.
tolist
()
byte_list
=
tensor
.
cpu
().
tolist
()
for
i
,
byte
in
enumerate
(
byte_list
):
for
i
,
byte
in
enumerate
(
byte_list
):
self
.
unique_id
.
internal
[
i
]
=
byte
self
.
unique_id
.
internal
[
i
]
=
byte
self
.
comm
=
ctypes
.
c_void_p
()
self
.
comm
=
ctypes
.
c_void_p
()
result
=
_c_ncclCommInitRank
(
ctypes
.
byref
(
self
.
comm
),
self
.
world_size
,
if
device
is
None
:
self
.
unique_id
,
self
.
rank
)
local_rank
=
get_local_rank
()
assert
result
==
0
device
=
torch
.
device
(
f
"cuda:
{
local_rank
}
"
)
self
.
stream
=
torch
.
cuda
.
Stream
(
device
=
f
"cuda:
{
self
.
local_rank
}
"
)
elif
isinstance
(
device
,
int
):
device
=
torch
.
device
(
f
"cuda:
{
device
}
"
)
elif
isinstance
(
device
,
str
):
device
=
torch
.
device
(
device
)
# now `device` is a `torch.device` object
assert
isinstance
(
device
,
torch
.
device
)
self
.
device
=
device
# nccl communicator and stream will use this device
current_device
=
torch
.
cuda
.
current_device
()
try
:
torch
.
cuda
.
set_device
(
device
)
NCCL_CHECK
(
_c_ncclCommInitRank
(
ctypes
.
byref
(
self
.
comm
),
self
.
world_size
,
self
.
unique_id
,
self
.
rank
))
self
.
stream
=
torch
.
cuda
.
Stream
()
finally
:
torch
.
cuda
.
set_device
(
current_device
)
def
all_reduce
(
self
,
def
all_reduce
(
self
,
tensor
:
torch
.
Tensor
,
tensor
:
torch
.
Tensor
,
op
:
ReduceOp
=
ReduceOp
.
SUM
,
op
:
ReduceOp
=
ReduceOp
.
SUM
,
stream
=
None
):
stream
=
None
):
# nccl communicator created on a specific device
# will only work on tensors on the same device
# otherwise it will cause "illegal memory access"
assert
tensor
.
device
==
self
.
device
,
(
f
"this nccl communicator is created to work on
{
self
.
device
}
, "
f
"but the input tensor is on
{
tensor
.
device
}
"
)
if
stream
is
None
:
if
stream
is
None
:
stream
=
self
.
stream
stream
=
self
.
stream
result
=
_c_ncclAllReduce
(
ctypes
.
c_void_p
(
tensor
.
data_ptr
()),
NCCL_CHECK
(
ctypes
.
c_void_p
(
tensor
.
data_ptr
()),
_c_ncclAllReduce
(
ctypes
.
c_void_p
(
tensor
.
data_ptr
()),
tensor
.
numel
(
),
ctypes
.
c_void_p
(
tensor
.
data_ptr
()
),
ncclDataTypeEnum
.
from_torch
(
tensor
.
dtype
),
tensor
.
numel
(
),
ncclRedOp
TypeEnum
.
from_torch
(
op
),
self
.
comm
,
ncclData
TypeEnum
.
from_torch
(
tensor
.
dtype
)
,
ctypes
.
c_void_p
(
stream
.
cuda_stream
))
ncclRedOpTypeEnum
.
from_torch
(
op
),
self
.
comm
,
assert
result
==
0
ctypes
.
c_void_p
(
stream
.
cuda_stream
)))
def
__del__
(
self
):
def
__del__
(
self
):
# `dist` module might have been already destroyed
# `dist` module might have been already destroyed
...
...
vllm/distributed/device_communicators/pynccl_utils.py
View file @
91f50a6f
...
@@ -2,7 +2,7 @@ import contextlib
...
@@ -2,7 +2,7 @@ import contextlib
from
typing
import
Optional
from
typing
import
Optional
import
torch
import
torch
from
torch.distributed
import
ReduceOp
from
torch.distributed
import
ProcessGroup
,
ReduceOp
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
...
@@ -37,17 +37,11 @@ def set_pynccl_stream(stream: torch.cuda.Stream):
...
@@ -37,17 +37,11 @@ def set_pynccl_stream(stream: torch.cuda.Stream):
pass
pass
def
init_process_group
(
world_size
:
int
,
def
init_process_group
(
group
:
Optional
[
ProcessGroup
]
=
None
)
->
None
:
rank
:
int
,
init_method
:
str
,
local_rank
:
int
=
-
1
)
->
None
:
assert
not
is_initialized
()
assert
not
is_initialized
()
global
comm
global
comm
logger
.
info
(
f
"vLLM is using nccl==
{
ncclGetVersion
()
}
"
)
logger
.
info
(
f
"vLLM is using nccl==
{
ncclGetVersion
()
}
"
)
comm
=
NCCLCommunicator
(
init_method
=
init_method
,
comm
=
NCCLCommunicator
(
group
=
group
)
world_size
=
world_size
,
local_rank
=
local_rank
,
rank
=
rank
)
def
all_reduce
(
input_
:
torch
.
Tensor
,
op
=
ReduceOp
.
SUM
)
->
None
:
def
all_reduce
(
input_
:
torch
.
Tensor
,
op
=
ReduceOp
.
SUM
)
->
None
:
...
...
vllm/distributed/parallel_state.py
View file @
91f50a6f
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Tensor and pipeline parallel groups."""
"""Tensor and pipeline parallel groups."""
import
contextlib
import
contextlib
import
os
from
typing
import
Optional
from
typing
import
Optional
import
torch
import
torch
...
@@ -73,6 +74,11 @@ def init_distributed_environment(
...
@@ -73,6 +74,11 @@ def init_distributed_environment(
ranks
=
list
(
range
(
torch
.
distributed
.
get_world_size
()))
ranks
=
list
(
range
(
torch
.
distributed
.
get_world_size
()))
_CPU_WORLD_GROUP
=
torch
.
distributed
.
new_group
(
ranks
=
ranks
,
_CPU_WORLD_GROUP
=
torch
.
distributed
.
new_group
(
ranks
=
ranks
,
backend
=
"gloo"
)
backend
=
"gloo"
)
# set the local rank
# local_rank is not available in torch ProcessGroup,
# see https://github.com/pytorch/pytorch/issues/122816
if
local_rank
==
-
1
and
distributed_init_method
==
"env://"
:
local_rank
=
int
(
os
.
environ
[
'LOCAL_RANK'
])
global
_LOCAL_RANK
global
_LOCAL_RANK
_LOCAL_RANK
=
local_rank
_LOCAL_RANK
=
local_rank
...
...
vllm/worker/worker.py
View file @
91f50a6f
...
@@ -298,12 +298,9 @@ def init_worker_distributed_environment(
...
@@ -298,12 +298,9 @@ def init_worker_distributed_environment(
elif
parallel_config
.
world_size
>
1
:
elif
parallel_config
.
world_size
>
1
:
# NOTE(woosuk): We don't initialize pynccl process group when world size
# NOTE(woosuk): We don't initialize pynccl process group when world size
# is 1.
# is 1.
pynccl_utils
.
init_process_group
(
# NOTE(kaichao): By default, pynccl will use information inside
world_size
=
parallel_config
.
world_size
,
# `parallel_state` for initialization.
local_rank
=
local_rank
,
pynccl_utils
.
init_process_group
()
rank
=
rank
,
init_method
=
distributed_init_method
,
)
ensure_model_parallel_initialized
(
parallel_config
.
tensor_parallel_size
,
ensure_model_parallel_initialized
(
parallel_config
.
tensor_parallel_size
,
parallel_config
.
pipeline_parallel_size
)
parallel_config
.
pipeline_parallel_size
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment