Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
1fe09900
Unverified
Commit
1fe09900
authored
Nov 01, 2023
by
Woosuk Kwon
Committed by
GitHub
Nov 01, 2023
Browse files
Remove `MPTConfig` (#1529)
parent
7e90a2d1
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
26 additions
and
102 deletions
+26
-102
vllm/model_executor/model_loader.py
vllm/model_executor/model_loader.py
+2
-2
vllm/model_executor/models/__init__.py
vllm/model_executor/models/__init__.py
+2
-2
vllm/model_executor/models/mpt.py
vllm/model_executor/models/mpt.py
+20
-20
vllm/transformers_utils/config.py
vllm/transformers_utils/config.py
+2
-2
vllm/transformers_utils/configs/__init__.py
vllm/transformers_utils/configs/__init__.py
+0
-2
vllm/transformers_utils/configs/mpt.py
vllm/transformers_utils/configs/mpt.py
+0
-74
No files found.
vllm/model_executor/model_loader.py
View file @
1fe09900
...
...
@@ -28,8 +28,8 @@ _MODEL_REGISTRY = {
"LLaMAForCausalLM"
:
LlamaForCausalLM
,
# For decapoda-research/llama-*
"MistralForCausalLM"
:
MistralForCausalLM
,
# transformers's mpt class has lower case
"MptForCausalLM"
:
M
PT
ForCausalLM
,
"MPTForCausalLM"
:
M
PT
ForCausalLM
,
"MptForCausalLM"
:
M
pt
ForCausalLM
,
"MPTForCausalLM"
:
M
pt
ForCausalLM
,
"OPTForCausalLM"
:
OPTForCausalLM
,
"QWenLMHeadModel"
:
QWenLMHeadModel
,
"RWForCausalLM"
:
FalconForCausalLM
,
...
...
vllm/model_executor/models/__init__.py
View file @
1fe09900
...
...
@@ -10,7 +10,7 @@ from vllm.model_executor.models.gpt_neox import GPTNeoXForCausalLM
from
vllm.model_executor.models.internlm
import
InternLMForCausalLM
from
vllm.model_executor.models.llama
import
LlamaForCausalLM
from
vllm.model_executor.models.mistral
import
MistralForCausalLM
from
vllm.model_executor.models.mpt
import
M
PT
ForCausalLM
from
vllm.model_executor.models.mpt
import
M
pt
ForCausalLM
from
vllm.model_executor.models.opt
import
OPTForCausalLM
from
vllm.model_executor.models.qwen
import
QWenLMHeadModel
...
...
@@ -26,7 +26,7 @@ __all__ = [
"GPTNeoXForCausalLM"
,
"InternLMForCausalLM"
,
"LlamaForCausalLM"
,
"M
PT
ForCausalLM"
,
"M
pt
ForCausalLM"
,
"OPTForCausalLM"
,
"QWenLMHeadModel"
,
"MistralForCausalLM"
,
...
...
vllm/model_executor/models/mpt.py
View file @
1fe09900
...
...
@@ -5,6 +5,7 @@ from typing import List, Optional, Tuple
import
torch
import
torch.nn
as
nn
from
transformers
import
MptConfig
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.activation
import
get_act_fn
...
...
@@ -19,7 +20,6 @@ from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding,
ColumnParallelLinear
,
RowParallelLinear
)
from
vllm.sequence
import
SamplerOutput
from
vllm.transformers_utils.configs.mpt
import
MPTConfig
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
...
...
@@ -37,17 +37,17 @@ def _get_alibi_slopes(
return
slopes
class
M
PT
Attention
(
nn
.
Module
):
class
M
pt
Attention
(
nn
.
Module
):
def
__init__
(
self
,
config
:
M
PT
Config
):
def
__init__
(
self
,
config
:
M
pt
Config
):
super
().
__init__
()
self
.
d_model
=
config
.
d_model
self
.
total_num_heads
=
config
.
n_heads
self
.
clip_qkv
=
config
.
attn_config
[
"
clip_qkv
"
]
self
.
qk_ln
=
config
.
attn_config
[
"
qk_ln
"
]
self
.
alibi_bias_max
=
config
.
attn_config
[
"
alibi_bias_max
"
]
assert
not
config
.
attn_config
[
"
prefix_lm
"
]
assert
config
.
attn_config
[
"
alibi
"
]
self
.
clip_qkv
=
config
.
attn_config
.
clip_qkv
self
.
qk_ln
=
config
.
attn_config
.
qk_ln
self
.
alibi_bias_max
=
config
.
attn_config
.
alibi_bias_max
assert
not
config
.
attn_config
.
prefix_lm
assert
config
.
attn_config
.
alibi
self
.
qkv_proj
=
ColumnParallelLinear
(
self
.
d_model
,
...
...
@@ -105,9 +105,9 @@ class MPTAttention(nn.Module):
return
output
class
M
PT
MLP
(
nn
.
Module
):
class
M
pt
MLP
(
nn
.
Module
):
def
__init__
(
self
,
config
:
M
PT
Config
):
def
__init__
(
self
,
config
:
M
pt
Config
):
super
().
__init__
()
hidden_size
=
config
.
d_model
expansion_ratio
=
config
.
expansion_ratio
...
...
@@ -133,15 +133,15 @@ class MPTMLP(nn.Module):
return
x
class
M
PT
Block
(
nn
.
Module
):
class
M
pt
Block
(
nn
.
Module
):
def
__init__
(
self
,
config
:
M
PT
Config
):
def
__init__
(
self
,
config
:
M
pt
Config
):
super
().
__init__
()
hidden_size
=
config
.
d_model
self
.
norm_1
=
nn
.
LayerNorm
(
hidden_size
)
self
.
attn
=
M
PT
Attention
(
config
)
self
.
attn
=
M
pt
Attention
(
config
)
self
.
norm_2
=
nn
.
LayerNorm
(
hidden_size
)
self
.
ffn
=
M
PT
MLP
(
config
)
self
.
ffn
=
M
pt
MLP
(
config
)
def
forward
(
self
,
...
...
@@ -166,9 +166,9 @@ class MPTBlock(nn.Module):
return
hidden_states
class
M
PT
Model
(
nn
.
Module
):
class
M
pt
Model
(
nn
.
Module
):
def
__init__
(
self
,
config
:
M
PT
Config
):
def
__init__
(
self
,
config
:
M
pt
Config
):
super
().
__init__
()
assert
config
.
embedding_fraction
==
1.0
assert
config
.
norm_type
==
"low_precision_layernorm"
...
...
@@ -178,7 +178,7 @@ class MPTModel(nn.Module):
config
.
d_model
,
)
self
.
blocks
=
nn
.
ModuleList
(
[
M
PT
Block
(
config
)
for
_
in
range
(
config
.
n_layers
)])
[
M
pt
Block
(
config
)
for
_
in
range
(
config
.
n_layers
)])
self
.
norm_f
=
nn
.
LayerNorm
(
config
.
d_model
)
if
config
.
no_bias
:
for
module
in
self
.
modules
():
...
...
@@ -213,14 +213,14 @@ class MPTModel(nn.Module):
return
hidden_states
class
M
PT
ForCausalLM
(
nn
.
Module
):
class
M
pt
ForCausalLM
(
nn
.
Module
):
def
__init__
(
self
,
config
:
M
PT
Config
):
def
__init__
(
self
,
config
:
M
pt
Config
):
super
().
__init__
()
self
.
config
=
config
assert
config
.
tie_word_embeddings
self
.
transformer
=
M
PT
Model
(
config
)
self
.
transformer
=
M
pt
Model
(
config
)
# TODO(zhuohan): create a new weight after implementing pipeline
# parallelism
self
.
lm_head_weight
=
self
.
transformer
.
wte
.
weight
...
...
vllm/transformers_utils/config.py
View file @
1fe09900
from
typing
import
Optional
from
transformers
import
AutoConfig
,
PretrainedConfig
from
transformers
import
AutoConfig
,
MptConfig
,
PretrainedConfig
from
vllm.transformers_utils.configs
import
*
# pylint: disable=wildcard-import
_CONFIG_REGISTRY
=
{
"mpt"
:
M
PT
Config
,
"mpt"
:
M
pt
Config
,
"baichuan"
:
BaiChuanConfig
,
"aquila"
:
AquilaConfig
,
"qwen"
:
QWenConfig
,
...
...
vllm/transformers_utils/configs/__init__.py
View file @
1fe09900
from
vllm.transformers_utils.configs.mpt
import
MPTConfig
from
vllm.transformers_utils.configs.baichuan
import
BaiChuanConfig
from
vllm.transformers_utils.configs.aquila
import
AquilaConfig
from
vllm.transformers_utils.configs.qwen
import
QWenConfig
...
...
@@ -8,7 +7,6 @@ from vllm.transformers_utils.configs.qwen import QWenConfig
from
vllm.transformers_utils.configs.falcon
import
RWConfig
__all__
=
[
"MPTConfig"
,
"BaiChuanConfig"
,
"AquilaConfig"
,
"QWenConfig"
,
...
...
vllm/transformers_utils/configs/mpt.py
deleted
100644 → 0
View file @
7e90a2d1
# Adapted from
# https://huggingface.co/mosaicml/mpt-7b/blob/main/configuration_mpt.py
from
typing
import
Any
,
Dict
,
Optional
,
Union
from
transformers
import
PretrainedConfig
_ATTN_CONFIG_DEFAULTS
=
{
"attn_type"
:
"multihead_attention"
,
"attn_pdrop"
:
0.0
,
"attn_impl"
:
"triton"
,
"qk_ln"
:
False
,
"clip_qkv"
:
None
,
"softmax_scale"
:
None
,
"prefix_lm"
:
False
,
"attn_uses_sequence_id"
:
False
,
"alibi"
:
False
,
"alibi_bias_max"
:
8
,
}
class
MPTConfig
(
PretrainedConfig
):
model_type
=
"mpt"
attribute_map
=
{
"hidden_size"
:
"d_model"
,
"num_attention_heads"
:
"n_heads"
,
"num_hidden_layers"
:
"n_layers"
,
}
def
__init__
(
self
,
d_model
:
int
=
2048
,
n_heads
:
int
=
16
,
n_layers
:
int
=
24
,
expansion_ratio
:
int
=
4
,
max_seq_len
:
int
=
2048
,
vocab_size
:
int
=
50368
,
resid_pdrop
:
float
=
0.0
,
emb_pdrop
:
float
=
0.0
,
learned_pos_emb
:
bool
=
True
,
attn_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
init_device
:
str
=
"cpu"
,
logit_scale
:
Optional
[
Union
[
float
,
str
]]
=
None
,
no_bias
:
bool
=
False
,
verbose
:
int
=
0
,
embedding_fraction
:
float
=
1.0
,
norm_type
:
str
=
"low_precision_layernorm"
,
use_cache
:
bool
=
False
,
**
kwargs
,
)
->
None
:
self
.
d_model
=
d_model
self
.
n_heads
=
n_heads
self
.
n_layers
=
n_layers
self
.
expansion_ratio
=
expansion_ratio
self
.
max_seq_len
=
max_seq_len
self
.
vocab_size
=
vocab_size
self
.
resid_pdrop
=
resid_pdrop
self
.
emb_pdrop
=
emb_pdrop
self
.
learned_pos_emb
=
learned_pos_emb
if
attn_config
is
None
:
self
.
attn_config
=
_ATTN_CONFIG_DEFAULTS
else
:
self
.
attn_config
=
attn_config
self
.
init_device
=
init_device
self
.
logit_scale
=
logit_scale
self
.
no_bias
=
no_bias
self
.
verbose
=
verbose
self
.
embedding_fraction
=
embedding_fraction
self
.
norm_type
=
norm_type
self
.
use_cache
=
use_cache
if
"name"
in
kwargs
:
del
kwargs
[
"name"
]
if
"loss_fn"
in
kwargs
:
del
kwargs
[
"loss_fn"
]
super
().
__init__
(
**
kwargs
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment