Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3de6e6a3
Unverified
Commit
3de6e6a3
authored
Jul 03, 2024
by
youkaichao
Committed by
GitHub
Jul 03, 2024
Browse files
[core][distributed] support n layers % pp size != 0 (#6115)
parent
966fe721
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
19 additions
and
10 deletions
+19
-10
.buildkite/test-pipeline.yaml
.buildkite/test-pipeline.yaml
+1
-0
vllm/config.py
vllm/config.py
+6
-9
vllm/distributed/utils.py
vllm/distributed/utils.py
+8
-1
vllm/worker/openvino_worker.py
vllm/worker/openvino_worker.py
+1
-0
vllm/worker/tpu_worker.py
vllm/worker/tpu_worker.py
+1
-0
vllm/worker/worker.py
vllm/worker/worker.py
+1
-0
vllm/worker/xpu_worker.py
vllm/worker/xpu_worker.py
+1
-0
No files found.
.buildkite/test-pipeline.yaml
View file @
3de6e6a3
...
@@ -80,6 +80,7 @@ steps:
...
@@ -80,6 +80,7 @@ steps:
commands
:
commands
:
-
TP_SIZE=2 PP_SIZE=2 EAGER_MODE=1 CHUNKED_PREFILL=1 pytest -v -s distributed/test_pipeline_parallel.py
-
TP_SIZE=2 PP_SIZE=2 EAGER_MODE=1 CHUNKED_PREFILL=1 pytest -v -s distributed/test_pipeline_parallel.py
-
TP_SIZE=2 PP_SIZE=2 EAGER_MODE=1 CHUNKED_PREFILL=0 pytest -v -s distributed/test_pipeline_parallel.py
-
TP_SIZE=2 PP_SIZE=2 EAGER_MODE=1 CHUNKED_PREFILL=0 pytest -v -s distributed/test_pipeline_parallel.py
-
TP_SIZE=1 PP_SIZE=3 EAGER_MODE=1 CHUNKED_PREFILL=0 pytest -v -s distributed/test_pipeline_parallel.py
-
PP_SIZE=4 EAGER_MODE=1 CHUNKED_PREFILL=1 pytest -v -s distributed/test_pipeline_parallel.py
-
PP_SIZE=4 EAGER_MODE=1 CHUNKED_PREFILL=1 pytest -v -s distributed/test_pipeline_parallel.py
-
PP_SIZE=4 EAGER_MODE=1 CHUNKED_PREFILL=0 pytest -v -s distributed/test_pipeline_parallel.py
-
PP_SIZE=4 EAGER_MODE=1 CHUNKED_PREFILL=0 pytest -v -s distributed/test_pipeline_parallel.py
...
...
vllm/config.py
View file @
3de6e6a3
...
@@ -265,8 +265,6 @@ class ModelConfig:
...
@@ -265,8 +265,6 @@ class ModelConfig:
" must be divisible by tensor parallel size "
" must be divisible by tensor parallel size "
f
"(
{
tensor_parallel_size
}
)."
)
f
"(
{
tensor_parallel_size
}
)."
)
total_num_hidden_layers
=
getattr
(
self
.
hf_text_config
,
"num_hidden_layers"
,
0
)
pipeline_parallel_size
=
parallel_config
.
pipeline_parallel_size
pipeline_parallel_size
=
parallel_config
.
pipeline_parallel_size
architectures
=
getattr
(
self
.
hf_config
,
"architectures"
,
[])
architectures
=
getattr
(
self
.
hf_config
,
"architectures"
,
[])
if
not
all
(
arch
in
_PP_SUPPORTED_MODELS
if
not
all
(
arch
in
_PP_SUPPORTED_MODELS
...
@@ -275,12 +273,6 @@ class ModelConfig:
...
@@ -275,12 +273,6 @@ class ModelConfig:
"Pipeline parallelism is only supported for the following "
"Pipeline parallelism is only supported for the following "
f
" architectures:
{
_PP_SUPPORTED_MODELS
}
."
)
f
" architectures:
{
_PP_SUPPORTED_MODELS
}
."
)
if
total_num_hidden_layers
%
pipeline_parallel_size
!=
0
:
raise
ValueError
(
f
"Total number of hidden layers (
{
total_num_hidden_layers
}
) "
"must be divisible by pipeline parallel size "
f
"(
{
pipeline_parallel_size
}
)."
)
if
self
.
quantization
==
"bitsandbytes"
and
(
if
self
.
quantization
==
"bitsandbytes"
and
(
parallel_config
.
tensor_parallel_size
>
1
parallel_config
.
tensor_parallel_size
>
1
or
parallel_config
.
pipeline_parallel_size
>
1
):
or
parallel_config
.
pipeline_parallel_size
>
1
):
...
@@ -385,9 +377,13 @@ class ModelConfig:
...
@@ -385,9 +377,13 @@ class ModelConfig:
return
num_heads
//
parallel_config
.
tensor_parallel_size
return
num_heads
//
parallel_config
.
tensor_parallel_size
def
get_num_layers
(
self
,
parallel_config
:
"ParallelConfig"
)
->
int
:
def
get_num_layers
(
self
,
parallel_config
:
"ParallelConfig"
)
->
int
:
from
vllm.distributed.utils
import
get_pp_indices
total_num_hidden_layers
=
getattr
(
self
.
hf_text_config
,
total_num_hidden_layers
=
getattr
(
self
.
hf_text_config
,
"num_hidden_layers"
,
0
)
"num_hidden_layers"
,
0
)
return
total_num_hidden_layers
//
parallel_config
.
pipeline_parallel_size
pp_rank
=
parallel_config
.
rank
//
parallel_config
.
tensor_parallel_size
pp_size
=
parallel_config
.
pipeline_parallel_size
start
,
end
=
get_pp_indices
(
total_num_hidden_layers
,
pp_rank
,
pp_size
)
return
end
-
start
def
contains_seqlen_agnostic_layers
(
def
contains_seqlen_agnostic_layers
(
self
,
parallel_config
:
"ParallelConfig"
)
->
bool
:
self
,
parallel_config
:
"ParallelConfig"
)
->
bool
:
...
@@ -709,6 +705,7 @@ class ParallelConfig:
...
@@ -709,6 +705,7 @@ class ParallelConfig:
{
"CUDA_VISIBLE_DEVICES"
:
envs
.
CUDA_VISIBLE_DEVICES
})
{
"CUDA_VISIBLE_DEVICES"
:
envs
.
CUDA_VISIBLE_DEVICES
})
self
.
_verify_args
()
self
.
_verify_args
()
self
.
rank
=
0
def
_verify_args
(
self
)
->
None
:
def
_verify_args
(
self
)
->
None
:
if
(
self
.
pipeline_parallel_size
>
1
if
(
self
.
pipeline_parallel_size
>
1
...
...
vllm/distributed/utils.py
View file @
3de6e6a3
...
@@ -50,8 +50,15 @@ def split_tensor_along_last_dim(
...
@@ -50,8 +50,15 @@ def split_tensor_along_last_dim(
def
get_pp_indices
(
num_hidden_layers
:
int
,
pp_rank
:
int
,
def
get_pp_indices
(
num_hidden_layers
:
int
,
pp_rank
:
int
,
pp_size
:
int
)
->
Tuple
[
int
,
int
]:
pp_size
:
int
)
->
Tuple
[
int
,
int
]:
layers_per_partition
=
divide
(
num_hidden_layers
,
pp_size
)
"""Try to evenly distribute layers across partitions.
If the number of layers is not divisible by the number of partitions,
the last partition will have the remaining layers.
"""
layers_per_partition
=
num_hidden_layers
//
pp_size
start_layer
=
pp_rank
*
layers_per_partition
start_layer
=
pp_rank
*
layers_per_partition
end_layer
=
start_layer
+
layers_per_partition
end_layer
=
start_layer
+
layers_per_partition
if
pp_rank
==
pp_size
-
1
:
end_layer
=
num_hidden_layers
return
(
start_layer
,
end_layer
)
return
(
start_layer
,
end_layer
)
vllm/worker/openvino_worker.py
View file @
3de6e6a3
...
@@ -154,6 +154,7 @@ class OpenVINOWorker(LoraNotSupportedWorkerBase):
...
@@ -154,6 +154,7 @@ class OpenVINOWorker(LoraNotSupportedWorkerBase):
)
->
None
:
)
->
None
:
self
.
model_config
=
model_config
self
.
model_config
=
model_config
self
.
parallel_config
=
parallel_config
self
.
parallel_config
=
parallel_config
self
.
parallel_config
.
rank
=
rank
self
.
scheduler_config
=
scheduler_config
self
.
scheduler_config
=
scheduler_config
self
.
device_config
=
device_config
self
.
device_config
=
device_config
self
.
cache_config
=
cache_config
self
.
cache_config
=
cache_config
...
...
vllm/worker/tpu_worker.py
View file @
3de6e6a3
...
@@ -39,6 +39,7 @@ class TPUWorker(LoraNotSupportedWorkerBase):
...
@@ -39,6 +39,7 @@ class TPUWorker(LoraNotSupportedWorkerBase):
)
->
None
:
)
->
None
:
self
.
model_config
=
model_config
self
.
model_config
=
model_config
self
.
parallel_config
=
parallel_config
self
.
parallel_config
=
parallel_config
self
.
parallel_config
.
rank
=
rank
self
.
scheduler_config
=
scheduler_config
self
.
scheduler_config
=
scheduler_config
self
.
device_config
=
device_config
self
.
device_config
=
device_config
self
.
cache_config
=
cache_config
self
.
cache_config
=
cache_config
...
...
vllm/worker/worker.py
View file @
3de6e6a3
...
@@ -50,6 +50,7 @@ class Worker(LocalOrDistributedWorkerBase):
...
@@ -50,6 +50,7 @@ class Worker(LocalOrDistributedWorkerBase):
)
->
None
:
)
->
None
:
self
.
model_config
=
model_config
self
.
model_config
=
model_config
self
.
parallel_config
=
parallel_config
self
.
parallel_config
=
parallel_config
self
.
parallel_config
.
rank
=
rank
self
.
scheduler_config
=
scheduler_config
self
.
scheduler_config
=
scheduler_config
self
.
device_config
=
device_config
self
.
device_config
=
device_config
self
.
cache_config
=
cache_config
self
.
cache_config
=
cache_config
...
...
vllm/worker/xpu_worker.py
View file @
3de6e6a3
...
@@ -54,6 +54,7 @@ class XPUWorker(LoraNotSupportedWorkerBase, Worker):
...
@@ -54,6 +54,7 @@ class XPUWorker(LoraNotSupportedWorkerBase, Worker):
self
.
model_config
=
model_config
self
.
model_config
=
model_config
self
.
parallel_config
=
parallel_config
self
.
parallel_config
=
parallel_config
self
.
parallel_config
.
rank
=
rank
self
.
scheduler_config
=
scheduler_config
self
.
scheduler_config
=
scheduler_config
self
.
device_config
=
device_config
self
.
device_config
=
device_config
self
.
cache_config
=
cache_config
self
.
cache_config
=
cache_config
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment