Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
145944cb
Unverified
Commit
145944cb
authored
Feb 26, 2025
by
Harry Mellor
Committed by
GitHub
Feb 25, 2025
Browse files
Improve pipeline partitioning (#13839)
parent
094b7d94
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
46 additions
and
8 deletions
+46
-8
tests/distributed/test_pipeline_partition.py
tests/distributed/test_pipeline_partition.py
+24
-0
vllm/distributed/utils.py
vllm/distributed/utils.py
+22
-8
No files found.
tests/distributed/test_pipeline_partition.py
View file @
145944cb
...
@@ -34,3 +34,27 @@ def test_custom_layer_partition():
...
@@ -34,3 +34,27 @@ def test_custom_layer_partition():
# Wrong number of layers
# Wrong number of layers
with
pytest
.
raises
(
ValueError
):
with
pytest
.
raises
(
ValueError
):
_verify
(
"5,5,5,5"
,
21
,
4
,
[(
0
,
5
),
(
5
,
10
),
(
10
,
15
),
(
15
,
20
)])
_verify
(
"5,5,5,5"
,
21
,
4
,
[(
0
,
5
),
(
5
,
10
),
(
10
,
15
),
(
15
,
20
)])
@
pytest
.
mark
.
parametrize
(
"num_hidden_layers,pp_size,pp_rank,indices"
,
[
# pp_size 2
(
2
,
2
,
0
,
(
0
,
1
)),
(
2
,
2
,
1
,
(
1
,
2
)),
(
3
,
2
,
0
,
(
0
,
2
)),
(
3
,
2
,
1
,
(
2
,
3
)),
# pp_size 3
(
3
,
3
,
0
,
(
0
,
1
)),
(
3
,
3
,
1
,
(
1
,
2
)),
(
3
,
3
,
2
,
(
2
,
3
)),
(
4
,
3
,
0
,
(
0
,
1
)),
(
4
,
3
,
1
,
(
1
,
3
)),
(
4
,
3
,
2
,
(
3
,
4
)),
(
5
,
3
,
0
,
(
0
,
2
)),
(
5
,
3
,
1
,
(
2
,
4
)),
(
5
,
3
,
2
,
(
4
,
5
)),
])
def
test_uneven_auto_partition
(
num_hidden_layers
:
int
,
pp_size
:
int
,
pp_rank
:
int
,
indices
:
tuple
[
int
,
int
]):
assert
indices
==
get_pp_indices
(
num_hidden_layers
,
pp_rank
,
pp_size
)
vllm/distributed/utils.py
View file @
145944cb
...
@@ -67,8 +67,17 @@ def split_tensor_along_last_dim(
...
@@ -67,8 +67,17 @@ def split_tensor_along_last_dim(
def
get_pp_indices
(
num_hidden_layers
:
int
,
pp_rank
:
int
,
def
get_pp_indices
(
num_hidden_layers
:
int
,
pp_rank
:
int
,
pp_size
:
int
)
->
Tuple
[
int
,
int
]:
pp_size
:
int
)
->
Tuple
[
int
,
int
]:
"""Try to evenly distribute layers across partitions.
"""Try to evenly distribute layers across partitions.
If the number of layers is not divisible by the number of partitions,
If the number of layers is not divisible by the number of partitions,
the last partition will have the remaining layers.
the remaining layers are evenly distributed across all but the last
partition. The last partition is excluded because it often contains an
additional norm layer and we are attempting to balance compute.
If `pp_size > 2` and the number of remaining layers is
`0 < x <= pp_size - 2` then the remaining layers are evenly distributed
across the middle partitions. The first and last partitions are excluded
because they contain the input and output embeddings respectively and we
are attempting to reduce maximum memory consumption across partitions.
"""
"""
partition_list_str
=
envs
.
VLLM_PP_LAYER_PARTITION
partition_list_str
=
envs
.
VLLM_PP_LAYER_PARTITION
if
partition_list_str
is
not
None
:
if
partition_list_str
is
not
None
:
...
@@ -84,15 +93,20 @@ def get_pp_indices(num_hidden_layers: int, pp_rank: int,
...
@@ -84,15 +93,20 @@ def get_pp_indices(num_hidden_layers: int, pp_rank: int,
if
sum
(
partitions
)
!=
num_hidden_layers
:
if
sum
(
partitions
)
!=
num_hidden_layers
:
raise
ValueError
(
raise
ValueError
(
f
"
{
sum
(
partitions
)
=
}
does not match
{
num_hidden_layers
=
}
."
)
f
"
{
sum
(
partitions
)
=
}
does not match
{
num_hidden_layers
=
}
."
)
start_layer
=
sum
(
partitions
[:
pp_rank
])
end_layer
=
start_layer
+
partitions
[
pp_rank
]
else
:
else
:
layers_per_partition
=
num_hidden_layers
//
pp_size
layers_per_partition
=
num_hidden_layers
//
pp_size
start_layer
=
pp_rank
*
layers_per_partition
partitions
=
[
layers_per_partition
for
_
in
range
(
pp_size
)]
end_layer
=
start_layer
+
layers_per_partition
if
pp_rank
==
pp_size
-
1
:
if
remaining_layers
:
=
num_hidden_layers
%
pp_size
:
end_layer
=
num_hidden_layers
for
i
in
range
(
2
,
remaining_layers
+
2
):
partitions
[
-
i
]
+=
1
logger
.
info
(
"Hidden layers were unevenly partitioned: %s"
,
","
.
join
(
str
(
p
)
for
p
in
partitions
))
logger
.
info
(
"This can be manually overridden using the "
"VLLM_PP_LAYER_PARTITION environment variable"
)
start_layer
=
sum
(
partitions
[:
pp_rank
])
end_layer
=
start_layer
+
partitions
[
pp_rank
]
return
(
start_layer
,
end_layer
)
return
(
start_layer
,
end_layer
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment