Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b5af8c22
Unverified
Commit
b5af8c22
authored
Jul 17, 2024
by
Cody Yu
Committed by
GitHub
Jul 17, 2024
Browse files
[Model] Pipeline parallel support for Mixtral (#6516)
parent
b5241e41
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
60 additions
and
19 deletions
+60
-19
tests/distributed/test_pipeline_parallel.py
tests/distributed/test_pipeline_parallel.py
+11
-6
vllm/config.py
vllm/config.py
+1
-0
vllm/model_executor/models/mixtral.py
vllm/model_executor/models/mixtral.py
+48
-13
No files found.
tests/distributed/test_pipeline_parallel.py
View file @
b5af8c22
import
pytest
from
transformers
import
AutoTokenizer
from
..utils
import
RemoteOpenAIServer
...
...
@@ -12,6 +13,8 @@ from ..utils import RemoteOpenAIServer
(
1
,
4
,
1
,
0
,
"meta-llama/Meta-Llama-3-8B"
),
])
def
test_compare_tp
(
TP_SIZE
,
PP_SIZE
,
EAGER_MODE
,
CHUNKED_PREFILL
,
MODEL_NAME
):
tokenizer
=
AutoTokenizer
.
from_pretrained
(
MODEL_NAME
)
pp_args
=
[
# use half precision for speed and memory savings in CI environment
"--dtype"
,
...
...
@@ -34,7 +37,7 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME):
"--dtype"
,
"bfloat16"
,
"--tensor-parallel-size"
,
str
(
max
(
TP_SIZE
,
2
)),
#
use at least TP_SIZE=2 to hold the model
str
(
max
(
TP_SIZE
,
2
)),
#
We only use 2 GPUs in the CI.
"--distributed-executor-backend"
,
"mp"
,
]
...
...
@@ -45,8 +48,10 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME):
pp_args
.
append
(
"--enforce-eager"
)
tp_args
.
append
(
"--enforce-eager"
)
prompt
=
"Hello, my name is"
token_ids
=
tokenizer
(
prompt
)[
"input_ids"
]
results
=
[]
for
args
in
[
pp_args
,
tp_args
]
:
for
args
in
(
pp_args
,
tp_args
)
:
with
RemoteOpenAIServer
(
MODEL_NAME
,
args
)
as
server
:
client
=
server
.
get_client
()
...
...
@@ -62,7 +67,7 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME):
# test with text prompt
completion
=
client
.
completions
.
create
(
model
=
MODEL_NAME
,
prompt
=
"Hello, my name is"
,
prompt
=
prompt
,
max_tokens
=
5
,
temperature
=
0.0
)
...
...
@@ -76,7 +81,7 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME):
# test using token IDs
completion
=
client
.
completions
.
create
(
model
=
MODEL_NAME
,
prompt
=
[
0
,
0
,
0
,
0
,
0
]
,
prompt
=
token_ids
,
max_tokens
=
5
,
temperature
=
0.0
,
)
...
...
@@ -91,7 +96,7 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME):
# test simple list
batch
=
client
.
completions
.
create
(
model
=
MODEL_NAME
,
prompt
=
[
"Hello, my name is"
,
"Hello, my name is"
],
prompt
=
[
prompt
,
prompt
],
max_tokens
=
5
,
temperature
=
0.0
,
)
...
...
@@ -105,7 +110,7 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME):
# test streaming
batch
=
client
.
completions
.
create
(
model
=
MODEL_NAME
,
prompt
=
[
"Hello, my name is"
,
"Hello, my name is"
],
prompt
=
[
prompt
,
prompt
],
max_tokens
=
5
,
temperature
=
0.0
,
stream
=
True
,
...
...
vllm/config.py
View file @
b5af8c22
...
...
@@ -34,6 +34,7 @@ _PP_SUPPORTED_MODELS = [
"MistralForCausalLM"
,
"Phi3ForCausalLM"
,
"GPT2LMHeadModel"
,
"MixtralForCausalLM"
,
]
...
...
vllm/model_executor/models/mixtral.py
View file @
b5af8c22
...
...
@@ -29,7 +29,7 @@ from transformers import MixtralConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.config
import
CacheConfig
,
LoRAConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
QKVParallelLinear
,
...
...
@@ -48,6 +48,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
from
.interfaces
import
SupportsLoRA
from
.utils
import
is_pp_missing_parameter
,
make_layers
class
MixtralMoE
(
nn
.
Module
):
...
...
@@ -255,12 +256,11 @@ class MixtralModel(nn.Module):
config
.
hidden_size
,
org_num_embeddings
=
config
.
vocab_size
,
)
self
.
layers
=
nn
.
ModuleList
([
MixtralDecoderLayer
(
config
,
cache_config
,
quant_config
=
quant_config
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
self
.
start_layer
,
self
.
end_layer
,
self
.
layers
=
make_layers
(
config
.
num_hidden_layers
,
lambda
:
MixtralDecoderLayer
(
config
,
cache_config
,
quant_config
=
quant_config
))
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
...
...
@@ -269,14 +269,25 @@ class MixtralModel(nn.Module):
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
residual
=
None
for
i
in
range
(
len
(
self
.
layers
)):
if
get_pp_group
().
is_first_rank
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
residual
=
None
else
:
assert
intermediate_tensors
is
not
None
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
residual
=
intermediate_tensors
[
"residual"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
kv_caches
[
i
],
attn_metadata
,
residual
)
kv_caches
[
i
-
self
.
start_layer
],
attn_metadata
,
residual
)
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
,
"residual"
:
residual
})
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
...
...
@@ -347,7 +358,7 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
)
attn_metadata
,
intermediate_tensors
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
...
...
@@ -356,6 +367,20 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
sampling_metadata
)
return
logits
def
make_empty_intermediate_tensors
(
self
,
batch_size
:
int
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
)
->
IntermediateTensors
:
return
IntermediateTensors
({
"hidden_states"
:
torch
.
zeros
((
batch_size
,
self
.
config
.
hidden_size
),
dtype
=
dtype
,
device
=
device
),
"residual"
:
torch
.
zeros
((
batch_size
,
self
.
config
.
hidden_size
),
dtype
=
dtype
,
device
=
device
),
})
def
sample
(
self
,
logits
:
Optional
[
torch
.
Tensor
],
...
...
@@ -392,6 +417,10 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
# Skip layers on other devices.
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
...
...
@@ -402,6 +431,9 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip layers on other devices.
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
...
...
@@ -414,6 +446,9 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
# Skip layers on other devices.
if
is_pp_missing_parameter
(
name
,
self
):
continue
# Remapping the name of FP8 kv-scale.
name
=
maybe_remap_kv_scale_name
(
name
,
params_dict
)
if
name
is
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment