Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4c07dd28
Unverified
Commit
4c07dd28
authored
Mar 21, 2024
by
Lalit Pradhan
Committed by
GitHub
Mar 21, 2024
Browse files
[
🚀
Ready to be merged] Added support for Jais models (#3183)
parent
3bbff9e5
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
596 additions
and
3 deletions
+596
-3
README.md
README.md
+1
-0
docs/source/models/supported_models.rst
docs/source/models/supported_models.rst
+5
-1
vllm/model_executor/models/__init__.py
vllm/model_executor/models/__init__.py
+1
-0
vllm/model_executor/models/gpt2.py
vllm/model_executor/models/gpt2.py
+1
-2
vllm/model_executor/models/jais.py
vllm/model_executor/models/jais.py
+351
-0
vllm/transformers_utils/config.py
vllm/transformers_utils/config.py
+1
-0
vllm/transformers_utils/configs/__init__.py
vllm/transformers_utils/configs/__init__.py
+2
-0
vllm/transformers_utils/configs/jais.py
vllm/transformers_utils/configs/jais.py
+234
-0
No files found.
README.md
View file @
4c07dd28
...
@@ -76,6 +76,7 @@ vLLM seamlessly supports many Hugging Face models, including the following archi
...
@@ -76,6 +76,7 @@ vLLM seamlessly supports many Hugging Face models, including the following archi
-
GPT-NeoX (
`EleutherAI/gpt-neox-20b`
,
`databricks/dolly-v2-12b`
,
`stabilityai/stablelm-tuned-alpha-7b`
, etc.)
-
GPT-NeoX (
`EleutherAI/gpt-neox-20b`
,
`databricks/dolly-v2-12b`
,
`stabilityai/stablelm-tuned-alpha-7b`
, etc.)
-
InternLM (
`internlm/internlm-7b`
,
`internlm/internlm-chat-7b`
, etc.)
-
InternLM (
`internlm/internlm-7b`
,
`internlm/internlm-chat-7b`
, etc.)
-
InternLM2 (
`internlm/internlm2-7b`
,
`internlm/internlm2-chat-7b`
, etc.)
-
InternLM2 (
`internlm/internlm2-7b`
,
`internlm/internlm2-chat-7b`
, etc.)
-
Jais (
`core42/jais-13b`
,
`core42/jais-13b-chat`
,
`core42/jais-30b-v3`
,
`core42/jais-30b-chat-v3`
, etc.)
-
LLaMA & LLaMA-2 (
`meta-llama/Llama-2-70b-hf`
,
`lmsys/vicuna-13b-v1.3`
,
`young-geng/koala`
,
`openlm-research/open_llama_13b`
, etc.)
-
LLaMA & LLaMA-2 (
`meta-llama/Llama-2-70b-hf`
,
`lmsys/vicuna-13b-v1.3`
,
`young-geng/koala`
,
`openlm-research/open_llama_13b`
, etc.)
-
Mistral (
`mistralai/Mistral-7B-v0.1`
,
`mistralai/Mistral-7B-Instruct-v0.1`
, etc.)
-
Mistral (
`mistralai/Mistral-7B-v0.1`
,
`mistralai/Mistral-7B-Instruct-v0.1`
, etc.)
-
Mixtral (
`mistralai/Mixtral-8x7B-v0.1`
,
`mistralai/Mixtral-8x7B-Instruct-v0.1`
, etc.)
-
Mixtral (
`mistralai/Mixtral-8x7B-v0.1`
,
`mistralai/Mixtral-8x7B-Instruct-v0.1`
, etc.)
...
...
docs/source/models/supported_models.rst
View file @
4c07dd28
...
@@ -66,7 +66,11 @@ Alongside each architecture, we include some popular models that use it.
...
@@ -66,7 +66,11 @@ Alongside each architecture, we include some popular models that use it.
* - :code:`InternLM2ForCausalLM`
* - :code:`InternLM2ForCausalLM`
- InternLM2
- InternLM2
- :code:`internlm/internlm2-7b`, :code:`internlm/internlm2-chat-7b`, etc.
- :code:`internlm/internlm2-7b`, :code:`internlm/internlm2-chat-7b`, etc.
-
-
* - :code:`JAISLMHeadModel`
- Jais
- :code:`core42/jais-13b`, :code:`core42/jais-13b-chat`, :code:`core42/jais-30b-v3`, :code:`core42/jais-30b-chat-v3`, etc.
-
* - :code:`LlamaForCausalLM`
* - :code:`LlamaForCausalLM`
- LLaMA, LLaMA-2, Vicuna, Alpaca, Yi
- LLaMA, LLaMA-2, Vicuna, Alpaca, Yi
- :code:`meta-llama/Llama-2-13b-hf`, :code:`meta-llama/Llama-2-70b-hf`, :code:`openlm-research/open_llama_13b`, :code:`lmsys/vicuna-13b-v1.3`, :code:`01-ai/Yi-6B`, :code:`01-ai/Yi-34B`, etc.
- :code:`meta-llama/Llama-2-13b-hf`, :code:`meta-llama/Llama-2-70b-hf`, :code:`openlm-research/open_llama_13b`, :code:`lmsys/vicuna-13b-v1.3`, :code:`01-ai/Yi-6B`, :code:`01-ai/Yi-34B`, etc.
...
...
vllm/model_executor/models/__init__.py
View file @
4c07dd28
...
@@ -27,6 +27,7 @@ _MODELS = {
...
@@ -27,6 +27,7 @@ _MODELS = {
"GPTNeoXForCausalLM"
:
(
"gpt_neox"
,
"GPTNeoXForCausalLM"
),
"GPTNeoXForCausalLM"
:
(
"gpt_neox"
,
"GPTNeoXForCausalLM"
),
"InternLMForCausalLM"
:
(
"llama"
,
"LlamaForCausalLM"
),
"InternLMForCausalLM"
:
(
"llama"
,
"LlamaForCausalLM"
),
"InternLM2ForCausalLM"
:
(
"internlm2"
,
"InternLM2ForCausalLM"
),
"InternLM2ForCausalLM"
:
(
"internlm2"
,
"InternLM2ForCausalLM"
),
"JAISLMHeadModel"
:
(
"jais"
,
"JAISLMHeadModel"
),
"LlamaForCausalLM"
:
(
"llama"
,
"LlamaForCausalLM"
),
"LlamaForCausalLM"
:
(
"llama"
,
"LlamaForCausalLM"
),
# For decapoda-research/llama-*
# For decapoda-research/llama-*
"LLaMAForCausalLM"
:
(
"llama"
,
"LlamaForCausalLM"
),
"LLaMAForCausalLM"
:
(
"llama"
,
"LlamaForCausalLM"
),
...
...
vllm/model_executor/models/gpt2.py
View file @
4c07dd28
...
@@ -242,8 +242,7 @@ class GPT2LMHeadModel(nn.Module):
...
@@ -242,8 +242,7 @@ class GPT2LMHeadModel(nn.Module):
logits
:
torch
.
Tensor
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
self
.
lm_head_weight
,
logits
,
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
sampling_metadata
)
return
next_tokens
return
next_tokens
def
load_weights
(
self
,
def
load_weights
(
self
,
...
...
vllm/model_executor/models/jais.py
0 → 100644
View file @
4c07dd28
# coding=utf-8
# Adapted from
# https://huggingface.co/core42/jais-30b-chat-v3/blob/main/modeling_jais.py
# Copyright 2023 The vLLM team.
# Copyright 2023 the Jais authors and HuggingFace Inc. team. All rights
# reserved.
# Copyright 2023 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Jais model compatible with HuggingFace weights."""
import
math
from
typing
import
List
,
Optional
,
Tuple
import
torch
from
torch
import
nn
from
vllm.transformers_utils.configs
import
JAISConfig
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.attention
import
Attention
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearMethodBase
,
QKVParallelLinear
,
RowParallelLinear
,
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
,
)
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_rank
,
)
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
hf_model_weights_iterator
,
)
from
vllm.sequence
import
SamplerOutput
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
class
SwiGLUActivation
(
nn
.
Module
):
def
forward
(
self
,
x1
:
torch
.
Tensor
,
x2
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
x1
*
nn
.
functional
.
silu
(
x2
)
def
_get_alibi_slopes
(
n
):
def
get_slopes_power_of_2
(
n
):
start
=
2
**
(
-
(
2
**-
(
math
.
log2
(
n
)
-
3
)))
ratio
=
start
return
[
start
*
ratio
**
i
for
i
in
range
(
n
)]
if
math
.
log2
(
n
).
is_integer
():
return
get_slopes_power_of_2
(
n
)
else
:
closest_power_of_2
=
2
**
math
.
floor
(
math
.
log2
(
n
))
return
(
get_slopes_power_of_2
(
closest_power_of_2
)
+
_get_alibi_slopes
(
2
*
closest_power_of_2
)[
0
::
2
][:
n
-
closest_power_of_2
])
class
JAISAttention
(
nn
.
Module
):
def
__init__
(
self
,
config
:
JAISConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
):
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
total_num_heads
=
config
.
num_attention_heads
tensor_model_parallel_world_size
=
(
get_tensor_model_parallel_world_size
())
assert
total_num_heads
%
tensor_model_parallel_world_size
==
0
self
.
num_heads
=
total_num_heads
//
tensor_model_parallel_world_size
self
.
head_dim
=
self
.
hidden_size
//
total_num_heads
if
hasattr
(
config
,
"scale_qk_dot_by_d"
):
config
.
mup_scale_qk_dot_by_d
=
config
.
scale_qk_dot_by_d
self
.
attn_scale_power
=
1.0
if
config
.
mup_scale_qk_dot_by_d
else
0.5
self
.
scale
=
self
.
head_dim
**-
self
.
attn_scale_power
self
.
c_attn
=
QKVParallelLinear
(
self
.
hidden_size
,
self
.
head_dim
,
total_num_heads
,
bias
=
True
,
linear_method
=
linear_method
,
)
self
.
c_proj
=
RowParallelLinear
(
self
.
hidden_size
,
self
.
hidden_size
,
bias
=
True
,
linear_method
=
linear_method
,
)
tp_rank
=
get_tensor_model_parallel_rank
()
head_start
=
tp_rank
*
self
.
num_heads
head_end
=
(
tp_rank
+
1
)
*
self
.
num_heads
alibi_slopes
=
_get_alibi_slopes
(
total_num_heads
)
alibi_slopes
=
alibi_slopes
[
head_start
:
head_end
]
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
scale
=
self
.
scale
,
alibi_slopes
=
alibi_slopes
,
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
c_attn
(
hidden_states
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
key_cache
,
value_cache
=
kv_cache
attn_output
=
self
.
attn
(
q
,
k
,
v
,
key_cache
,
value_cache
,
input_metadata
)
attn_output
,
_
=
self
.
c_proj
(
attn_output
)
return
attn_output
class
JAISMLP
(
nn
.
Module
):
def
__init__
(
self
,
intermediate_size
:
int
,
config
:
JAISConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
):
super
().
__init__
()
hidden_size
=
config
.
hidden_size
self
.
swiglu
=
config
.
activation_function
==
"swiglu"
self
.
c_fc
=
ColumnParallelLinear
(
hidden_size
,
intermediate_size
,
bias
=
True
,
linear_method
=
linear_method
,
)
self
.
c_fc2
=
(
ColumnParallelLinear
(
hidden_size
,
intermediate_size
,
bias
=
True
,
linear_method
=
linear_method
,
)
if
self
.
swiglu
else
None
)
self
.
c_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
True
,
linear_method
=
linear_method
,
)
self
.
act
=
SwiGLUActivation
()
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
self
.
swiglu
:
hidden_states2
,
_
=
self
.
c_fc2
(
hidden_states
)
hidden_states
,
_
=
self
.
c_fc
(
hidden_states
)
hidden_states
=
(
self
.
act
(
hidden_states
,
hidden_states2
)
if
self
.
swiglu
else
self
.
act
(
hidden_states
))
hidden_states
,
_
=
self
.
c_proj
(
hidden_states
)
return
hidden_states
class
JAISBlock
(
nn
.
Module
):
def
__init__
(
self
,
config
:
JAISConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
):
super
().
__init__
()
hidden_size
=
config
.
hidden_size
inner_dim
=
(
config
.
n_inner
if
config
.
n_inner
is
not
None
else
4
*
hidden_size
)
self
.
ln_1
=
nn
.
LayerNorm
(
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
attn
=
JAISAttention
(
config
,
linear_method
)
self
.
ln_2
=
nn
.
LayerNorm
(
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
mlp
=
JAISMLP
(
inner_dim
,
config
,
linear_method
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
residual
=
hidden_states
hidden_states
=
self
.
ln_1
(
hidden_states
)
attn_output
=
self
.
attn
(
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
input_metadata
=
input_metadata
,
)
# residual connection
hidden_states
=
attn_output
+
residual
residual
=
hidden_states
hidden_states
=
self
.
ln_2
(
hidden_states
)
feed_forward_hidden_states
=
self
.
mlp
(
hidden_states
)
# residual connection
hidden_states
=
residual
+
feed_forward_hidden_states
return
hidden_states
class
JAISModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
JAISConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
assert
not
config
.
add_cross_attention
assert
not
config
.
scale_attn_by_inverse_layer_idx
assert
not
config
.
reorder_and_upcast_attn
self
.
embed_dim
=
config
.
hidden_size
self
.
wte
=
VocabParallelEmbedding
(
config
.
vocab_size
,
self
.
embed_dim
)
self
.
wpe
=
(
nn
.
Embedding
(
config
.
max_position_embeddings
,
self
.
embed_dim
)
if
config
.
position_embedding_type
!=
"alibi"
else
None
)
if
hasattr
(
config
,
"embeddings_scale"
):
self
.
embeddings_scale
=
config
.
embeddings_scale
else
:
self
.
embeddings_scale
=
config
.
mup_embeddings_scale
self
.
h
=
nn
.
ModuleList
([
JAISBlock
(
config
,
linear_method
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
self
.
ln_f
=
nn
.
LayerNorm
(
self
.
embed_dim
,
eps
=
config
.
layer_norm_epsilon
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
inputs_embeds
=
self
.
wte
(
input_ids
)
if
self
.
wpe
is
not
None
:
position_embeds
=
self
.
wpe
(
position_ids
)
hidden_states
=
inputs_embeds
+
position_embeds
else
:
hidden_states
=
inputs_embeds
hidden_states
*=
torch
.
tensor
(
float
(
self
.
embeddings_scale
),
dtype
=
hidden_states
.
dtype
)
for
i
in
range
(
len
(
self
.
h
)):
layer
=
self
.
h
[
i
]
hidden_states
=
layer
(
hidden_states
,
kv_caches
[
i
],
input_metadata
)
hidden_states
=
self
.
ln_f
(
hidden_states
)
return
hidden_states
class
JAISLMHeadModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
JAISConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
self
.
linear_method
=
linear_method
self
.
transformer
=
JAISModel
(
config
,
linear_method
)
self
.
lm_head_weight
=
self
.
transformer
.
wte
.
weight
if
hasattr
(
config
,
"width_scale"
):
self
.
output_logits_scale
=
config
.
width_scale
else
:
self
.
output_logits_scale
=
(
config
.
mup_output_alpha
*
config
.
mup_width_scale
)
self
.
logits_processor
=
LogitsProcessor
(
vocab_size
=
config
.
vocab_size
,
scale
=
self
.
output_logits_scale
)
self
.
sampler
=
Sampler
()
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
input_metadata
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
lm_head_weight
,
hidden_states
,
sampling_metadata
)
return
logits
def
sample
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
def
load_weights
(
self
,
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
load_format
:
str
=
"auto"
,
revision
:
Optional
[
str
]
=
None
,
):
params_dict
=
dict
(
self
.
named_parameters
(
remove_duplicate
=
False
))
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
model_name_or_path
,
cache_dir
,
load_format
,
revision
):
if
"lm_head.weight"
in
name
:
# GPT-2 ties the weights of the embedding layer and the final
# linear layer.
continue
if
".attn.bias"
in
name
or
".attn.masked_bias"
in
name
:
# Skip attention mask.
# NOTE: "c_attn.bias" should not be skipped.
continue
if
"relative_pe"
in
name
:
continue
if
not
name
.
startswith
(
"transformer."
):
name
=
"transformer."
+
name
param
=
params_dict
[
name
]
# The HF's GPT-2 implementation uses Conv1D instead of Linear.
# Because of this, we need to transpose the weights.
# Note(zhuohan): the logic below might break quantized models.
for
conv1d_weight_name
in
[
"c_attn"
,
"c_proj"
,
"c_fc"
]:
if
conv1d_weight_name
not
in
name
:
continue
if
not
name
.
endswith
(
".weight"
):
continue
loaded_weight
=
loaded_weight
.
t
()
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
\ No newline at end of file
vllm/transformers_utils/config.py
View file @
4c07dd28
...
@@ -10,6 +10,7 @@ _CONFIG_REGISTRY = {
...
@@ -10,6 +10,7 @@ _CONFIG_REGISTRY = {
"RefinedWeb"
:
RWConfig
,
# For tiiuae/falcon-40b(-instruct)
"RefinedWeb"
:
RWConfig
,
# For tiiuae/falcon-40b(-instruct)
"RefinedWebModel"
:
RWConfig
,
# For tiiuae/falcon-7b(-instruct)
"RefinedWebModel"
:
RWConfig
,
# For tiiuae/falcon-7b(-instruct)
"starcoder2"
:
Starcoder2Config
,
"starcoder2"
:
Starcoder2Config
,
"jais"
:
JAISConfig
,
}
}
...
...
vllm/transformers_utils/configs/__init__.py
View file @
4c07dd28
...
@@ -5,10 +5,12 @@ from vllm.transformers_utils.configs.mpt import MPTConfig
...
@@ -5,10 +5,12 @@ from vllm.transformers_utils.configs.mpt import MPTConfig
# `FalconConfig` class from the official HuggingFace transformers library.
# `FalconConfig` class from the official HuggingFace transformers library.
from
vllm.transformers_utils.configs.falcon
import
RWConfig
from
vllm.transformers_utils.configs.falcon
import
RWConfig
from
vllm.transformers_utils.configs.starcoder2
import
Starcoder2Config
from
vllm.transformers_utils.configs.starcoder2
import
Starcoder2Config
from
vllm.transformers_utils.configs.jais
import
JAISConfig
__all__
=
[
__all__
=
[
"ChatGLMConfig"
,
"ChatGLMConfig"
,
"MPTConfig"
,
"MPTConfig"
,
"RWConfig"
,
"RWConfig"
,
"Starcoder2Config"
,
"Starcoder2Config"
,
"JAISConfig"
,
]
]
vllm/transformers_utils/configs/jais.py
0 → 100644
View file @
4c07dd28
# coding=utf-8
# Copyright 2023 The OpenAI Team Authors and HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
# Copyright 2023 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""JAIS configuration"""
from
transformers.configuration_utils
import
PretrainedConfig
from
transformers.utils
import
logging
logger
=
logging
.
get_logger
(
__name__
)
class
JAISConfig
(
PretrainedConfig
):
"""
This is the configuration class to store the configuration of a
[`JAISModel`]. It is used to instantiate a JAIS model according to the
specified arguments, defining the model architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used
to control the model outputs. Read the documentation from
[`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 50257):
Vocabulary size of the JAIS model. Defines the number of different
tokens that can be represented by the
`inputs_ids` passed when calling [`JAISModel`].
n_positions (`int`, *optional*, defaults to 1024):
The maximum sequence length that this model might ever be used
with. Typically set this to something large just in case
(e.g., 512 or 1024 or 2048).
n_embd (`int`, *optional*, defaults to 768):
Dimensionality of the embeddings and hidden states.
n_layer (`int`, *optional*, defaults to 12):
Number of hidden layers in the Transformer encoder.
n_head (`int`, *optional*, defaults to 12):
Number of attention heads for each attention layer in the
Transformer encoder.
n_inner (`int`, *optional*, defaults to None):
Dimensionality of the inner feed-forward layers. `None` will set
it to 4 times n_embd
activation_function (`str`, *optional*, defaults to `"gelu"`):
Activation function, to be selected in the list
`["relu", "silu", "gelu", "tanh", "gelu_new", "swiglu"]`.
resid_pdrop (`float`, *optional*, defaults to 0.1):
The dropout probability for all fully connected layers in
the embeddings, encoder, and pooler.
embd_pdrop (`float`, *optional*, defaults to 0.1):
The dropout ratio for the embeddings.
attn_pdrop (`float`, *optional*, defaults to 0.1):
The dropout ratio for the attention.
layer_norm_epsilon (`float`, *optional*, defaults to 1e-5):
The epsilon to use in the layer normalization layers.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for
initializing all weight matrices.
scale_attn_weights (`bool`, *optional*, defaults to `True`):
Scale attention weights by dividing by sqrt(hidden_size)..
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values
attentions (not used by all models).
scale_attn_by_inverse_layer_idx (`bool`, *optional*,
defaults to `False`):
Whether to additionally scale attention weights by
`1 / layer_idx + 1`.
reorder_and_upcast_attn (`bool`, *optional*, defaults to `False`):
Whether to scale keys (K) prior to computing attention
(dot-product)
and upcast attention dot-product/softmax to float() when training
with mixed precision.
position_embedding_type (`str`, *optional*, defaults to `"learned"`):
Positional embedding can be either `"alibi"` or `"learned"`.
mup_width_scale (`float`, *optional*, defaults to 1.0):
muP parameter to scale learning rate and initializers. Calculated
as (`d_model,0 / d_model`), where
`d_model` is the model's width and `d_model,0` is the proxy
model's width.
mup_embeddings_scale (`float`, *optional*, defaults to 1.0):
muP parameter to scale token and position embeddings.
mup_output_alpha (`float`, *optional*, defaults to 1.0):
muP parameter to scale output logits
(`output_logits_scale = mup_output_alpha * mup_width_scale`).
mup_scale_qk_dot_by_d (`bool`, *optional*, defaults to `False`):
Scale attention weights by dividing by hidden_size instead of
sqrt(hidden_size). Need to set scale_attn_weights to `True` as
well.
alibi_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for ALiBi
embeddings. Currently only supports linear
scaling strategy. Can specify either the scaling `factor` (must be
a float greater than 1) for fixed scaling
or `train_seq_len` for dynamic scaling on input samples with
sequence length > `train_seq_len`. The expected
formats are `{"type": strategy name, "factor": scaling factor}` or
`{"type": strategy name,
"train_seq_len": training sequence length}`.
architectures (`List`, *optional*, defaults to ['JAISLMHeadModel']):
architecture names for Jais.
Example:
```python
>>> from transformers import JAISConfig, JAISModel
>>> # Initializing a JAIS configuration
>>> configuration = JAISConfig()
>>> # Initializing a model (with random weights) from the configuration
>>> model = JAISModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type
=
"jais"
keys_to_ignore_at_inference
=
[
"past_key_values"
]
attribute_map
=
{
"hidden_size"
:
"n_embd"
,
"max_position_embeddings"
:
"n_positions"
,
"num_attention_heads"
:
"n_head"
,
"num_hidden_layers"
:
"n_layer"
,
}
def
__init__
(
self
,
vocab_size
=
50257
,
n_positions
=
1024
,
n_embd
=
768
,
n_layer
=
12
,
n_head
=
12
,
n_inner
=
None
,
activation_function
=
"gelu_new"
,
resid_pdrop
=
0.1
,
embd_pdrop
=
0.1
,
attn_pdrop
=
0.1
,
layer_norm_epsilon
=
1e-5
,
initializer_range
=
0.02
,
scale_attn_weights
=
True
,
use_cache
=
True
,
bos_token_id
=
50256
,
eos_token_id
=
50256
,
scale_attn_by_inverse_layer_idx
=
False
,
reorder_and_upcast_attn
=
False
,
position_embedding_type
=
"learned"
,
mup_width_scale
=
1.0
,
mup_embeddings_scale
=
1.0
,
mup_output_alpha
=
1.0
,
mup_scale_qk_dot_by_d
=
False
,
alibi_scaling
=
None
,
architectures
=
None
,
**
kwargs
,
):
self
.
vocab_size
=
vocab_size
self
.
n_positions
=
n_positions
self
.
n_embd
=
n_embd
self
.
n_layer
=
n_layer
self
.
n_head
=
n_head
self
.
n_inner
=
n_inner
self
.
activation_function
=
activation_function
self
.
resid_pdrop
=
resid_pdrop
self
.
embd_pdrop
=
embd_pdrop
self
.
attn_pdrop
=
attn_pdrop
self
.
layer_norm_epsilon
=
layer_norm_epsilon
self
.
initializer_range
=
initializer_range
self
.
scale_attn_weights
=
scale_attn_weights
self
.
use_cache
=
use_cache
self
.
scale_attn_by_inverse_layer_idx
=
scale_attn_by_inverse_layer_idx
self
.
reorder_and_upcast_attn
=
reorder_and_upcast_attn
self
.
bos_token_id
=
bos_token_id
self
.
eos_token_id
=
eos_token_id
self
.
position_embedding_type
=
position_embedding_type
self
.
mup_width_scale
=
mup_width_scale
self
.
mup_embeddings_scale
=
mup_embeddings_scale
self
.
mup_output_alpha
=
mup_output_alpha
self
.
mup_scale_qk_dot_by_d
=
mup_scale_qk_dot_by_d
self
.
alibi_scaling
=
alibi_scaling
self
.
_alibi_scaling_validation
()
if
architectures
is
None
:
architectures
=
[
"JAISLMHeadModel"
]
super
().
__init__
(
bos_token_id
=
bos_token_id
,
eos_token_id
=
eos_token_id
,
architectures
=
architectures
,
**
kwargs
,
)
def
_alibi_scaling_validation
(
self
):
"""
Validate the `alibi_scaling` configuration.
"""
if
self
.
alibi_scaling
is
None
:
return
if
(
not
isinstance
(
self
.
alibi_scaling
,
dict
)
or
len
(
self
.
alibi_scaling
)
!=
2
):
raise
ValueError
(
"`alibi_scaling` must be a dictionary with two fields,"
"`type` and `factor` or `type` and `train_seq_len`, "
f
"got
{
self
.
alibi_scaling
}
"
)
alibi_scaling_type
=
self
.
alibi_scaling
.
get
(
"type"
,
None
)
alibi_scaling_factor
=
self
.
alibi_scaling
.
get
(
"factor"
,
None
)
alibi_dynamic_scaling
=
self
.
alibi_scaling
.
get
(
"train_seq_len"
,
None
)
if
alibi_scaling_type
is
None
or
alibi_scaling_type
!=
"linear"
:
raise
ValueError
(
f
"`alibi_scaling`'s type field must be 'linear',"
f
"got
{
alibi_scaling_type
}
"
)
if
(
alibi_scaling_factor
is
not
None
and
not
isinstance
(
alibi_scaling_factor
,
float
)
or
alibi_scaling_factor
<=
1.0
):
raise
ValueError
(
f
"`alibi_scaling`'s factor field must be a float > 1.0,"
f
"got
{
alibi_scaling_factor
}
"
)
if
(
alibi_dynamic_scaling
is
not
None
and
not
isinstance
(
alibi_dynamic_scaling
,
int
)
or
alibi_dynamic_scaling
<=
1
):
raise
ValueError
(
f
"`alibi_scaling`'s `train_seq_len` field must be an"
f
"integer > 1, got
{
alibi_dynamic_scaling
}
"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment