Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
344a5d0c
Unverified
Commit
344a5d0c
authored
May 02, 2024
by
youkaichao
Committed by
GitHub
May 02, 2024
Browse files
[Core][Distributed] enable allreduce for multiple tp groups (#4566)
parent
0f8a9140
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
71 additions
and
22 deletions
+71
-22
tests/distributed/test_pynccl.py
tests/distributed/test_pynccl.py
+39
-4
vllm/distributed/communication_op.py
vllm/distributed/communication_op.py
+0
-1
vllm/distributed/parallel_state.py
vllm/distributed/parallel_state.py
+25
-11
vllm/worker/worker.py
vllm/worker/worker.py
+7
-6
No files found.
tests/distributed/test_pynccl.py
View file @
344a5d0c
...
...
@@ -3,9 +3,13 @@ import multiprocessing
import
pytest
import
torch
import
vllm.distributed.device_communicators.pynccl_utils
as
pynccl_utils
from
vllm.distributed.communication_op
import
tensor_model_parallel_all_reduce
from
vllm.distributed.device_communicators.pynccl
import
(
NCCLCommunicator
,
ncclGetUniqueId
)
from
vllm.distributed.parallel_state
import
init_distributed_environment
from
vllm.distributed.parallel_state
import
(
ensure_model_parallel_initialized
,
get_tensor_model_parallel_cpu_group
,
init_distributed_environment
,
with_pynccl_for_all_reduce
)
from
vllm.utils
import
update_environment_variables
...
...
@@ -67,7 +71,7 @@ def multiple_tp_worker_fn():
]
group
=
groups
[
0
]
if
torch
.
distributed
.
get_rank
()
in
[
0
,
1
]
else
groups
[
1
]
comm
=
NCCLCommunicator
(
group
=
group
,
device
=
device
)
tensor
=
torch
.
ones
(
16
,
1024
,
1024
,
dtype
=
torch
.
float32
).
cuda
(
comm
.
rank
)
tensor
=
torch
.
ones
(
16
,
1024
,
1024
,
dtype
=
torch
.
float32
,
device
=
device
)
# two groups can communicate independently
if
torch
.
distributed
.
get_rank
()
in
[
0
,
1
]:
comm
.
all_reduce
(
tensor
)
...
...
@@ -81,9 +85,40 @@ def multiple_tp_worker_fn():
@
pytest
.
mark
.
skipif
(
torch
.
cuda
.
device_count
()
<
4
,
reason
=
"Need at least
2
GPUs to run the test."
)
reason
=
"Need at least
4
GPUs to run the test."
)
def
test_pynccl_multiple_tp
():
distributed_run
(
worker_fn
,
4
)
# this tests pynccl for multiple tp groups, in a standalone way
# i.e. call `comm.all_reduce` directly
distributed_run
(
multiple_tp_worker_fn
,
4
)
@
worker_fn_wrapper
def
multiple_tp_with_vllm_worker_fn
():
device
=
torch
.
device
(
f
"cuda:
{
torch
.
distributed
.
get_rank
()
}
"
)
torch
.
cuda
.
set_device
(
torch
.
distributed
.
get_rank
())
ensure_model_parallel_initialized
(
2
,
2
)
pynccl_utils
.
init_process_group
(
group
=
get_tensor_model_parallel_cpu_group
())
tensor
=
torch
.
ones
(
16
,
1024
,
1024
,
dtype
=
torch
.
float32
,
device
=
device
)
with
with_pynccl_for_all_reduce
():
# two tp groups can communicate independently
if
torch
.
distributed
.
get_rank
()
in
[
0
,
1
]:
tensor
=
tensor_model_parallel_all_reduce
(
tensor
)
tensor
=
tensor_model_parallel_all_reduce
(
tensor
)
result
=
tensor
.
mean
().
cpu
().
item
()
assert
result
==
4
else
:
tensor
=
tensor_model_parallel_all_reduce
(
tensor
)
result
=
tensor
.
mean
().
cpu
().
item
()
assert
result
==
2
@
pytest
.
mark
.
skipif
(
torch
.
cuda
.
device_count
()
<
4
,
reason
=
"Need at least 4 GPUs to run the test."
)
def
test_pynccl_multiple_tp_with_vllm
():
# this tests pynccl for multiple tp groups, together with vllm
# i.e. call `tensor_model_parallel_all_reduce`
distributed_run
(
multiple_tp_with_vllm_worker_fn
,
4
)
@
worker_fn_wrapper
...
...
vllm/distributed/communication_op.py
View file @
344a5d0c
...
...
@@ -34,7 +34,6 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
if
out
is
not
None
:
return
out
if
is_pynccl_enabled_for_all_reduce
():
# TODO: support multiple parallel groups.
pynccl_utils
.
all_reduce
(
input_
)
else
:
torch
.
distributed
.
all_reduce
(
input_
,
...
...
vllm/distributed/parallel_state.py
View file @
344a5d0c
...
...
@@ -14,7 +14,8 @@ from vllm.logger import init_logger
logger
=
init_logger
(
__name__
)
# Tensor model parallel group that the current rank belongs to.
_TENSOR_MODEL_PARALLEL_GROUP
=
None
_TP_DEVICE_GROUP
=
None
_TP_CPU_GROUP
=
None
# Pipeline model parallel group that the current rank belongs to.
_PIPELINE_MODEL_PARALLEL_GROUP
=
None
...
...
@@ -132,15 +133,17 @@ def initialize_model_parallel(
rank
=
torch
.
distributed
.
get_rank
()
# Build the tensor model-parallel groups.
global
_T
ENSOR_MODEL_PARALLEL
_GROUP
assert
_T
ENSOR_MODEL_PARALLEL
_GROUP
is
None
,
(
global
_T
P_DEVICE_GROUP
,
_TP_CPU
_GROUP
assert
_T
P_DEVICE
_GROUP
is
None
,
(
"tensor model parallel group is already initialized"
)
for
i
in
range
(
num_tensor_model_parallel_groups
):
ranks
=
range
(
i
*
tensor_model_parallel_size
,
(
i
+
1
)
*
tensor_model_parallel_size
)
group
=
torch
.
distributed
.
new_group
(
ranks
,
backend
=
backend
)
cpu_group
=
torch
.
distributed
.
new_group
(
ranks
,
backend
=
"gloo"
)
if
rank
in
ranks
:
_TENSOR_MODEL_PARALLEL_GROUP
=
group
_TP_DEVICE_GROUP
=
group
_TP_CPU_GROUP
=
cpu_group
# Build the pipeline model-parallel groups.
global
_PIPELINE_MODEL_PARALLEL_GROUP
...
...
@@ -185,7 +188,7 @@ def ensure_model_parallel_initialized(
def
model_parallel_is_initialized
():
"""Check if tensor and pipeline parallel groups are initialized."""
return
(
_T
ENSOR_MODEL_PARALLEL
_GROUP
is
not
None
return
(
_T
P_DEVICE
_GROUP
is
not
None
and
_PIPELINE_MODEL_PARALLEL_GROUP
is
not
None
)
...
...
@@ -197,9 +200,16 @@ def get_cpu_world_group():
def
get_tensor_model_parallel_group
():
"""Get the tensor model parallel group the caller rank belongs to."""
assert
_T
ENSOR_MODEL_PARALLEL
_GROUP
is
not
None
,
(
assert
_T
P_DEVICE
_GROUP
is
not
None
,
(
"tensor model parallel group is not initialized"
)
return
_TENSOR_MODEL_PARALLEL_GROUP
return
_TP_DEVICE_GROUP
def
get_tensor_model_parallel_cpu_group
():
"""Get the tensor model parallel cpu group the caller rank belongs to."""
assert
_TP_CPU_GROUP
is
not
None
,
(
"tensor model parallel cpu group is not initialized"
)
return
_TP_CPU_GROUP
def
get_pipeline_model_parallel_group
():
...
...
@@ -277,10 +287,14 @@ def get_pipeline_model_parallel_prev_rank():
def
destroy_model_parallel
():
"""Set the groups to none and destroy them."""
global
_TENSOR_MODEL_PARALLEL_GROUP
if
_TENSOR_MODEL_PARALLEL_GROUP
:
torch
.
distributed
.
destroy_process_group
(
_TENSOR_MODEL_PARALLEL_GROUP
)
_TENSOR_MODEL_PARALLEL_GROUP
=
None
global
_TP_DEVICE_GROUP
if
_TP_DEVICE_GROUP
:
torch
.
distributed
.
destroy_process_group
(
_TP_DEVICE_GROUP
)
_TP_DEVICE_GROUP
=
None
global
_TP_CPU_GROUP
if
_TP_CPU_GROUP
:
torch
.
distributed
.
destroy_process_group
(
_TP_CPU_GROUP
)
_TP_CPU_GROUP
=
None
global
_PIPELINE_MODEL_PARALLEL_GROUP
if
_PIPELINE_MODEL_PARALLEL_GROUP
:
torch
.
distributed
.
destroy_process_group
(
_PIPELINE_MODEL_PARALLEL_GROUP
)
...
...
vllm/worker/worker.py
View file @
344a5d0c
...
...
@@ -11,6 +11,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
VisionLanguageConfig
)
from
vllm.distributed
import
(
broadcast_tensor_dict
,
ensure_model_parallel_initialized
,
get_tensor_model_parallel_cpu_group
,
init_distributed_environment
)
from
vllm.distributed.device_communicators
import
pynccl_utils
from
vllm.distributed.device_communicators.custom_all_reduce
import
(
...
...
@@ -288,6 +289,9 @@ def init_worker_distributed_environment(
init_distributed_environment
(
parallel_config
.
world_size
,
rank
,
distributed_init_method
,
local_rank
)
ensure_model_parallel_initialized
(
parallel_config
.
tensor_parallel_size
,
parallel_config
.
pipeline_parallel_size
)
if
pynccl_utils
.
is_initialized
():
pynccl_world_size
=
pynccl_utils
.
get_world_size
()
if
pynccl_world_size
!=
parallel_config
.
world_size
:
...
...
@@ -298,12 +302,9 @@ def init_worker_distributed_environment(
elif
parallel_config
.
world_size
>
1
:
# NOTE(woosuk): We don't initialize pynccl process group when world size
# is 1.
# NOTE(kaichao): By default, pynccl will use information inside
# `parallel_state` for initialization.
pynccl_utils
.
init_process_group
()
ensure_model_parallel_initialized
(
parallel_config
.
tensor_parallel_size
,
parallel_config
.
pipeline_parallel_size
)
# NOTE(kaichao): By default, pynccl is initialized for tp group.
pynccl_utils
.
init_process_group
(
group
=
get_tensor_model_parallel_cpu_group
())
# Initialize a custom fast all-reduce implementation.
if
not
parallel_config
.
disable_custom_all_reduce
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment