Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
ab3a5a82
Unverified
Commit
ab3a5a82
authored
Feb 19, 2024
by
Isotr0py
Committed by
GitHub
Feb 18, 2024
Browse files
Support OLMo models. (#2832)
parent
a61f0521
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
471 additions
and
5 deletions
+471
-5
README.md
README.md
+1
-0
docs/source/models/supported_models.rst
docs/source/models/supported_models.rst
+3
-0
tests/models/test_models.py
tests/models/test_models.py
+14
-5
vllm/model_executor/models/__init__.py
vllm/model_executor/models/__init__.py
+1
-0
vllm/model_executor/models/olmo.py
vllm/model_executor/models/olmo.py
+378
-0
vllm/transformers_utils/configs/__init__.py
vllm/transformers_utils/configs/__init__.py
+2
-0
vllm/transformers_utils/configs/olmo.py
vllm/transformers_utils/configs/olmo.py
+72
-0
No files found.
README.md
View file @
ab3a5a82
...
@@ -70,6 +70,7 @@ vLLM seamlessly supports many Hugging Face models, including the following archi
...
@@ -70,6 +70,7 @@ vLLM seamlessly supports many Hugging Face models, including the following archi
-
Mistral (
`mistralai/Mistral-7B-v0.1`
,
`mistralai/Mistral-7B-Instruct-v0.1`
, etc.)
-
Mistral (
`mistralai/Mistral-7B-v0.1`
,
`mistralai/Mistral-7B-Instruct-v0.1`
, etc.)
-
Mixtral (
`mistralai/Mixtral-8x7B-v0.1`
,
`mistralai/Mixtral-8x7B-Instruct-v0.1`
, etc.)
-
Mixtral (
`mistralai/Mixtral-8x7B-v0.1`
,
`mistralai/Mixtral-8x7B-Instruct-v0.1`
, etc.)
-
MPT (
`mosaicml/mpt-7b`
,
`mosaicml/mpt-30b`
, etc.)
-
MPT (
`mosaicml/mpt-7b`
,
`mosaicml/mpt-30b`
, etc.)
-
OLMo (
`allenai/OLMo-1B`
,
`allenai/OLMo-7B`
, etc.)
-
OPT (
`facebook/opt-66b`
,
`facebook/opt-iml-max-30b`
, etc.)
-
OPT (
`facebook/opt-66b`
,
`facebook/opt-iml-max-30b`
, etc.)
-
Phi (
`microsoft/phi-1_5`
,
`microsoft/phi-2`
, etc.)
-
Phi (
`microsoft/phi-1_5`
,
`microsoft/phi-2`
, etc.)
-
Qwen (
`Qwen/Qwen-7B`
,
`Qwen/Qwen-7B-Chat`
, etc.)
-
Qwen (
`Qwen/Qwen-7B`
,
`Qwen/Qwen-7B-Chat`
, etc.)
...
...
docs/source/models/supported_models.rst
View file @
ab3a5a82
...
@@ -62,6 +62,9 @@ Alongside each architecture, we include some popular models that use it.
...
@@ -62,6 +62,9 @@ Alongside each architecture, we include some popular models that use it.
* - :code:`MPTForCausalLM`
* - :code:`MPTForCausalLM`
- MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter
- MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter
- :code:`mosaicml/mpt-7b`, :code:`mosaicml/mpt-7b-storywriter`, :code:`mosaicml/mpt-30b`, etc.
- :code:`mosaicml/mpt-7b`, :code:`mosaicml/mpt-7b-storywriter`, :code:`mosaicml/mpt-30b`, etc.
* - :code:`OLMoForCausalLM`
- OLMo
- :code:`allenai/OLMo-1B`, :code:`allenai/OLMo-7B`, etc.
* - :code:`OPTForCausalLM`
* - :code:`OPTForCausalLM`
- OPT, OPT-IML
- OPT, OPT-IML
- :code:`facebook/opt-66b`, :code:`facebook/opt-iml-max-30b`, etc.
- :code:`facebook/opt-66b`, :code:`facebook/opt-iml-max-30b`, etc.
...
...
tests/models/test_models.py
View file @
ab3a5a82
...
@@ -5,11 +5,20 @@ Run `pytest tests/models/test_models.py --forked`.
...
@@ -5,11 +5,20 @@ Run `pytest tests/models/test_models.py --forked`.
import
pytest
import
pytest
MODELS
=
[
MODELS
=
[
"facebook/opt-125m"
,
"meta-llama/Llama-2-7b-hf"
,
"facebook/opt-125m"
,
"mistralai/Mistral-7B-v0.1"
,
"Deci/DeciLM-7b"
,
"tiiuae/falcon-7b"
,
"gpt2"
,
"meta-llama/Llama-2-7b-hf"
,
"bigcode/tiny_starcoder_py"
,
"EleutherAI/gpt-j-6b"
,
"mistralai/Mistral-7B-v0.1"
,
"EleutherAI/pythia-70m"
,
"bigscience/bloom-560m"
,
"mosaicml/mpt-7b"
,
"Deci/DeciLM-7b"
,
"microsoft/phi-2"
,
"stabilityai/stablelm-3b-4e1t"
"tiiuae/falcon-7b"
,
"gpt2"
,
"bigcode/tiny_starcoder_py"
,
"EleutherAI/gpt-j-6b"
,
"EleutherAI/pythia-70m"
,
"bigscience/bloom-560m"
,
"mosaicml/mpt-7b"
,
"microsoft/phi-2"
,
"stabilityai/stablelm-3b-4e1t"
,
"allenai/OLMo-1B"
,
]
]
...
...
vllm/model_executor/models/__init__.py
View file @
ab3a5a82
...
@@ -35,6 +35,7 @@ _MODELS = {
...
@@ -35,6 +35,7 @@ _MODELS = {
# transformers's mpt class has lower case
# transformers's mpt class has lower case
"MptForCausalLM"
:
(
"mpt"
,
"MPTForCausalLM"
),
"MptForCausalLM"
:
(
"mpt"
,
"MPTForCausalLM"
),
"MPTForCausalLM"
:
(
"mpt"
,
"MPTForCausalLM"
),
"MPTForCausalLM"
:
(
"mpt"
,
"MPTForCausalLM"
),
"OLMoForCausalLM"
:
(
"olmo"
,
"OLMoForCausalLM"
),
"OPTForCausalLM"
:
(
"opt"
,
"OPTForCausalLM"
),
"OPTForCausalLM"
:
(
"opt"
,
"OPTForCausalLM"
),
"PhiForCausalLM"
:
(
"phi"
,
"PhiForCausalLM"
),
"PhiForCausalLM"
:
(
"phi"
,
"PhiForCausalLM"
),
"QWenLMHeadModel"
:
(
"qwen"
,
"QWenLMHeadModel"
),
"QWenLMHeadModel"
:
(
"qwen"
,
"QWenLMHeadModel"
),
...
...
vllm/model_executor/models/olmo.py
0 → 100644
View file @
ab3a5a82
# coding=utf-8
# Adapted from
# https://github.com/allenai/OLMo/blob/v0.2.4/olmo/model.py and
# https://github.com/allenai/OLMo/blob/v0.2.4/hf_olmo/modeling_olmo.py
# Copyright 2023 The vLLM team.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
#
# BSD 3-Clause License
#
# Copyright (c) 2022, Tri Dao, trid@cs.stanford.edu.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Inference-only OLMo model compatible with HuggingFace weights."""
from
typing
import
List
,
Optional
,
Tuple
import
torch
import
torch.nn.functional
as
F
from
torch
import
nn
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.attention
import
PagedAttention
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearMethodBase
,
QKVParallelLinear
,
RowParallelLinear
,
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_world_size
,
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
hf_model_weights_iterator
,
)
from
vllm.sequence
import
SamplerOutput
from
vllm.transformers_utils.configs.olmo
import
OLMoConfig
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
class
SwiGLU
(
nn
.
Module
):
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x
,
gate
=
x
.
chunk
(
2
,
dim
=-
1
)
return
F
.
silu
(
gate
)
*
x
@
property
def
output_multiplier
(
self
)
->
float
:
return
0.5
class
OlmoAttention
(
nn
.
Module
):
"""
This is the attention block where the output is computed as ``Attention(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))``
(plus another skip connection).
"""
def
__init__
(
self
,
config
:
OLMoConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
self
.
hidden_size
=
config
.
d_model
assert
config
.
d_model
%
config
.
n_heads
==
0
tensor_model_parallel_world_size
=
get_tensor_model_parallel_world_size
(
)
self
.
total_num_heads
=
self
.
config
.
n_heads
assert
self
.
total_num_heads
%
tensor_model_parallel_world_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
tensor_model_parallel_world_size
self
.
head_dim
=
self
.
hidden_size
//
self
.
total_num_heads
# Layer norms.
self
.
attn_norm
=
nn
.
LayerNorm
(
config
.
d_model
,
elementwise_affine
=
False
,
bias
=
False
)
# Attention input projection. Projects x -> (q, k, v)
self
.
att_proj
=
QKVParallelLinear
(
config
.
d_model
,
self
.
head_dim
,
self
.
total_num_heads
,
bias
=
config
.
include_bias
,
linear_method
=
linear_method
,
)
# Rotary embeddings.
if
self
.
config
.
rope
:
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
max_position_embeddings
=
getattr
(
config
,
"max_position_embeddings"
,
8192
)
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
rotary_dim
=
self
.
head_dim
,
max_position
=
max_position_embeddings
,
base
=
rope_theta
,
)
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
attn
=
PagedAttention
(
self
.
num_heads
,
self
.
head_dim
,
scale
=
self
.
scaling
)
# Attention output projection.
self
.
attn_out
=
RowParallelLinear
(
config
.
d_model
,
config
.
d_model
,
bias
=
config
.
include_bias
,
linear_method
=
linear_method
,
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
attn_norm
(
hidden_states
)
qkv
,
_
=
self
.
att_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
if
self
.
config
.
rope
:
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
k_cache
,
v_cache
=
kv_cache
attn_output
=
self
.
attn
(
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
)
output
,
_
=
self
.
attn_out
(
attn_output
)
return
output
class
OlmoMLP
(
nn
.
Module
):
"""
This is the MLP block where the output is computed as ``MLP(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))``
(plus another skip connection).
"""
def
__init__
(
self
,
config
:
OLMoConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
self
.
hidden_size
=
(
config
.
mlp_hidden_size
if
config
.
mlp_hidden_size
is
not
None
else
config
.
mlp_ratio
*
config
.
d_model
)
# Layer norms.
self
.
ff_norm
=
nn
.
LayerNorm
(
config
.
d_model
,
elementwise_affine
=
False
,
bias
=
False
)
# Feed-forward input projection.
self
.
ff_proj
=
ColumnParallelLinear
(
config
.
d_model
,
self
.
hidden_size
,
bias
=
config
.
include_bias
,
linear_method
=
linear_method
,
)
# Activation function.
# self.act = SiluAndMul()
# self.act.output_multiplier = 0.5
self
.
act
=
SwiGLU
()
assert
(
self
.
act
.
output_multiplier
*
self
.
hidden_size
)
%
1
==
0
# Feed-forward output projection.
self
.
ff_out
=
RowParallelLinear
(
int
(
self
.
act
.
output_multiplier
*
self
.
hidden_size
),
config
.
d_model
,
bias
=
config
.
include_bias
,
linear_method
=
linear_method
,
)
def
forward
(
self
,
x
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
# Add feed-forward projection.
# shape: (batch_size, seq_len, d_model)
og_x
=
x
x
=
self
.
ff_norm
(
x
)
x
,
_
=
self
.
ff_proj
(
x
)
x
=
self
.
act
(
x
)
x
,
_
=
self
.
ff_out
(
x
)
x
=
og_x
+
x
return
x
class
OlmoBlock
(
nn
.
Module
):
"""
This is a typical transformer block where the output is computed as ``MLP(LN(x + Attention(LN(x))))``
(plus another skip connection).
"""
def
__init__
(
self
,
config
:
OLMoConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
):
super
().
__init__
()
# Attention block.
self
.
attn
=
OlmoAttention
(
config
,
linear_method
)
# MLP block.
self
.
mlp
=
OlmoMLP
(
config
,
linear_method
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input_metadata
:
InputMetadata
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]]:
# Attention block.
og_x
=
hidden_states
x
=
self
.
attn
(
positions
,
hidden_states
,
kv_cache
,
input_metadata
)
x
=
x
+
og_x
# MLP block.
hidden_states
=
self
.
mlp
(
x
)
return
hidden_states
class
OlmoModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
OLMoConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
):
super
().
__init__
()
self
.
config
=
config
self
.
transformer
=
nn
.
ModuleDict
(
dict
(
wte
=
VocabParallelEmbedding
(
config
.
embedding_size
or
config
.
vocab_size
,
config
.
d_model
,
),
ln_f
=
nn
.
LayerNorm
(
config
.
d_model
,
elementwise_affine
=
False
,
bias
=
False
),
))
blocks
=
[
OlmoBlock
(
config
,
linear_method
)
for
i
in
range
(
config
.
n_layers
)
]
if
self
.
config
.
block_group_size
>
1
:
raise
NotImplementedError
(
"Block group size > 1 not supported yet"
)
else
:
self
.
transformer
.
update
({
"blocks"
:
nn
.
ModuleList
(
blocks
)})
if
not
config
.
weight_tying
:
self
.
transformer
.
update
({
"ff_out"
:
ColumnParallelLinear
(
config
.
d_model
,
config
.
embedding_size
or
config
.
vocab_size
,
bias
=
config
.
include_bias
,
linear_method
=
linear_method
,
)
})
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
"""
:param input_ids: A tensor of shape `(batch_size, seq_len)`.
"""
# Get embeddings of input.
# shape: (batch_size, seq_len, d_model)
x
=
self
.
transformer
.
wte
(
input_ids
)
# type: ignore
# Apply blocks one-by-one.
for
block_idx
,
block
in
enumerate
(
self
.
transformer
.
blocks
):
# shape: (batch_size, seq_len, d_model)
x
=
block
(
positions
,
x
,
kv_caches
[
block_idx
],
input_metadata
,
)
# Apply final layer norm.
# shape: (batch_size, seq_len or 1, d_model)
x
=
self
.
transformer
.
ln_f
(
x
)
# type: ignore
return
x
class
OLMoForCausalLM
(
nn
.
Module
):
"""
Extremely barebones HF model wrapper.
"""
def
__init__
(
self
,
config
:
OLMoConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
):
super
().
__init__
()
self
.
config
=
config
self
.
linear_method
=
linear_method
self
.
model
=
OlmoModel
(
config
,
linear_method
)
self
.
lm_head_weight
=
(
self
.
model
.
transformer
.
wte
.
weight
if
config
.
weight_tying
else
self
.
model
.
transformer
.
ff_out
.
weight
)
self
.
sampler
=
Sampler
(
config
.
vocab_size
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
=
input_ids
,
positions
=
positions
,
kv_caches
=
kv_caches
,
input_metadata
=
input_metadata
,
)
return
hidden_states
def
sample
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
self
.
lm_head_weight
,
hidden_states
,
sampling_metadata
)
return
next_tokens
def
load_weights
(
self
,
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
load_format
:
str
=
"auto"
,
revision
:
Optional
[
str
]
=
None
,
):
params_dict
=
dict
(
self
.
named_parameters
(
remove_duplicate
=
False
))
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
model_name_or_path
,
cache_dir
,
load_format
,
revision
):
# attention
if
".att"
in
name
:
name
=
name
.
replace
(
".att"
,
".attn.att"
)
# mlp
if
".ff"
in
name
and
"transformer.ff_out"
not
in
name
:
name
=
name
.
replace
(
".ff"
,
".mlp.ff"
)
# there is no bias in olmo
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
vllm/transformers_utils/configs/__init__.py
View file @
ab3a5a82
from
vllm.transformers_utils.configs.baichuan
import
BaiChuanConfig
from
vllm.transformers_utils.configs.baichuan
import
BaiChuanConfig
from
vllm.transformers_utils.configs.chatglm
import
ChatGLMConfig
from
vllm.transformers_utils.configs.chatglm
import
ChatGLMConfig
from
vllm.transformers_utils.configs.mpt
import
MPTConfig
from
vllm.transformers_utils.configs.mpt
import
MPTConfig
from
vllm.transformers_utils.configs.olmo
import
OLMoConfig
from
vllm.transformers_utils.configs.qwen
import
QWenConfig
from
vllm.transformers_utils.configs.qwen
import
QWenConfig
# RWConfig is for the original tiiuae/falcon-40b(-instruct) and
# RWConfig is for the original tiiuae/falcon-40b(-instruct) and
# tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the
# tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the
...
@@ -11,6 +12,7 @@ __all__ = [
...
@@ -11,6 +12,7 @@ __all__ = [
"BaiChuanConfig"
,
"BaiChuanConfig"
,
"ChatGLMConfig"
,
"ChatGLMConfig"
,
"MPTConfig"
,
"MPTConfig"
,
"OLMoConfig"
,
"QWenConfig"
,
"QWenConfig"
,
"RWConfig"
,
"RWConfig"
,
]
]
vllm/transformers_utils/configs/olmo.py
0 → 100644
View file @
ab3a5a82
# coding=utf-8
# adapted from https://github.com/allenai/OLMo/blob/v0.2.4/hf_olmo/configuration_olmo.py
"""OLMo configuration"""
from
transformers
import
PretrainedConfig
class
OLMoConfig
(
PretrainedConfig
):
model_type
=
'olmo'
attribute_map
=
{
'num_attention_heads'
:
'n_heads'
,
'hidden_size'
:
'd_model'
,
'num_hidden_layers'
:
'n_layers'
,
}
# Note that the defaults for these attributes are equivalent to the base GPT2 model.
def
__init__
(
self
,
d_model
=
768
,
n_heads
=
12
,
n_layers
=
12
,
mlp_ratio
=
4
,
mlp_hidden_size
=
None
,
activation_type
=
"swiglu"
,
block_type
=
"sequential"
,
block_group_size
=
1
,
alibi
=
False
,
alibi_bias_max
=
8.0
,
rope
=
False
,
rope_full_precision
=
True
,
multi_query_attention
=
False
,
attention_layer_norm
=
False
,
layer_norm_type
=
"default"
,
layer_norm_with_affine
=
True
,
attention_layer_norm_with_affine
=
True
,
max_sequence_length
=
1024
,
include_bias
=
True
,
bias_for_layer_norm
=
None
,
scale_logits
=
False
,
vocab_size
=
50257
,
embedding_size
=
50304
,
weight_tying
=
True
,
eos_token_id
=
50256
,
pad_token_id
=
50256
,
**
kwargs
,
):
self
.
d_model
=
d_model
self
.
n_heads
=
n_heads
self
.
n_layers
=
n_layers
self
.
mlp_ratio
=
mlp_ratio
self
.
mlp_hidden_size
=
mlp_hidden_size
self
.
activation_type
=
activation_type
self
.
block_type
=
block_type
self
.
block_group_size
=
block_group_size
self
.
alibi
=
alibi
self
.
alibi_bias_max
=
alibi_bias_max
self
.
rope
=
rope
self
.
rope_full_precision
=
rope_full_precision
self
.
multi_query_attention
=
multi_query_attention
self
.
attention_layer_norm
=
attention_layer_norm
self
.
layer_norm_type
=
layer_norm_type
self
.
layer_norm_with_affine
=
layer_norm_with_affine
self
.
attention_layer_norm_with_affine
=
attention_layer_norm_with_affine
self
.
max_sequence_length
=
max_sequence_length
self
.
include_bias
=
include_bias
self
.
bias_for_layer_norm
=
bias_for_layer_norm
self
.
scale_logits
=
scale_logits
self
.
vocab_size
=
vocab_size
self
.
embedding_size
=
embedding_size
self
.
weight_tying
=
weight_tying
self
.
eos_token_id
=
eos_token_id
self
.
pad_token_id
=
pad_token_id
super
().
__init__
(
**
kwargs
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment