Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6d592eb4
Unverified
Commit
6d592eb4
authored
Apr 09, 2024
by
youkaichao
Committed by
GitHub
Apr 09, 2024
Browse files
[Core] separate distributed_init from worker (#3904)
parent
d036198e
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
85 additions
and
58 deletions
+85
-58
vllm/model_executor/parallel_utils/parallel_state.py
vllm/model_executor/parallel_utils/parallel_state.py
+60
-3
vllm/test_utils.py
vllm/test_utils.py
+6
-7
vllm/worker/cpu_worker.py
vllm/worker/cpu_worker.py
+7
-21
vllm/worker/worker.py
vllm/worker/worker.py
+12
-27
No files found.
vllm/model_executor/parallel_utils/parallel_state.py
View file @
6d592eb4
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Tensor and pipeline parallel groups."""
"""Tensor and pipeline parallel groups."""
import
contextlib
import
contextlib
from
typing
import
Optional
import
torch
import
torch
...
@@ -14,14 +15,59 @@ _TENSOR_MODEL_PARALLEL_GROUP = None
...
@@ -14,14 +15,59 @@ _TENSOR_MODEL_PARALLEL_GROUP = None
# Pipeline model parallel group that the current rank belongs to.
# Pipeline model parallel group that the current rank belongs to.
_PIPELINE_MODEL_PARALLEL_GROUP
=
None
_PIPELINE_MODEL_PARALLEL_GROUP
=
None
# when people blindly call `torch.distributed.all_reduce` etc,
# it will use this group. It is initialized with the `backend`
# parameter of `init_distributed_environment` below.
# Essentially, this is `torch.distributed.group.WORLD`.
# We leave a line here to note that this is device-specific.
# Note that this variable is not safe to use, because when users
# call `init_distributed_environment` first, and then destroy
# the process group themselves, this variable will keep a reference to the
# destroyed process group, which is not useful.
_DEVICE_WORLD_GROUP
=
None
# duing `init_distributed_environment`, we will also initialize a
# group with `gloo` backend, to allow direct coordination between
# processes through the CPU.
_CPU_WORLD_GROUP
=
None
# In summary, after calling `init_distributed_environment`, we will
# always have two groups: one for device-specific (and is the default)
# and one for CPU. All processes will be part of both groups.
# A list of global ranks for each pipeline group to ease calculation of the
# A list of global ranks for each pipeline group to ease calculation of the
# source rank when broadcasting from the first or last pipeline stage.
# source rank when broadcasting from the first or last pipeline stage.
_PIPELINE_GLOBAL_RANKS
=
None
_PIPELINE_GLOBAL_RANKS
=
None
def
init_distributed_environment
(
world_size
:
int
,
rank
:
int
,
distributed_init_method
:
Optional
[
str
]
=
None
,
local_rank
:
int
=
-
1
,
backend
:
str
=
"nccl"
,
):
if
not
torch
.
distributed
.
is_initialized
():
assert
distributed_init_method
is
not
None
,
(
"distributed_init_method must be provided when initializing "
"distributed environment"
)
# this backend is used for WORLD
torch
.
distributed
.
init_process_group
(
backend
=
backend
,
init_method
=
distributed_init_method
,
world_size
=
world_size
,
rank
=
rank
)
global
_DEVICE_WORLD_GROUP
,
_CPU_WORLD_GROUP
_DEVICE_WORLD_GROUP
=
torch
.
distributed
.
group
.
WORLD
ranks
=
list
(
range
(
torch
.
distributed
.
get_world_size
()))
_CPU_WORLD_GROUP
=
torch
.
distributed
.
new_group
(
ranks
=
ranks
,
backend
=
"gloo"
)
def
initialize_model_parallel
(
def
initialize_model_parallel
(
tensor_model_parallel_size
:
int
=
1
,
tensor_model_parallel_size
:
int
=
1
,
pipeline_model_parallel_size
:
int
=
1
,
pipeline_model_parallel_size
:
int
=
1
,
backend
:
Optional
[
str
]
=
None
,
)
->
None
:
)
->
None
:
"""
"""
Initialize model parallel groups.
Initialize model parallel groups.
...
@@ -48,6 +94,8 @@ def initialize_model_parallel(
...
@@ -48,6 +94,8 @@ def initialize_model_parallel(
# Get world size and rank. Ensure some consistencies.
# Get world size and rank. Ensure some consistencies.
assert
torch
.
distributed
.
is_initialized
()
assert
torch
.
distributed
.
is_initialized
()
world_size
:
int
=
torch
.
distributed
.
get_world_size
()
world_size
:
int
=
torch
.
distributed
.
get_world_size
()
# get the backend of _DEVICE_WORLD_GROUP
backend
=
backend
or
torch
.
distributed
.
get_backend
()
if
(
world_size
!=
if
(
world_size
!=
tensor_model_parallel_size
*
pipeline_model_parallel_size
):
tensor_model_parallel_size
*
pipeline_model_parallel_size
):
...
@@ -69,7 +117,7 @@ def initialize_model_parallel(
...
@@ -69,7 +117,7 @@ def initialize_model_parallel(
for
i
in
range
(
num_tensor_model_parallel_groups
):
for
i
in
range
(
num_tensor_model_parallel_groups
):
ranks
=
range
(
i
*
tensor_model_parallel_size
,
ranks
=
range
(
i
*
tensor_model_parallel_size
,
(
i
+
1
)
*
tensor_model_parallel_size
)
(
i
+
1
)
*
tensor_model_parallel_size
)
group
=
torch
.
distributed
.
new_group
(
ranks
)
group
=
torch
.
distributed
.
new_group
(
ranks
,
backend
=
backend
)
if
rank
in
ranks
:
if
rank
in
ranks
:
_TENSOR_MODEL_PARALLEL_GROUP
=
group
_TENSOR_MODEL_PARALLEL_GROUP
=
group
...
@@ -80,7 +128,7 @@ def initialize_model_parallel(
...
@@ -80,7 +128,7 @@ def initialize_model_parallel(
"pipeline model parallel group is already initialized"
)
"pipeline model parallel group is already initialized"
)
for
i
in
range
(
num_pipeline_model_parallel_groups
):
for
i
in
range
(
num_pipeline_model_parallel_groups
):
ranks
=
range
(
i
,
world_size
,
num_pipeline_model_parallel_groups
)
ranks
=
range
(
i
,
world_size
,
num_pipeline_model_parallel_groups
)
group
=
torch
.
distributed
.
new_group
(
ranks
)
group
=
torch
.
distributed
.
new_group
(
ranks
,
backend
=
backend
)
if
rank
in
ranks
:
if
rank
in
ranks
:
_PIPELINE_MODEL_PARALLEL_GROUP
=
group
_PIPELINE_MODEL_PARALLEL_GROUP
=
group
_PIPELINE_GLOBAL_RANKS
=
ranks
_PIPELINE_GLOBAL_RANKS
=
ranks
...
@@ -89,14 +137,17 @@ def initialize_model_parallel(
...
@@ -89,14 +137,17 @@ def initialize_model_parallel(
def
ensure_model_parallel_initialized
(
def
ensure_model_parallel_initialized
(
tensor_model_parallel_size
:
int
,
tensor_model_parallel_size
:
int
,
pipeline_model_parallel_size
:
int
,
pipeline_model_parallel_size
:
int
,
backend
:
Optional
[
str
]
=
None
,
)
->
None
:
)
->
None
:
"""Helper to initialize model parallel groups if they are not initialized,
"""Helper to initialize model parallel groups if they are not initialized,
or ensure tensor-parallel and pipeline-parallel sizes are equal to expected
or ensure tensor-parallel and pipeline-parallel sizes are equal to expected
values if the model parallel groups are initialized.
values if the model parallel groups are initialized.
"""
"""
# get the backend of _DEVICE_WORLD_GROUP
backend
=
backend
or
torch
.
distributed
.
get_backend
()
if
not
model_parallel_is_initialized
():
if
not
model_parallel_is_initialized
():
initialize_model_parallel
(
tensor_model_parallel_size
,
initialize_model_parallel
(
tensor_model_parallel_size
,
pipeline_model_parallel_size
)
pipeline_model_parallel_size
,
backend
)
return
return
assert
(
assert
(
...
@@ -117,6 +168,12 @@ def model_parallel_is_initialized():
...
@@ -117,6 +168,12 @@ def model_parallel_is_initialized():
and
_PIPELINE_MODEL_PARALLEL_GROUP
is
not
None
)
and
_PIPELINE_MODEL_PARALLEL_GROUP
is
not
None
)
def
get_cpu_world_group
():
"""Get the CPU world group."""
assert
_CPU_WORLD_GROUP
is
not
None
,
(
"CPU world group is not initialized"
)
return
_CPU_WORLD_GROUP
def
get_tensor_model_parallel_group
():
def
get_tensor_model_parallel_group
():
"""Get the tensor model parallel group the caller rank belongs to."""
"""Get the tensor model parallel group the caller rank belongs to."""
assert
_TENSOR_MODEL_PARALLEL_GROUP
is
not
None
,
(
assert
_TENSOR_MODEL_PARALLEL_GROUP
is
not
None
,
(
...
...
vllm/test_utils.py
View file @
6d592eb4
import
ray
import
ray
from
vllm.config
import
ParallelConfig
from
vllm.model_executor.parallel_utils.parallel_state
import
(
ensure_model_parallel_initialized
,
init_distributed_environment
)
from
vllm.utils
import
get_open_port
from
vllm.utils
import
get_open_port
from
vllm.worker.worker
import
init_distributed_environment
def
init_test_distributed_environment
(
def
init_test_distributed_environment
(
...
@@ -12,15 +12,14 @@ def init_test_distributed_environment(
...
@@ -12,15 +12,14 @@ def init_test_distributed_environment(
distributed_init_port
:
str
,
distributed_init_port
:
str
,
local_rank
:
int
=
-
1
,
local_rank
:
int
=
-
1
,
)
->
None
:
)
->
None
:
parallel_config
=
ParallelConfig
(
pipeline_parallel_size
,
tensor_parallel_size
,
worker_use_ray
=
True
)
distributed_init_method
=
f
"tcp://localhost:
{
distributed_init_port
}
"
distributed_init_method
=
f
"tcp://localhost:
{
distributed_init_port
}
"
init_distributed_environment
(
init_distributed_environment
(
parallel_config
,
world_size
=
pipeline_parallel_size
*
tensor_parallel_size
,
rank
,
rank
=
rank
,
distributed_init_method
=
distributed_init_method
,
distributed_init_method
=
distributed_init_method
,
local_rank
=
local_rank
)
local_rank
=
local_rank
)
ensure_model_parallel_initialized
(
tensor_parallel_size
,
pipeline_parallel_size
)
def
multi_process_tensor_parallel
(
def
multi_process_tensor_parallel
(
...
...
vllm/worker/cpu_worker.py
View file @
6d592eb4
...
@@ -13,7 +13,7 @@ from vllm.model_executor.model_loader import get_model
...
@@ -13,7 +13,7 @@ from vllm.model_executor.model_loader import get_model
from
vllm.model_executor.parallel_utils.communication_op
import
(
from
vllm.model_executor.parallel_utils.communication_op
import
(
broadcast_tensor_dict
)
broadcast_tensor_dict
)
from
vllm.model_executor.parallel_utils.parallel_state
import
(
from
vllm.model_executor.parallel_utils.parallel_state
import
(
ensure_model_parallel_initialized
)
ensure_model_parallel_initialized
,
init_distributed_environment
)
from
vllm.sequence
import
SamplerOutput
,
SequenceGroupMetadata
from
vllm.sequence
import
SamplerOutput
,
SequenceGroupMetadata
from
vllm.utils
import
STR_DTYPE_TO_TORCH_DTYPE
from
vllm.utils
import
STR_DTYPE_TO_TORCH_DTYPE
from
vllm.worker.model_runner
import
ModelRunner
from
vllm.worker.model_runner
import
ModelRunner
...
@@ -251,25 +251,11 @@ class CPUWorker:
...
@@ -251,25 +251,11 @@ class CPUWorker:
parallel_config
=
self
.
parallel_config
parallel_config
=
self
.
parallel_config
rank
=
self
.
rank
rank
=
self
.
rank
distributed_init_method
=
self
.
distributed_init_method
distributed_init_method
=
self
.
distributed_init_method
init_distributed_environment
(
if
torch
.
distributed
.
is_initialized
():
torch_world_size
=
torch
.
distributed
.
get_world_size
()
if
torch_world_size
!=
parallel_config
.
world_size
:
raise
RuntimeError
(
"torch.distributed is already initialized but the torch "
"world size does not match parallel_config.world_size "
f
"(
{
torch_world_size
}
vs.
{
parallel_config
.
world_size
}
)."
)
elif
not
distributed_init_method
:
raise
ValueError
(
"distributed_init_method must be set if torch.distributed "
"is not already initialized"
)
else
:
backend
=
"gloo"
torch
.
distributed
.
init_process_group
(
backend
=
backend
,
world_size
=
parallel_config
.
world_size
,
world_size
=
parallel_config
.
world_size
,
rank
=
rank
,
rank
=
rank
,
init_method
=
distributed_init_method
,
distributed_init_method
=
distributed_init_method
,
backend
=
"gloo"
,
)
)
# A small all_reduce for warmup.
# A small all_reduce for warmup.
...
...
vllm/worker/worker.py
View file @
6d592eb4
...
@@ -15,7 +15,7 @@ from vllm.model_executor.parallel_utils.communication_op import (
...
@@ -15,7 +15,7 @@ from vllm.model_executor.parallel_utils.communication_op import (
broadcast_tensor_dict
)
broadcast_tensor_dict
)
from
vllm.model_executor.parallel_utils.custom_all_reduce
import
init_custom_ar
from
vllm.model_executor.parallel_utils.custom_all_reduce
import
init_custom_ar
from
vllm.model_executor.parallel_utils.parallel_state
import
(
from
vllm.model_executor.parallel_utils.parallel_state
import
(
ensure_model_parallel_initialized
)
ensure_model_parallel_initialized
,
init_distributed_environment
)
from
vllm.sequence
import
SamplerOutput
,
SequenceGroupMetadata
from
vllm.sequence
import
SamplerOutput
,
SequenceGroupMetadata
from
vllm.worker.cache_engine
import
CacheEngine
from
vllm.worker.cache_engine
import
CacheEngine
from
vllm.worker.model_runner
import
ModelRunner
from
vllm.worker.model_runner
import
ModelRunner
...
@@ -97,7 +97,7 @@ class Worker:
...
@@ -97,7 +97,7 @@ class Worker:
raise
RuntimeError
(
raise
RuntimeError
(
f
"Not support device type:
{
self
.
device_config
.
device
}
"
)
f
"Not support device type:
{
self
.
device_config
.
device
}
"
)
# Initialize the distributed environment.
# Initialize the distributed environment.
init_distributed_environment
(
self
.
parallel_config
,
self
.
rank
,
init_
worker_
distributed_environment
(
self
.
parallel_config
,
self
.
rank
,
self
.
distributed_init_method
,
self
.
distributed_init_method
,
self
.
local_rank
)
self
.
local_rank
)
# Set random seed.
# Set random seed.
...
@@ -248,31 +248,15 @@ class Worker:
...
@@ -248,31 +248,15 @@ class Worker:
self
.
parallel_config
)
self
.
parallel_config
)
def
init_distributed_environment
(
def
init_
worker_
distributed_environment
(
parallel_config
:
ParallelConfig
,
parallel_config
:
ParallelConfig
,
rank
:
int
,
rank
:
int
,
distributed_init_method
:
Optional
[
str
]
=
None
,
distributed_init_method
:
Optional
[
str
]
=
None
,
local_rank
:
int
=
-
1
,
local_rank
:
int
=
-
1
,
)
->
None
:
)
->
None
:
"""Initialize the distributed environment."""
"""Initialize the distributed environment."""
if
torch
.
distributed
.
is_initialized
():
init_distributed_environment
(
parallel_config
.
world_size
,
rank
,
torch_world_size
=
torch
.
distributed
.
get_world_size
()
distributed_init_method
,
local_rank
)
if
torch_world_size
!=
parallel_config
.
world_size
:
raise
RuntimeError
(
"torch.distributed is already initialized but the torch world "
"size does not match parallel_config.world_size "
f
"(
{
torch_world_size
}
vs.
{
parallel_config
.
world_size
}
)."
)
elif
not
distributed_init_method
:
raise
ValueError
(
"distributed_init_method must be set if torch.distributed "
"is not already initialized"
)
else
:
torch
.
distributed
.
init_process_group
(
backend
=
"nccl"
,
world_size
=
parallel_config
.
world_size
,
rank
=
rank
,
init_method
=
distributed_init_method
,
)
if
pynccl_utils
.
is_initialized
():
if
pynccl_utils
.
is_initialized
():
pynccl_world_size
=
pynccl_utils
.
get_world_size
()
pynccl_world_size
=
pynccl_utils
.
get_world_size
()
...
@@ -291,10 +275,6 @@ def init_distributed_environment(
...
@@ -291,10 +275,6 @@ def init_distributed_environment(
init_method
=
distributed_init_method
,
init_method
=
distributed_init_method
,
)
)
# A small all_reduce for warmup.
torch
.
distributed
.
all_reduce
(
torch
.
zeros
(
1
).
cuda
())
if
pynccl_utils
.
is_initialized
():
pynccl_utils
.
all_reduce
(
torch
.
zeros
(
1
).
cuda
())
ensure_model_parallel_initialized
(
parallel_config
.
tensor_parallel_size
,
ensure_model_parallel_initialized
(
parallel_config
.
tensor_parallel_size
,
parallel_config
.
pipeline_parallel_size
)
parallel_config
.
pipeline_parallel_size
)
...
@@ -302,6 +282,11 @@ def init_distributed_environment(
...
@@ -302,6 +282,11 @@ def init_distributed_environment(
if
not
parallel_config
.
disable_custom_all_reduce
:
if
not
parallel_config
.
disable_custom_all_reduce
:
init_custom_ar
()
init_custom_ar
()
# A small all_reduce for warmup.
torch
.
distributed
.
all_reduce
(
torch
.
zeros
(
1
).
cuda
())
if
pynccl_utils
.
is_initialized
():
pynccl_utils
.
all_reduce
(
torch
.
zeros
(
1
).
cuda
())
def
_check_if_gpu_supports_dtype
(
torch_dtype
:
torch
.
dtype
):
def
_check_if_gpu_supports_dtype
(
torch_dtype
:
torch
.
dtype
):
# Check if the GPU supports the dtype.
# Check if the GPU supports the dtype.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment