Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
404422f4
Unverified
Commit
404422f4
authored
Jul 03, 2023
by
Woosuk Kwon
Committed by
GitHub
Jul 03, 2023
Browse files
[Model] Add support for MPT (#334)
parent
7717d083
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
388 additions
and
4 deletions
+388
-4
README.md
README.md
+1
-0
csrc/attention/attention_kernels.cu
csrc/attention/attention_kernels.cu
+3
-0
docs/source/models/supported_models.rst
docs/source/models/supported_models.rst
+3
-0
vllm/config.py
vllm/config.py
+3
-2
vllm/model_executor/layers/attention.py
vllm/model_executor/layers/attention.py
+1
-1
vllm/model_executor/model_loader.py
vllm/model_executor/model_loader.py
+2
-1
vllm/model_executor/models/__init__.py
vllm/model_executor/models/__init__.py
+2
-0
vllm/model_executor/models/mpt.py
vllm/model_executor/models/mpt.py
+279
-0
vllm/transformers_utils/config.py
vllm/transformers_utils/config.py
+15
-0
vllm/transformers_utils/configs/__init__.py
vllm/transformers_utils/configs/__init__.py
+5
-0
vllm/transformers_utils/configs/mpt.py
vllm/transformers_utils/configs/mpt.py
+74
-0
No files found.
README.md
View file @
404422f4
...
...
@@ -46,6 +46,7 @@ vLLM seamlessly supports many Huggingface models, including the following archit
-
GPT BigCode (
`bigcode/starcoder`
,
`bigcode/gpt_bigcode-santacoder`
, etc.)
-
GPT-NeoX (
`EleutherAI/gpt-neox-20b`
,
`databricks/dolly-v2-12b`
,
`stabilityai/stablelm-tuned-alpha-7b`
, etc.)
-
LLaMA (
`lmsys/vicuna-13b-v1.3`
,
`young-geng/koala`
,
`openlm-research/open_llama_13b`
, etc.)
-
MPT (
`mosaicml/mpt-7b`
,
`mosaicml/mpt-30b`
, etc.)
-
OPT (
`facebook/opt-66b`
,
`facebook/opt-iml-max-30b`
, etc.)
Install vLLM with pip or
[
from source
](
https://vllm.readthedocs.io/en/latest/getting_started/installation.html#build-from-source
)
:
...
...
csrc/attention/attention_kernels.cu
View file @
404422f4
...
...
@@ -395,6 +395,9 @@ void single_query_cached_kv_attention_launcher(
case
96
:
LAUNCH_ATTENTION_KERNEL
(
T
,
96
,
BLOCK_SIZE
,
NUM_THREADS
);
break
;
case
112
:
LAUNCH_ATTENTION_KERNEL
(
T
,
112
,
BLOCK_SIZE
,
NUM_THREADS
);
break
;
case
128
:
LAUNCH_ATTENTION_KERNEL
(
T
,
128
,
BLOCK_SIZE
,
NUM_THREADS
);
break
;
...
...
docs/source/models/supported_models.rst
View file @
404422f4
...
...
@@ -29,6 +29,9 @@ Alongside each architecture, we include some popular models that use it.
* - :code:`LlamaForCausalLM`
- LLaMA, Vicuna, Alpaca, Koala, Guanaco
- :code:`openlm-research/open_llama_13b`, :code:`lmsys/vicuna-13b-v1.3`, :code:`young-geng/koala`, :code:`JosephusCheung/Guanaco`, etc.
* - :code: `MPTForCausalLM`
- MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter
- :code:`mosaicml/mpt-7b`, :code:`mosaicml/mpt-7b-storywriter`, :code:`mosaicml/mpt-30b`, etc.
* - :code:`OPTForCausalLM`
- OPT, OPT-IML
- :code:`facebook/opt-66b`, :code:`facebook/opt-iml-max-30b`, etc.
...
...
vllm/config.py
View file @
404422f4
from
typing
import
Optional
import
torch
from
transformers
import
AutoConfig
,
PretrainedConfig
from
transformers
import
PretrainedConfig
from
vllm.logger
import
init_logger
from
vllm.transformers_utils.config
import
get_config
from
vllm.utils
import
get_cpu_memory
logger
=
init_logger
(
__name__
)
...
...
@@ -49,7 +50,7 @@ class ModelConfig:
self
.
use_dummy_weights
=
use_dummy_weights
self
.
seed
=
seed
self
.
hf_config
:
PretrainedConfig
=
AutoConfig
.
from_pretrained
(
model
)
self
.
hf_config
=
get_config
(
model
)
self
.
dtype
=
_get_and_verify_dtype
(
self
.
hf_config
,
dtype
)
self
.
_verify_tokenizer_mode
()
...
...
vllm/model_executor/layers/attention.py
View file @
404422f4
...
...
@@ -12,7 +12,7 @@ from vllm import cache_ops
from
vllm
import
pos_encoding_ops
from
vllm.model_executor.input_metadata
import
InputMetadata
_SUPPORTED_HEAD_SIZES
=
[
64
,
80
,
96
,
128
]
_SUPPORTED_HEAD_SIZES
=
[
64
,
80
,
96
,
112
,
128
]
class
PagedAttention
(
nn
.
Module
):
...
...
vllm/model_executor/model_loader.py
View file @
404422f4
...
...
@@ -16,7 +16,8 @@ _MODEL_REGISTRY = {
"GPTBigCodeForCausalLM"
:
GPTBigCodeForCausalLM
,
"GPTNeoXForCausalLM"
:
GPTNeoXForCausalLM
,
"LlamaForCausalLM"
:
LlamaForCausalLM
,
"LLaMAForCausalLM"
:
LlamaForCausalLM
,
"LLaMAForCausalLM"
:
LlamaForCausalLM
,
# For decapoda-research/llama-*
"MPTForCausalLM"
:
MPTForCausalLM
,
"OPTForCausalLM"
:
OPTForCausalLM
,
}
...
...
vllm/model_executor/models/__init__.py
View file @
404422f4
...
...
@@ -3,6 +3,7 @@ from vllm.model_executor.models.gpt2 import GPT2LMHeadModel
from
vllm.model_executor.models.gpt_bigcode
import
GPTBigCodeForCausalLM
from
vllm.model_executor.models.gpt_neox
import
GPTNeoXForCausalLM
from
vllm.model_executor.models.llama
import
LlamaForCausalLM
from
vllm.model_executor.models.mpt
import
MPTForCausalLM
from
vllm.model_executor.models.opt
import
OPTForCausalLM
__all__
=
[
...
...
@@ -11,5 +12,6 @@ __all__ = [
"GPTBigCodeForCausalLM"
,
"GPTNeoXForCausalLM"
,
"LlamaForCausalLM"
,
"MPTForCausalLM"
,
"OPTForCausalLM"
,
]
vllm/model_executor/models/mpt.py
0 → 100644
View file @
404422f4
# Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main
import
math
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
torch
import
torch.nn
as
nn
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.attention
import
PagedAttentionWithALiBi
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.weight_utils
import
(
hf_model_weights_iterator
,
load_tensor_parallel_weights
)
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.parallel_utils.tensor_parallel
import
(
VocabParallelEmbedding
,
ColumnParallelLinear
,
RowParallelLinear
)
from
vllm.sequence
import
SequenceOutputs
from
vllm.transformers_utils.configs.mpt
import
MPTConfig
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
def
_get_alibi_slopes
(
total_num_heads
:
int
,
alibi_bias_max
:
int
,
)
->
torch
.
Tensor
:
next_power_of_2
=
2
**
math
.
ceil
(
math
.
log2
(
total_num_heads
))
m
=
torch
.
arange
(
1
,
next_power_of_2
+
1
,
dtype
=
torch
.
float32
)
m
=
m
.
mul
(
alibi_bias_max
/
next_power_of_2
)
slopes
=
1.0
/
torch
.
pow
(
2
,
m
)
if
next_power_of_2
!=
total_num_heads
:
slopes
=
torch
.
concat
([
slopes
[
1
::
2
],
slopes
[::
2
]])[:
total_num_heads
]
return
slopes
class
MPTAttention
(
nn
.
Module
):
def
__init__
(
self
,
config
:
MPTConfig
):
super
().
__init__
()
self
.
d_model
=
config
.
d_model
self
.
total_num_heads
=
config
.
n_heads
self
.
clip_qkv
=
config
.
attn_config
[
"clip_qkv"
]
self
.
qk_ln
=
config
.
attn_config
[
"qk_ln"
]
self
.
alibi_bias_max
=
config
.
attn_config
[
"alibi_bias_max"
]
assert
not
config
.
attn_config
[
"prefix_lm"
]
assert
config
.
attn_config
[
"alibi"
]
self
.
qkv_proj
=
ColumnParallelLinear
(
self
.
d_model
,
3
*
self
.
d_model
,
bias
=
not
config
.
no_bias
,
gather_output
=
False
,
perform_initialization
=
False
,
)
if
self
.
qk_ln
:
self
.
q_ln
=
nn
.
LayerNorm
(
self
.
d_model
)
self
.
k_ln
=
nn
.
LayerNorm
(
self
.
d_model
)
self
.
out_proj
=
RowParallelLinear
(
self
.
d_model
,
self
.
d_model
,
bias
=
not
config
.
no_bias
,
input_is_parallel
=
True
,
perform_initialization
=
False
,
)
tp_world_size
=
get_tensor_model_parallel_world_size
()
assert
self
.
total_num_heads
%
tp_world_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
tp_world_size
# Create the alibi slopes and slice them.
tp_rank
=
get_tensor_model_parallel_rank
()
head_start
=
tp_rank
*
self
.
num_heads
head_end
=
(
tp_rank
+
1
)
*
self
.
num_heads
alibi_slopes
=
_get_alibi_slopes
(
self
.
total_num_heads
,
self
.
alibi_bias_max
)
alibi_slopes
=
alibi_slopes
[
head_start
:
head_end
].
tolist
()
self
.
head_dim
=
self
.
d_model
//
self
.
total_num_heads
scaling
=
self
.
head_dim
**-
0.5
self
.
attn
=
PagedAttentionWithALiBi
(
self
.
num_heads
,
self
.
head_dim
,
scaling
,
alibi_slopes
)
def
forward
(
self
,
position_ids
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input_metadata
:
InputMetadata
,
cache_event
:
Optional
[
torch
.
cuda
.
Event
],
)
->
torch
.
Tensor
:
del
position_ids
# unused.
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
if
self
.
clip_qkv
is
not
None
:
qkv
.
clamp_
(
min
=-
self
.
clip_qkv
,
max
=
self
.
clip_qkv
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
if
self
.
qk_ln
:
q
=
self
.
q_ln
(
q
)
k
=
self
.
k_ln
(
k
)
k_cache
,
v_cache
=
kv_cache
attn_output
=
self
.
attn
(
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
,
cache_event
)
output
,
_
=
self
.
out_proj
(
attn_output
)
return
output
class
MPTMLP
(
nn
.
Module
):
def
__init__
(
self
,
config
:
MPTConfig
):
super
().
__init__
()
hidden_size
=
config
.
d_model
expansion_ratio
=
config
.
expansion_ratio
intermediate_size
=
expansion_ratio
*
hidden_size
self
.
up_proj
=
ColumnParallelLinear
(
hidden_size
,
intermediate_size
,
bias
=
not
config
.
no_bias
,
gather_output
=
False
,
perform_initialization
=
False
)
self
.
act
=
get_act_fn
(
"gelu"
)
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
not
config
.
no_bias
,
input_is_parallel
=
True
,
perform_initialization
=
False
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x
,
_
=
self
.
up_proj
(
x
)
x
=
self
.
act
(
x
)
x
,
_
=
self
.
down_proj
(
x
)
return
x
class
MPTBlock
(
nn
.
Module
):
def
__init__
(
self
,
config
:
MPTConfig
):
super
().
__init__
()
hidden_size
=
config
.
d_model
self
.
norm_1
=
nn
.
LayerNorm
(
hidden_size
)
self
.
attn
=
MPTAttention
(
config
)
self
.
norm_2
=
nn
.
LayerNorm
(
hidden_size
)
self
.
ffn
=
MPTMLP
(
config
)
def
forward
(
self
,
position_ids
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input_metadata
:
InputMetadata
,
cache_event
:
Optional
[
torch
.
cuda
.
Event
],
)
->
torch
.
Tensor
:
x
=
self
.
norm_1
(
hidden_states
)
x
=
self
.
attn
(
position_ids
=
position_ids
,
hidden_states
=
x
,
kv_cache
=
kv_cache
,
input_metadata
=
input_metadata
,
cache_event
=
cache_event
,
)
hidden_states
=
hidden_states
+
x
x
=
self
.
norm_2
(
hidden_states
)
x
=
self
.
ffn
(
x
)
hidden_states
=
hidden_states
+
x
return
hidden_states
class
MPTModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
MPTConfig
):
super
().
__init__
()
assert
config
.
embedding_fraction
==
1.0
assert
config
.
norm_type
==
"low_precision_layernorm"
self
.
wte
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
d_model
,
perform_initialization
=
False
)
self
.
blocks
=
nn
.
ModuleList
(
[
MPTBlock
(
config
)
for
_
in
range
(
config
.
n_layers
)])
self
.
norm_f
=
nn
.
LayerNorm
(
config
.
d_model
)
if
config
.
no_bias
:
for
module
in
self
.
modules
():
if
hasattr
(
module
,
"bias"
):
if
isinstance
(
module
.
bias
,
nn
.
Parameter
):
# Remove the bias term in Linear and LayerNorm.
module
.
register_parameter
(
"bias"
,
None
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input_metadata
:
InputMetadata
,
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
)
->
torch
.
Tensor
:
hidden_states
=
self
.
wte
(
input_ids
)
for
i
in
range
(
len
(
self
.
blocks
)):
if
cache_events
is
None
:
cache_event
=
None
else
:
cache_event
=
cache_events
[
i
]
block
=
self
.
blocks
[
i
]
hidden_states
=
block
(
position_ids
,
hidden_states
,
kv_caches
[
i
],
input_metadata
,
cache_event
,
)
hidden_states
=
self
.
norm_f
(
hidden_states
)
return
hidden_states
class
MPTForCausalLM
(
nn
.
Module
):
def
__init__
(
self
,
config
:
MPTConfig
):
super
().
__init__
()
self
.
config
=
config
assert
config
.
tie_word_embeddings
self
.
transformer
=
MPTModel
(
config
)
# TODO(zhuohan): create a new weight after implementing pipeline
# parallelism
self
.
lm_head_weight
=
self
.
transformer
.
wte
.
weight
self
.
sampler
=
Sampler
(
config
.
vocab_size
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input_metadata
:
InputMetadata
,
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
)
->
Dict
[
int
,
SequenceOutputs
]:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
input_metadata
,
cache_events
)
next_tokens
=
self
.
sampler
(
self
.
lm_head_weight
,
hidden_states
,
input_metadata
)
return
next_tokens
_column_parallel_weights
=
[
"wte.weight"
,
"up_proj.weight"
,
"up_proj.bias"
]
_row_parallel_weights
=
[
"out_proj.weight"
,
"down_proj.weight"
]
def
load_weights
(
self
,
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
use_np_cache
:
bool
=
False
):
tp_world_size
=
get_tensor_model_parallel_world_size
()
tp_rank
=
get_tensor_model_parallel_rank
()
state_dict
=
self
.
state_dict
()
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
model_name_or_path
,
cache_dir
,
use_np_cache
):
if
"Wqkv"
in
name
:
# NOTE(woosuk): MPT's fused QKV has the shape of
# [3 * num_heads * head_size, hidden_size].
# When tensor model parallelism is used, we need to shard
# the weight along the hidden dimension.
total_num_heads
=
self
.
config
.
num_attention_heads
hidden_size
=
self
.
config
.
hidden_size
head_size
=
hidden_size
//
total_num_heads
num_heads
=
total_num_heads
//
tp_world_size
head_start
=
tp_rank
*
num_heads
head_end
=
(
tp_rank
+
1
)
*
num_heads
if
name
.
endswith
(
".weight"
):
loaded_weight
=
loaded_weight
.
view
(
3
,
total_num_heads
,
head_size
,
hidden_size
)
loaded_weight
=
loaded_weight
[:,
head_start
:
head_end
,
:,
:]
loaded_weight
=
loaded_weight
.
reshape
(
-
1
,
hidden_size
)
elif
name
.
endswith
(
".bias"
):
loaded_weight
=
loaded_weight
.
view
(
3
,
total_num_heads
,
head_size
)
loaded_weight
=
loaded_weight
[:,
head_start
:
head_end
,
:]
loaded_weight
=
loaded_weight
.
reshape
(
-
1
)
else
:
raise
ValueError
(
f
"Unexpected parameter name
{
name
}
"
)
name
=
name
.
replace
(
"Wqkv"
,
"qkv_proj"
)
param
=
state_dict
[
name
]
load_tensor_parallel_weights
(
param
,
loaded_weight
,
name
,
self
.
_column_parallel_weights
,
self
.
_row_parallel_weights
,
tp_rank
)
vllm/transformers_utils/config.py
0 → 100644
View file @
404422f4
from
transformers
import
AutoConfig
,
PretrainedConfig
from
vllm.transformers_utils.configs
import
*
# pylint: disable=wildcard-import
_CONFIG_REGISTRY
=
{
"mpt"
:
MPTConfig
,
}
def
get_config
(
model
:
str
)
->
PretrainedConfig
:
config
=
AutoConfig
.
from_pretrained
(
model
,
trust_remote_code
=
True
)
if
config
.
model_type
in
_CONFIG_REGISTRY
:
config_class
=
_CONFIG_REGISTRY
[
config
.
model_type
]
config
=
config_class
.
from_pretrained
(
model
)
return
config
vllm/transformers_utils/configs/__init__.py
0 → 100644
View file @
404422f4
from
vllm.transformers_utils.configs.mpt
import
MPTConfig
__all__
=
[
"MPTConfig"
,
]
vllm/transformers_utils/configs/mpt.py
0 → 100644
View file @
404422f4
# Adapted from
# https://huggingface.co/mosaicml/mpt-7b/blob/main/configuration_mpt.py
from
typing
import
Any
,
Dict
,
Optional
,
Union
from
transformers
import
PretrainedConfig
_ATTN_CONFIG_DEFAULTS
=
{
"attn_type"
:
"multihead_attention"
,
"attn_pdrop"
:
0.0
,
"attn_impl"
:
"triton"
,
"qk_ln"
:
False
,
"clip_qkv"
:
None
,
"softmax_scale"
:
None
,
"prefix_lm"
:
False
,
"attn_uses_sequence_id"
:
False
,
"alibi"
:
False
,
"alibi_bias_max"
:
8
,
}
class
MPTConfig
(
PretrainedConfig
):
model_type
=
"mpt"
attribute_map
=
{
"hidden_size"
:
"d_model"
,
"num_attention_heads"
:
"n_heads"
,
"num_hidden_layers"
:
"n_layers"
,
}
def
__init__
(
self
,
d_model
:
int
=
2048
,
n_heads
:
int
=
16
,
n_layers
:
int
=
24
,
expansion_ratio
:
int
=
4
,
max_seq_len
:
int
=
2048
,
vocab_size
:
int
=
50368
,
resid_pdrop
:
float
=
0.0
,
emb_pdrop
:
float
=
0.0
,
learned_pos_emb
:
bool
=
True
,
attn_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
init_device
:
str
=
"cpu"
,
logit_scale
:
Optional
[
Union
[
float
,
str
]]
=
None
,
no_bias
:
bool
=
False
,
verbose
:
int
=
0
,
embedding_fraction
:
float
=
1.0
,
norm_type
:
str
=
"low_precision_layernorm"
,
use_cache
:
bool
=
False
,
**
kwargs
,
)
->
None
:
self
.
d_model
=
d_model
self
.
n_heads
=
n_heads
self
.
n_layers
=
n_layers
self
.
expansion_ratio
=
expansion_ratio
self
.
max_seq_len
=
max_seq_len
self
.
vocab_size
=
vocab_size
self
.
resid_pdrop
=
resid_pdrop
self
.
emb_pdrop
=
emb_pdrop
self
.
learned_pos_emb
=
learned_pos_emb
if
attn_config
is
None
:
self
.
attn_config
=
_ATTN_CONFIG_DEFAULTS
else
:
self
.
attn_config
=
attn_config
self
.
init_device
=
init_device
self
.
logit_scale
=
logit_scale
self
.
no_bias
=
no_bias
self
.
verbose
=
verbose
self
.
embedding_fraction
=
embedding_fraction
self
.
norm_type
=
norm_type
self
.
use_cache
=
use_cache
if
"name"
in
kwargs
:
del
kwargs
[
"name"
]
if
"loss_fn"
in
kwargs
:
del
kwargs
[
"loss_fn"
]
super
().
__init__
(
**
kwargs
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment