Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
515386ef
Unverified
Commit
515386ef
authored
Mar 29, 2024
by
Roy
Committed by
GitHub
Mar 28, 2024
Browse files
[Core] Support multi-node inference(eager and cuda graph) (#3686)
parent
a4075cba
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
25 additions
and
22 deletions
+25
-22
tests/distributed/test_comm_ops.py
tests/distributed/test_comm_ops.py
+3
-3
tests/distributed/test_custom_all_reduce.py
tests/distributed/test_custom_all_reduce.py
+2
-2
vllm/executor/ray_gpu_executor.py
vllm/executor/ray_gpu_executor.py
+0
-2
vllm/model_executor/parallel_utils/pynccl.py
vllm/model_executor/parallel_utils/pynccl.py
+8
-10
vllm/model_executor/parallel_utils/pynccl_utils.py
vllm/model_executor/parallel_utils/pynccl_utils.py
+3
-1
vllm/test_utils.py
vllm/test_utils.py
+5
-1
vllm/worker/worker.py
vllm/worker/worker.py
+4
-3
No files found.
tests/distributed/test_comm_ops.py
View file @
515386ef
...
...
@@ -24,7 +24,7 @@ def all_reduce_test_worker(tensor_parallel_size: int, rank: int,
del
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
device
=
torch
.
device
(
f
"cuda:
{
rank
}
"
)
torch
.
cuda
.
set_device
(
device
)
init_test_distributed_environment
(
1
,
tensor_parallel_size
,
rank
,
init_test_distributed_environment
(
1
,
tensor_parallel_size
,
rank
,
rank
,
distributed_init_port
)
num_elements
=
8
all_tensors
=
[
...
...
@@ -46,7 +46,7 @@ def all_gather_test_worker(tensor_parallel_size: int, rank: int,
del
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
device
=
torch
.
device
(
f
"cuda:
{
rank
}
"
)
torch
.
cuda
.
set_device
(
device
)
init_test_distributed_environment
(
1
,
tensor_parallel_size
,
rank
,
init_test_distributed_environment
(
1
,
tensor_parallel_size
,
rank
,
rank
,
distributed_init_port
)
num_dimensions
=
3
tensor_size
=
list
(
range
(
2
,
num_dimensions
+
2
))
...
...
@@ -74,7 +74,7 @@ def broadcast_tensor_dict_test_worker(tensor_parallel_size: int, rank: int,
del
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
device
=
torch
.
device
(
f
"cuda:
{
rank
}
"
)
torch
.
cuda
.
set_device
(
device
)
init_test_distributed_environment
(
1
,
tensor_parallel_size
,
rank
,
init_test_distributed_environment
(
1
,
tensor_parallel_size
,
rank
,
rank
,
distributed_init_port
)
test_dict
=
{
"a"
:
torch
.
arange
(
8
,
dtype
=
torch
.
float32
,
device
=
"cuda"
),
...
...
tests/distributed/test_custom_all_reduce.py
View file @
515386ef
...
...
@@ -23,7 +23,7 @@ def graph_allreduce(world_size, rank, distributed_init_port):
del
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
device
=
torch
.
device
(
f
"cuda:
{
rank
}
"
)
torch
.
cuda
.
set_device
(
device
)
init_test_distributed_environment
(
1
,
world_size
,
rank
,
init_test_distributed_environment
(
1
,
world_size
,
rank
,
rank
,
distributed_init_port
)
custom_ar
.
init_custom_ar
()
...
...
@@ -58,7 +58,7 @@ def eager_allreduce(world_size, rank, distributed_init_port):
del
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
device
=
torch
.
device
(
f
"cuda:
{
rank
}
"
)
torch
.
cuda
.
set_device
(
device
)
init_test_distributed_environment
(
1
,
world_size
,
rank
,
init_test_distributed_environment
(
1
,
world_size
,
rank
,
rank
,
distributed_init_port
)
sz
=
1024
...
...
vllm/executor/ray_gpu_executor.py
View file @
515386ef
...
...
@@ -188,8 +188,6 @@ class RayGPUExecutor(ExecutorBase):
is_driver_worker
=
True
,
)
# FIXME(woosuk): We are not properly initializing pynccl when
# we have multiple nodes.
self
.
_run_workers
(
"init_device"
)
self
.
_run_workers
(
"load_model"
,
...
...
vllm/model_executor/parallel_utils/pynccl.py
View file @
515386ef
...
...
@@ -202,6 +202,7 @@ class NCCLCommunicator:
init_method
=
None
,
timeout
=
datetime
.
timedelta
(
seconds
=
10
),
world_size
:
int
=
-
1
,
local_rank
:
int
=
-
1
,
rank
:
int
=
-
1
,
store
=
None
,
group_name
:
str
=
""
,
...
...
@@ -219,25 +220,22 @@ class NCCLCommunicator:
store
=
store
,
group_name
=
group_name
,
pg_options
=
pg_options
)
self
.
world_size
=
dist
.
get_world_size
()
self
.
rank
=
dist
.
get_rank
()
torch
.
cuda
.
set_device
(
self
.
rank
)
if
self
.
rank
==
0
:
torch
.
cuda
.
set_device
(
local_rank
)
if
rank
==
0
:
self
.
unique_id
=
ncclGetUniqueId
()
else
:
self
.
unique_id
=
NcclUniqueId
()
tensor
=
torch
.
ByteTensor
(
list
(
self
.
unique_id
.
internal
)).
cuda
(
self
.
rank
)
tensor
=
torch
.
ByteTensor
(
list
(
self
.
unique_id
.
internal
)).
cuda
(
local_
rank
)
dist
.
broadcast
(
tensor
,
src
=
0
)
byte_list
=
tensor
.
cpu
().
tolist
()
self
.
unique_id
=
NcclUniqueId
()
for
i
,
byte
in
enumerate
(
byte_list
):
self
.
unique_id
.
internal
[
i
]
=
byte
self
.
comm
=
ctypes
.
c_void_p
()
result
=
_c_ncclCommInitRank
(
ctypes
.
byref
(
self
.
comm
),
self
.
world_size
,
self
.
unique_id
,
self
.
rank
)
result
=
_c_ncclCommInitRank
(
ctypes
.
byref
(
self
.
comm
),
world_size
,
self
.
unique_id
,
rank
)
assert
result
==
0
self
.
stream
=
torch
.
cuda
.
Stream
(
device
=
f
"cuda:
{
self
.
rank
}
"
)
self
.
stream
=
torch
.
cuda
.
Stream
(
device
=
f
"cuda:
{
local_
rank
}
"
)
def
all_reduce
(
self
,
tensor
:
torch
.
Tensor
,
...
...
vllm/model_executor/parallel_utils/pynccl_utils.py
View file @
515386ef
...
...
@@ -36,11 +36,13 @@ def set_pynccl_stream(stream: torch.cuda.Stream):
pass
def
init_process_group
(
world_size
:
int
,
rank
:
int
,
init_method
:
str
)
->
None
:
def
init_process_group
(
world_size
:
int
,
local_rank
:
int
,
rank
:
int
,
init_method
:
str
)
->
None
:
assert
not
is_initialized
()
global
comm
comm
=
NCCLCommunicator
(
init_method
=
init_method
,
world_size
=
world_size
,
local_rank
=
local_rank
,
rank
=
rank
)
...
...
vllm/test_utils.py
View file @
515386ef
...
...
@@ -8,6 +8,7 @@ from vllm.worker.worker import init_distributed_environment
def
init_test_distributed_environment
(
pipeline_parallel_size
:
int
,
tensor_parallel_size
:
int
,
local_rank
:
int
,
rank
:
int
,
distributed_init_port
:
str
,
)
->
None
:
...
...
@@ -16,7 +17,10 @@ def init_test_distributed_environment(
worker_use_ray
=
True
)
distributed_init_method
=
f
"tcp://localhost:
{
distributed_init_port
}
"
init_distributed_environment
(
parallel_config
,
rank
,
distributed_init_method
=
distributed_init_method
)
parallel_config
,
local_rank
,
rank
,
distributed_init_method
=
distributed_init_method
)
def
multi_process_tensor_parallel
(
...
...
vllm/worker/worker.py
View file @
515386ef
...
...
@@ -97,8 +97,8 @@ class Worker:
raise
RuntimeError
(
f
"Not support device type:
{
self
.
device_config
.
device
}
"
)
# Initialize the distributed environment.
init_distributed_environment
(
self
.
parallel_config
,
self
.
rank
,
self
.
distributed_init_method
)
init_distributed_environment
(
self
.
parallel_config
,
self
.
local_
rank
,
self
.
rank
,
self
.
distributed_init_method
)
# Set random seed.
set_random_seed
(
self
.
model_config
.
seed
)
...
...
@@ -249,6 +249,7 @@ class Worker:
def
init_distributed_environment
(
parallel_config
:
ParallelConfig
,
local_rank
:
int
,
rank
:
int
,
distributed_init_method
:
Optional
[
str
]
=
None
,
)
->
None
:
...
...
@@ -282,9 +283,9 @@ def init_distributed_environment(
elif
parallel_config
.
world_size
>
1
:
# NOTE(woosuk): We don't initialize pynccl process group when world size
# is 1.
# TODO(woosuk): Support multi-node connection.
pynccl_utils
.
init_process_group
(
world_size
=
parallel_config
.
world_size
,
local_rank
=
local_rank
,
rank
=
rank
,
init_method
=
distributed_init_method
,
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment