Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3c6325f0
Unverified
Commit
3c6325f0
authored
Jul 03, 2024
by
youkaichao
Committed by
GitHub
Jul 03, 2024
Browse files
[core][distributed] custom allreduce when pp size > 1 (#6117)
parent
47f0954a
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
17 additions
and
15 deletions
+17
-15
vllm/config.py
vllm/config.py
+5
-11
vllm/distributed/parallel_state.py
vllm/distributed/parallel_state.py
+12
-4
No files found.
vllm/config.py
View file @
3c6325f0
...
@@ -723,17 +723,11 @@ class ParallelConfig:
...
@@ -723,17 +723,11 @@ class ParallelConfig:
if
self
.
distributed_executor_backend
==
"ray"
:
if
self
.
distributed_executor_backend
==
"ray"
:
from
vllm.executor
import
ray_utils
from
vllm.executor
import
ray_utils
ray_utils
.
assert_ray_available
()
ray_utils
.
assert_ray_available
()
if
not
self
.
disable_custom_all_reduce
and
self
.
world_size
>
1
:
if
is_hip
():
if
is_hip
():
self
.
disable_custom_all_reduce
=
True
self
.
disable_custom_all_reduce
=
True
logger
.
info
(
logger
.
info
(
"Disabled the custom all-reduce kernel because it is not "
"Disabled the custom all-reduce kernel because it is not "
"supported on AMD GPUs."
)
"supported on AMD GPUs."
)
elif
self
.
pipeline_parallel_size
>
1
:
self
.
disable_custom_all_reduce
=
True
logger
.
info
(
"Disabled the custom all-reduce kernel because it is not "
"supported with pipeline parallelism."
)
if
self
.
ray_workers_use_nsight
and
(
if
self
.
ray_workers_use_nsight
and
(
not
self
.
distributed_executor_backend
==
"ray"
):
not
self
.
distributed_executor_backend
==
"ray"
):
raise
ValueError
(
"Unable to use nsight profiling unless workers "
raise
ValueError
(
"Unable to use nsight profiling unless workers "
...
...
vllm/distributed/parallel_state.py
View file @
3c6325f0
...
@@ -719,14 +719,19 @@ def init_world_group(ranks: List[int], local_rank: int,
...
@@ -719,14 +719,19 @@ def init_world_group(ranks: List[int], local_rank: int,
)
)
def
init_model_parallel_group
(
group_ranks
:
List
[
List
[
int
]],
local_rank
:
int
,
def
init_model_parallel_group
(
backend
:
str
)
->
GroupCoordinator
:
group_ranks
:
List
[
List
[
int
]],
local_rank
:
int
,
backend
:
str
,
use_custom_allreduce
:
Optional
[
bool
]
=
None
)
->
GroupCoordinator
:
if
use_custom_allreduce
is
None
:
use_custom_allreduce
=
_ENABLE_CUSTOM_ALL_REDUCE
return
GroupCoordinator
(
return
GroupCoordinator
(
group_ranks
=
group_ranks
,
group_ranks
=
group_ranks
,
local_rank
=
local_rank
,
local_rank
=
local_rank
,
torch_distributed_backend
=
backend
,
torch_distributed_backend
=
backend
,
use_pynccl
=
True
,
use_pynccl
=
True
,
use_custom_allreduce
=
_ENABLE_CUSTOM_ALL_REDUCE
,
use_custom_allreduce
=
use_custom_allreduce
,
)
)
...
@@ -888,8 +893,11 @@ def initialize_model_parallel(
...
@@ -888,8 +893,11 @@ def initialize_model_parallel(
for
i
in
range
(
num_pipeline_model_parallel_groups
):
for
i
in
range
(
num_pipeline_model_parallel_groups
):
ranks
=
list
(
range
(
i
,
world_size
,
num_pipeline_model_parallel_groups
))
ranks
=
list
(
range
(
i
,
world_size
,
num_pipeline_model_parallel_groups
))
group_ranks
.
append
(
ranks
)
group_ranks
.
append
(
ranks
)
# pipeline parallel does not need custom allreduce
_PP
=
init_model_parallel_group
(
group_ranks
,
_PP
=
init_model_parallel_group
(
group_ranks
,
get_world_group
().
local_rank
,
backend
)
get_world_group
().
local_rank
,
backend
,
use_custom_allreduce
=
False
)
def
ensure_model_parallel_initialized
(
def
ensure_model_parallel_initialized
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment