Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
1a2bbc93
"tests/git@developer.sourcefind.cn:OpenDAS/fairscale.git" did not exist on "8a42a8e3ff1c33ccde9ee603426b17a29cfead3c"
Unverified
Commit
1a2bbc93
authored
Nov 07, 2023
by
GoHomeToMacDonal
Committed by
GitHub
Nov 06, 2023
Browse files
ChatGLM Support (#1261)
parent
e7f579eb
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
490 additions
and
4 deletions
+490
-4
vllm/config.py
vllm/config.py
+4
-0
vllm/model_executor/model_loader.py
vllm/model_executor/model_loader.py
+1
-0
vllm/model_executor/models/__init__.py
vllm/model_executor/models/__init__.py
+2
-0
vllm/model_executor/models/chatglm.py
vllm/model_executor/models/chatglm.py
+408
-0
vllm/transformers_utils/config.py
vllm/transformers_utils/config.py
+3
-2
vllm/transformers_utils/configs/__init__.py
vllm/transformers_utils/configs/__init__.py
+4
-2
vllm/transformers_utils/configs/chatglm.py
vllm/transformers_utils/configs/chatglm.py
+68
-0
No files found.
vllm/config.py
View file @
1a2bbc93
...
@@ -166,6 +166,10 @@ class ModelConfig:
...
@@ -166,6 +166,10 @@ class ModelConfig:
if
getattr
(
self
.
hf_config
,
"num_key_value_heads"
,
None
)
is
not
None
:
if
getattr
(
self
.
hf_config
,
"num_key_value_heads"
,
None
)
is
not
None
:
return
(
self
.
hf_config
.
num_key_value_heads
//
return
(
self
.
hf_config
.
num_key_value_heads
//
parallel_config
.
tensor_parallel_size
)
parallel_config
.
tensor_parallel_size
)
# For ChatGLM-2:
if
getattr
(
self
.
hf_config
,
"multi_query_group_num"
,
None
)
is
not
None
:
return
(
self
.
hf_config
.
multi_query_group_num
//
parallel_config
.
tensor_parallel_size
)
total_num_attention_heads
=
self
.
hf_config
.
num_attention_heads
total_num_attention_heads
=
self
.
hf_config
.
num_attention_heads
return
total_num_attention_heads
//
parallel_config
.
tensor_parallel_size
return
total_num_attention_heads
//
parallel_config
.
tensor_parallel_size
...
...
vllm/model_executor/model_loader.py
View file @
1a2bbc93
...
@@ -18,6 +18,7 @@ _MODEL_REGISTRY = {
...
@@ -18,6 +18,7 @@ _MODEL_REGISTRY = {
"BaiChuanForCausalLM"
:
BaiChuanForCausalLM
,
# baichuan-7b
"BaiChuanForCausalLM"
:
BaiChuanForCausalLM
,
# baichuan-7b
"BaichuanForCausalLM"
:
BaichuanForCausalLM
,
# baichuan-13b
"BaichuanForCausalLM"
:
BaichuanForCausalLM
,
# baichuan-13b
"BloomForCausalLM"
:
BloomForCausalLM
,
"BloomForCausalLM"
:
BloomForCausalLM
,
"ChatGLMModel"
:
ChatGLMForCausalLM
,
"FalconForCausalLM"
:
FalconForCausalLM
,
"FalconForCausalLM"
:
FalconForCausalLM
,
"GPT2LMHeadModel"
:
GPT2LMHeadModel
,
"GPT2LMHeadModel"
:
GPT2LMHeadModel
,
"GPTBigCodeForCausalLM"
:
GPTBigCodeForCausalLM
,
"GPTBigCodeForCausalLM"
:
GPTBigCodeForCausalLM
,
...
...
vllm/model_executor/models/__init__.py
View file @
1a2bbc93
...
@@ -13,6 +13,7 @@ from vllm.model_executor.models.mistral import MistralForCausalLM
...
@@ -13,6 +13,7 @@ from vllm.model_executor.models.mistral import MistralForCausalLM
from
vllm.model_executor.models.mpt
import
MptForCausalLM
from
vllm.model_executor.models.mpt
import
MptForCausalLM
from
vllm.model_executor.models.opt
import
OPTForCausalLM
from
vllm.model_executor.models.opt
import
OPTForCausalLM
from
vllm.model_executor.models.qwen
import
QWenLMHeadModel
from
vllm.model_executor.models.qwen
import
QWenLMHeadModel
from
vllm.model_executor.models.chatglm
import
ChatGLMForCausalLM
from
vllm.model_executor.models.yi
import
YiForCausalLM
from
vllm.model_executor.models.yi
import
YiForCausalLM
__all__
=
[
__all__
=
[
...
@@ -20,6 +21,7 @@ __all__ = [
...
@@ -20,6 +21,7 @@ __all__ = [
"BaiChuanForCausalLM"
,
"BaiChuanForCausalLM"
,
"BaichuanForCausalLM"
,
"BaichuanForCausalLM"
,
"BloomForCausalLM"
,
"BloomForCausalLM"
,
"ChatGLMForCausalLM"
,
"FalconForCausalLM"
,
"FalconForCausalLM"
,
"GPT2LMHeadModel"
,
"GPT2LMHeadModel"
,
"GPTBigCodeForCausalLM"
,
"GPTBigCodeForCausalLM"
,
...
...
vllm/model_executor/models/chatglm.py
0 → 100644
View file @
1a2bbc93
# coding=utf-8
# Adapted from
# https://github.com/THUDM/ChatGLM2-6B
"""Inference-only ChatGLM model compatible with THUDM weights.
The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input.
"""
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
torch
from
torch
import
nn
from
torch.nn
import
LayerNorm
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.attention
import
PagedAttentionWithRoPE
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.weight_utils
import
(
hf_model_weights_iterator
,
load_tensor_parallel_weights
,
)
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
)
from
vllm.model_executor.parallel_utils.layers
import
VocabParallelEmbedding
from
vllm.model_executor.parallel_utils.layers
import
(
ColumnParallelLinear
,
RowParallelLinear
,
)
from
vllm.sequence
import
SequenceOutputs
from
vllm.transformers_utils.configs
import
ChatGLMConfig
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
class
GLMAttention
(
nn
.
Module
):
def
__init__
(
self
,
config
):
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
total_num_heads
=
config
.
num_attention_heads
assert
self
.
total_num_heads
%
tp_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
tp_size
self
.
multi_query_attention
=
config
.
multi_query_attention
self
.
total_num_kv_heads
=
(
config
.
multi_query_group_num
if
config
.
multi_query_attention
else
config
.
num_attention_heads
)
assert
self
.
total_num_kv_heads
%
tp_size
==
0
self
.
num_kv_heads
=
self
.
total_num_kv_heads
//
tp_size
self
.
head_dim
=
config
.
hidden_size
//
self
.
total_num_heads
self
.
q_size
=
self
.
num_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_kv_heads
*
self
.
head_dim
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
query_key_value
=
ColumnParallelLinear
(
config
.
hidden_size
,
(
self
.
total_num_heads
+
2
*
self
.
total_num_kv_heads
)
*
self
.
head_dim
,
bias
=
config
.
add_qkv_bias
,
gather_output
=
False
,
)
self
.
dense
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
config
.
hidden_size
,
bias
=
config
.
add_bias_linear
,
input_is_parallel
=
True
,
)
self
.
attn
=
PagedAttentionWithRoPE
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
rotary_dim
=
self
.
head_dim
//
2
,
num_kv_heads
=
self
.
num_kv_heads
,
is_neox_style
=
False
,
# is_glm_style=True
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input_metadata
:
InputMetadata
,
cache_event
:
Optional
[
torch
.
cuda
.
Event
],
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
query_key_value
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
key_cache
,
value_cache
=
kv_cache
context_layer
=
self
.
attn
(
position_ids
,
q
,
k
,
v
,
key_cache
,
value_cache
,
input_metadata
,
cache_event
,
)
attn_output
,
_
=
self
.
dense
(
context_layer
)
return
attn_output
class
GLMMLP
(
nn
.
Module
):
"""MLP.
MLP will take the input with h hidden state, project it to 4*h
hidden dimension, perform nonlinear transformation, and project the
state back into h hidden dimension.
"""
def
__init__
(
self
,
config
):
super
().
__init__
()
self
.
add_bias
=
config
.
add_bias_linear
# Project to 4h.
self
.
dense_h_to_4h
=
ColumnParallelLinear
(
config
.
hidden_size
,
config
.
ffn_hidden_size
*
2
,
bias
=
config
.
add_bias_linear
,
gather_output
=
False
,
)
self
.
activation_func
=
SiluAndMul
()
# Project back to h.
self
.
dense_4h_to_h
=
RowParallelLinear
(
config
.
ffn_hidden_size
,
config
.
hidden_size
,
bias
=
config
.
add_bias_linear
,
input_is_parallel
=
True
,
)
def
forward
(
self
,
hidden_states
):
# [s, b, 4hp]
intermediate_parallel
,
_
=
self
.
dense_h_to_4h
(
hidden_states
)
intermediate_parallel
=
self
.
activation_func
(
intermediate_parallel
)
# [s, b, h]
output
,
_
=
self
.
dense_4h_to_h
(
intermediate_parallel
)
return
output
class
GLMBlock
(
nn
.
Module
):
"""A single transformer layer.
Transformer layer takes input with size [s, b, h] and returns an
output of the same size.
"""
def
__init__
(
self
,
config
,
):
super
().
__init__
()
self
.
apply_residual_connection_post_layernorm
=
(
config
.
apply_residual_connection_post_layernorm
)
self
.
fp32_residual_connection
=
config
.
fp32_residual_connection
layer_norm_func
=
RMSNorm
if
config
.
rmsnorm
else
LayerNorm
# Layernorm on the input data.
self
.
input_layernorm
=
layer_norm_func
(
config
.
hidden_size
,
eps
=
config
.
layernorm_epsilon
)
# Self attention.
self
.
self_attention
=
GLMAttention
(
config
)
self
.
hidden_dropout
=
config
.
hidden_dropout
# Layernorm on the attention output
self
.
post_attention_layernorm
=
layer_norm_func
(
config
.
hidden_size
,
eps
=
config
.
layernorm_epsilon
)
# MLP
self
.
mlp
=
GLMMLP
(
config
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input_metadata
:
InputMetadata
,
cache_event
:
Optional
[
torch
.
cuda
.
Event
],
)
->
torch
.
Tensor
:
# hidden_states: [num_tokens, h]
# Layer norm at the beginning of the transformer layer.
layernorm_output
=
self
.
input_layernorm
(
hidden_states
)
# Self attention.
attention_output
=
self
.
self_attention
(
hidden_states
=
layernorm_output
,
position_ids
=
position_ids
,
kv_cache
=
kv_cache
,
input_metadata
=
input_metadata
,
cache_event
=
cache_event
,
)
# Residual connection.
if
self
.
apply_residual_connection_post_layernorm
:
residual
=
layernorm_output
else
:
residual
=
hidden_states
layernorm_input
=
residual
+
attention_output
# Layer norm post the self attention.
layernorm_output
=
self
.
post_attention_layernorm
(
layernorm_input
)
# Second residual connection.
if
self
.
apply_residual_connection_post_layernorm
:
residual
=
layernorm_output
else
:
residual
=
layernorm_input
output
=
self
.
mlp
(
layernorm_output
)
+
residual
return
output
class
GLMTransformer
(
nn
.
Module
):
"""Transformer class."""
def
__init__
(
self
,
config
):
super
().
__init__
()
self
.
post_layer_norm
=
config
.
post_layer_norm
# Number of layers.
self
.
num_layers
=
config
.
num_layers
# Transformer layers.
self
.
layers
=
nn
.
ModuleList
(
[
GLMBlock
(
config
)
for
i
in
range
(
self
.
num_layers
)])
if
self
.
post_layer_norm
:
layer_norm_func
=
RMSNorm
if
config
.
rmsnorm
else
LayerNorm
# Final layer norm before output.
self
.
final_layernorm
=
layer_norm_func
(
config
.
hidden_size
,
eps
=
config
.
layernorm_epsilon
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input_metadata
:
InputMetadata
,
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
)
->
torch
.
Tensor
:
for
i
in
range
(
self
.
num_layers
):
if
cache_events
is
None
:
cache_event
=
None
else
:
cache_event
=
cache_events
[
i
]
layer
=
self
.
layers
[
i
]
hidden_states
=
layer
(
hidden_states
=
hidden_states
,
position_ids
=
position_ids
,
kv_cache
=
kv_caches
[
i
],
input_metadata
=
input_metadata
,
cache_event
=
cache_event
,
)
# Final layer norm.
if
self
.
post_layer_norm
:
hidden_states
=
self
.
final_layernorm
(
hidden_states
)
return
hidden_states
class
ChatGLMModel
(
nn
.
Module
):
def
__init__
(
self
,
config
):
super
().
__init__
()
self
.
embedding
=
VocabParallelEmbedding
(
config
.
padded_vocab_size
,
config
.
hidden_size
)
self
.
num_layers
=
config
.
num_layers
self
.
multi_query_group_num
=
config
.
multi_query_group_num
self
.
kv_channels
=
config
.
kv_channels
self
.
encoder
=
GLMTransformer
(
config
)
self
.
output_layer
=
ColumnParallelLinear
(
config
.
hidden_size
,
config
.
padded_vocab_size
,
bias
=
False
,
gather_output
=
False
,
params_dtype
=
config
.
torch_dtype
,
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input_metadata
:
InputMetadata
,
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
):
inputs_embeds
=
self
.
embedding
(
input_ids
)
# Run encoder.
hidden_states
=
self
.
encoder
(
hidden_states
=
inputs_embeds
,
position_ids
=
position_ids
,
kv_caches
=
kv_caches
,
input_metadata
=
input_metadata
,
cache_events
=
cache_events
,
)
return
hidden_states
class
ChatGLMForCausalLM
(
nn
.
Module
):
def
__init__
(
self
,
config
:
ChatGLMConfig
):
super
().
__init__
()
self
.
config
:
ChatGLMConfig
=
config
self
.
transformer
=
ChatGLMModel
(
config
)
self
.
lm_head_weight
=
self
.
transformer
.
output_layer
.
weight
self
.
sampler
=
Sampler
(
config
.
padded_vocab_size
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input_metadata
:
InputMetadata
,
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
)
->
Dict
[
int
,
SequenceOutputs
]:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
input_metadata
,
cache_events
)
next_tokens
=
self
.
sampler
(
self
.
lm_head_weight
,
hidden_states
,
input_metadata
)
return
next_tokens
_column_parallel_weights
=
[
"output_layer.weight"
,
"embedding.weight"
,
]
_row_parallel_weights
=
[
"dense_4h_to_h"
,
"self_attention.dense"
]
def
load_weights
(
self
,
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
load_format
:
str
=
"auto"
,
revision
:
Optional
[
str
]
=
None
,
):
tp_rank
=
get_tensor_model_parallel_rank
()
tp_size
=
get_tensor_model_parallel_world_size
()
q_proj_shard_size
=
self
.
config
.
hidden_size
//
tp_size
kv_proj_shard_size
=
(
self
.
config
.
hidden_size
//
self
.
config
.
num_attention_heads
*
self
.
config
.
multi_query_group_num
//
tp_size
)
mlp_hidden_shard_size
=
self
.
config
.
ffn_hidden_size
//
tp_size
state_dict
=
self
.
state_dict
()
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
model_name_or_path
,
cache_dir
,
load_format
,
revision
):
if
"word_embeddings"
in
name
:
name
=
name
.
replace
(
".word_embeddings"
,
""
)
if
name
in
state_dict
:
param
=
state_dict
[
name
]
if
"query_key_value"
in
name
:
q_offset
=
q_proj_shard_size
*
tp_rank
k_offset
=
(
q_proj_shard_size
*
tp_size
+
kv_proj_shard_size
*
tp_rank
)
v_offset
=
(
q_proj_shard_size
*
tp_size
+
kv_proj_shard_size
*
(
tp_size
+
tp_rank
))
wq
=
loaded_weight
[
q_offset
:
q_offset
+
q_proj_shard_size
]
wk
=
loaded_weight
[
k_offset
:
k_offset
+
kv_proj_shard_size
]
wv
=
loaded_weight
[
v_offset
:
v_offset
+
kv_proj_shard_size
]
loaded_weight
=
torch
.
cat
([
wq
,
wk
,
wv
],
dim
=
0
)
param
.
data
.
copy_
(
loaded_weight
)
continue
if
"dense_h_to_4h"
in
name
:
w_gate
=
loaded_weight
[
mlp_hidden_shard_size
*
tp_rank
:
mlp_hidden_shard_size
*
(
tp_rank
+
1
)]
w_proj
=
loaded_weight
[
mlp_hidden_shard_size
*
(
tp_size
+
tp_rank
):
mlp_hidden_shard_size
*
(
tp_size
+
tp_rank
+
1
)]
loaded_weight
=
torch
.
cat
([
w_gate
,
w_proj
],
dim
=
0
)
param
.
data
.
copy_
(
loaded_weight
)
continue
load_tensor_parallel_weights
(
param
,
loaded_weight
,
name
,
self
.
_column_parallel_weights
,
self
.
_row_parallel_weights
,
tp_rank
,
)
elif
name
==
"transformer.rotary_pos_emb.inv_freq"
:
continue
else
:
print
(
"Warning never found tensor's name:"
,
name
)
vllm/transformers_utils/config.py
View file @
1a2bbc93
...
@@ -5,9 +5,10 @@ from transformers import AutoConfig, MptConfig, PretrainedConfig
...
@@ -5,9 +5,10 @@ from transformers import AutoConfig, MptConfig, PretrainedConfig
from
vllm.transformers_utils.configs
import
*
# pylint: disable=wildcard-import
from
vllm.transformers_utils.configs
import
*
# pylint: disable=wildcard-import
_CONFIG_REGISTRY
=
{
_CONFIG_REGISTRY
=
{
"mpt"
:
MptConfig
,
"baichuan"
:
BaiChuanConfig
,
"aquila"
:
AquilaConfig
,
"aquila"
:
AquilaConfig
,
"baichuan"
:
BaiChuanConfig
,
"chatglm"
:
ChatGLMConfig
,
"mpt"
:
MptConfig
,
"qwen"
:
QWenConfig
,
"qwen"
:
QWenConfig
,
"RefinedWeb"
:
RWConfig
,
# For tiiuae/falcon-40b(-instruct)
"RefinedWeb"
:
RWConfig
,
# For tiiuae/falcon-40b(-instruct)
"RefinedWebModel"
:
RWConfig
,
# For tiiuae/falcon-7b(-instruct)
"RefinedWebModel"
:
RWConfig
,
# For tiiuae/falcon-7b(-instruct)
...
...
vllm/transformers_utils/configs/__init__.py
View file @
1a2bbc93
from
vllm.transformers_utils.configs.baichuan
import
BaiChuanConfig
from
vllm.transformers_utils.configs.aquila
import
AquilaConfig
from
vllm.transformers_utils.configs.aquila
import
AquilaConfig
from
vllm.transformers_utils.configs.baichuan
import
BaiChuanConfig
from
vllm.transformers_utils.configs.chatglm
import
ChatGLMConfig
from
vllm.transformers_utils.configs.qwen
import
QWenConfig
from
vllm.transformers_utils.configs.qwen
import
QWenConfig
# RWConfig is for the original tiiuae/falcon-40b(-instruct) and
# RWConfig is for the original tiiuae/falcon-40b(-instruct) and
# tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the
# tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the
...
@@ -8,8 +9,9 @@ from vllm.transformers_utils.configs.falcon import RWConfig
...
@@ -8,8 +9,9 @@ from vllm.transformers_utils.configs.falcon import RWConfig
from
vllm.transformers_utils.configs.yi
import
YiConfig
from
vllm.transformers_utils.configs.yi
import
YiConfig
__all__
=
[
__all__
=
[
"BaiChuanConfig"
,
"AquilaConfig"
,
"AquilaConfig"
,
"BaiChuanConfig"
,
"ChatGLMConfig"
,
"QWenConfig"
,
"QWenConfig"
,
"RWConfig"
,
"RWConfig"
,
"YiConfig"
,
"YiConfig"
,
...
...
vllm/transformers_utils/configs/chatglm.py
0 → 100644
View file @
1a2bbc93
# coding=utf-8
# Adapted from
# https://github.com/THUDM/ChatGLM2-6B
from
transformers
import
PretrainedConfig
class
ChatGLMConfig
(
PretrainedConfig
):
model_type
=
"chatglm"
attribute_map
=
{
"num_hidden_layers"
:
"num_layers"
,
"n_head_kv"
:
"multi_query_group_num"
,
}
def
__init__
(
self
,
num_layers
=
28
,
padded_vocab_size
=
65024
,
hidden_size
=
4096
,
ffn_hidden_size
=
13696
,
kv_channels
=
128
,
num_attention_heads
=
32
,
seq_length
=
2048
,
hidden_dropout
=
0.0
,
attention_dropout
=
0.0
,
layernorm_epsilon
=
1e-5
,
rmsnorm
=
True
,
apply_residual_connection_post_layernorm
=
False
,
post_layer_norm
=
True
,
add_bias_linear
=
False
,
add_qkv_bias
=
False
,
interleaved_qkv
=
False
,
bias_dropout_fusion
=
True
,
multi_query_attention
=
False
,
multi_query_group_num
=
1
,
apply_query_key_layer_scaling
=
True
,
attention_softmax_in_fp32
=
True
,
fp32_residual_connection
=
False
,
quantization_bit
=
0
,
pre_seq_len
=
None
,
prefix_projection
=
False
,
**
kwargs
):
self
.
num_layers
=
num_layers
self
.
vocab_size
=
padded_vocab_size
self
.
padded_vocab_size
=
padded_vocab_size
self
.
hidden_size
=
hidden_size
self
.
ffn_hidden_size
=
ffn_hidden_size
self
.
kv_channels
=
kv_channels
self
.
num_attention_heads
=
num_attention_heads
self
.
seq_length
=
seq_length
self
.
hidden_dropout
=
hidden_dropout
self
.
attention_dropout
=
attention_dropout
self
.
layernorm_epsilon
=
layernorm_epsilon
self
.
rmsnorm
=
rmsnorm
self
.
apply_residual_connection_post_layernorm
=
(
apply_residual_connection_post_layernorm
)
self
.
post_layer_norm
=
post_layer_norm
self
.
add_bias_linear
=
add_bias_linear
self
.
add_qkv_bias
=
add_qkv_bias
self
.
bias_dropout_fusion
=
bias_dropout_fusion
self
.
multi_query_attention
=
multi_query_attention
self
.
multi_query_group_num
=
multi_query_group_num
self
.
apply_query_key_layer_scaling
=
apply_query_key_layer_scaling
self
.
attention_softmax_in_fp32
=
attention_softmax_in_fp32
self
.
fp32_residual_connection
=
fp32_residual_connection
self
.
quantization_bit
=
quantization_bit
self
.
pre_seq_len
=
pre_seq_len
self
.
prefix_projection
=
prefix_projection
self
.
interleaved_qkv
=
interleaved_qkv
super
().
__init__
(
**
kwargs
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment