Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
20b0d88d
Unverified
Commit
20b0d88d
authored
Jul 17, 2023
by
codethazine
Committed by
GitHub
Jul 17, 2023
Browse files
Add support for baichuan (#365)
parent
2bdea7ac
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
361 additions
and
0 deletions
+361
-0
vllm/model_executor/model_loader.py
vllm/model_executor/model_loader.py
+1
-0
vllm/model_executor/models/__init__.py
vllm/model_executor/models/__init__.py
+2
-0
vllm/model_executor/models/baichuan.py
vllm/model_executor/models/baichuan.py
+293
-0
vllm/transformers_utils/config.py
vllm/transformers_utils/config.py
+1
-0
vllm/transformers_utils/configs/__init__.py
vllm/transformers_utils/configs/__init__.py
+2
-0
vllm/transformers_utils/configs/baichuan.py
vllm/transformers_utils/configs/baichuan.py
+62
-0
No files found.
vllm/model_executor/model_loader.py
View file @
20b0d88d
...
...
@@ -11,6 +11,7 @@ from vllm.model_executor.weight_utils import initialize_dummy_weights
# TODO(woosuk): Lazy-load the model classes.
_MODEL_REGISTRY
=
{
"BaiChuanForCausalLM"
:
BaiChuanForCausalLM
,
"BloomForCausalLM"
:
BloomForCausalLM
,
"GPT2LMHeadModel"
:
GPT2LMHeadModel
,
"GPTBigCodeForCausalLM"
:
GPTBigCodeForCausalLM
,
...
...
vllm/model_executor/models/__init__.py
View file @
20b0d88d
from
vllm.model_executor.models.baichuan
import
BaiChuanForCausalLM
from
vllm.model_executor.models.bloom
import
BloomForCausalLM
from
vllm.model_executor.models.gpt2
import
GPT2LMHeadModel
from
vllm.model_executor.models.gpt_bigcode
import
GPTBigCodeForCausalLM
...
...
@@ -8,6 +9,7 @@ from vllm.model_executor.models.mpt import MPTForCausalLM
from
vllm.model_executor.models.opt
import
OPTForCausalLM
__all__
=
[
"BaiChuanForCausalLM"
,
"BloomForCausalLM"
,
"GPT2LMHeadModel"
,
"GPTBigCodeForCausalLM"
,
...
...
vllm/model_executor/models/baichuan.py
0 → 100644
View file @
20b0d88d
# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only BaiChuan model compatible with HuggingFace weights.
The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input.
"""
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
torch
from
torch
import
nn
from
vllm.sequence
import
SequenceOutputs
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.attention
import
PagedAttentionWithRoPE
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.weight_utils
import
(
hf_model_weights_iterator
,
load_tensor_parallel_weights
)
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.parallel_utils.tensor_parallel
import
(
VocabParallelEmbedding
,
ColumnParallelLinear
,
RowParallelLinear
)
from
vllm.transformers_utils.configs.baichuan
import
BaiChuanConfig
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
class
BaiChuanMLP
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
intermediate_size
:
int
,
hidden_act
:
str
,
):
super
().
__init__
()
self
.
gate_up_proj
=
ColumnParallelLinear
(
hidden_size
,
2
*
intermediate_size
,
bias
=
False
,
gather_output
=
False
,
perform_initialization
=
False
)
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
False
,
input_is_parallel
=
True
,
perform_initialization
=
False
)
if
hidden_act
!=
"silu"
:
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. "
"Only silu is supported for now."
)
self
.
act_fn
=
SiluAndMul
()
def
forward
(
self
,
x
):
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
x
=
self
.
act_fn
(
gate_up
)
x
,
_
=
self
.
down_proj
(
x
)
return
x
class
BaiChuanAttention
(
nn
.
Module
):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def
__init__
(
self
,
hidden_size
:
int
,
num_heads
:
int
,
):
super
().
__init__
()
self
.
hidden_size
=
hidden_size
tensor_model_parallel_world_size
=
get_tensor_model_parallel_world_size
(
)
self
.
total_num_heads
=
num_heads
assert
self
.
total_num_heads
%
tensor_model_parallel_world_size
==
0
self
.
num_heads
=
(
self
.
total_num_heads
//
tensor_model_parallel_world_size
)
self
.
head_dim
=
hidden_size
//
self
.
total_num_heads
self
.
scaling
=
self
.
head_dim
**-
0.5
# pylint: disable=invalid-name
self
.
W_pack
=
ColumnParallelLinear
(
hidden_size
,
3
*
hidden_size
,
bias
=
False
,
gather_output
=
False
,
perform_initialization
=
False
,
)
self
.
o_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
hidden_size
,
bias
=
False
,
input_is_parallel
=
True
,
perform_initialization
=
False
,
)
self
.
attn
=
PagedAttentionWithRoPE
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
rotary_dim
=
self
.
head_dim
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input_metadata
:
InputMetadata
,
cache_event
:
Optional
[
torch
.
cuda
.
Event
],
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
W_pack
(
hidden_states
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
k_cache
,
v_cache
=
kv_cache
attn_output
=
self
.
attn
(
positions
,
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
,
cache_event
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
class
BaiChuanDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
BaiChuanConfig
):
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
self
.
self_attn
=
BaiChuanAttention
(
hidden_size
=
self
.
hidden_size
,
num_heads
=
config
.
num_attention_heads
,
)
self
.
mlp
=
BaiChuanMLP
(
hidden_size
=
self
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input_metadata
:
InputMetadata
,
cache_event
:
Optional
[
torch
.
cuda
.
Event
],
)
->
torch
.
Tensor
:
# Self Attention
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
input_metadata
=
input_metadata
,
cache_event
=
cache_event
,
)
hidden_states
=
residual
+
hidden_states
# Fully Connected
residual
=
hidden_states
hidden_states
=
self
.
post_attention_layernorm
(
hidden_states
)
hidden_states
=
self
.
mlp
(
hidden_states
)
hidden_states
=
residual
+
hidden_states
return
hidden_states
class
BaiChuanModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
BaiChuanConfig
):
super
().
__init__
()
self
.
config
=
config
self
.
padding_idx
=
config
.
pad_token_id
self
.
vocab_size
=
config
.
vocab_size
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
,
perform_initialization
=
False
)
self
.
layers
=
nn
.
ModuleList
([
BaiChuanDecoderLayer
(
config
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input_metadata
:
InputMetadata
,
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
for
i
in
range
(
len
(
self
.
layers
)):
if
cache_events
is
None
:
cache_event
=
None
else
:
cache_event
=
cache_events
[
i
]
layer
=
self
.
layers
[
i
]
hidden_states
=
layer
(
positions
,
hidden_states
,
kv_caches
[
i
],
input_metadata
,
cache_event
,
)
hidden_states
=
self
.
norm
(
hidden_states
)
return
hidden_states
class
BaiChuanForCausalLM
(
nn
.
Module
):
def
__init__
(
self
,
config
):
super
().
__init__
()
self
.
config
=
config
self
.
model
=
BaiChuanModel
(
config
)
self
.
lm_head
=
ColumnParallelLinear
(
config
.
hidden_size
,
config
.
vocab_size
,
bias
=
False
,
gather_output
=
False
,
perform_initialization
=
False
)
self
.
sampler
=
Sampler
(
config
.
vocab_size
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input_metadata
:
InputMetadata
,
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
)
->
Dict
[
int
,
SequenceOutputs
]:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
input_metadata
,
cache_events
)
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
input_metadata
)
return
next_tokens
_column_parallel_weights
=
[
"embed_tokens.weight"
,
"lm_head.weight"
,
"W_pack.weight"
,
"gate_proj.weight"
,
"up_proj.weight"
]
_row_parallel_weights
=
[
"o_proj.weight"
,
"down_proj.weight"
]
def
load_weights
(
self
,
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
use_np_cache
:
bool
=
False
):
tensor_model_parallel_rank
=
get_tensor_model_parallel_rank
()
state_dict
=
self
.
state_dict
()
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
model_name_or_path
,
cache_dir
,
use_np_cache
):
if
"rotary_emb.inv_freq"
in
name
:
continue
is_gate_up_weight
=
False
for
stride_id
,
weight_name
in
enumerate
([
"gate_proj"
,
"up_proj"
]):
if
weight_name
not
in
name
:
continue
param
=
state_dict
[
name
.
replace
(
weight_name
,
"gate_up_proj"
)]
shard_size
=
param
.
shape
[
0
]
//
2
loaded_weight
=
loaded_weight
[
shard_size
*
tensor_model_parallel_rank
:
shard_size
*
(
tensor_model_parallel_rank
+
1
)]
param_slice
=
param
.
data
[
shard_size
*
stride_id
:
shard_size
*
(
stride_id
+
1
)]
assert
param_slice
.
shape
==
loaded_weight
.
shape
param_slice
.
copy_
(
loaded_weight
)
is_gate_up_weight
=
True
break
if
is_gate_up_weight
:
continue
param
=
state_dict
[
name
]
load_tensor_parallel_weights
(
param
,
loaded_weight
,
name
,
self
.
_column_parallel_weights
,
self
.
_row_parallel_weights
,
tensor_model_parallel_rank
)
vllm/transformers_utils/config.py
View file @
20b0d88d
...
...
@@ -4,6 +4,7 @@ from vllm.transformers_utils.configs import * # pylint: disable=wildcard-import
_CONFIG_REGISTRY
=
{
"mpt"
:
MPTConfig
,
"baichuan"
:
BaiChuanConfig
,
}
...
...
vllm/transformers_utils/configs/__init__.py
View file @
20b0d88d
from
vllm.transformers_utils.configs.mpt
import
MPTConfig
from
vllm.transformers_utils.configs.baichuan
import
BaiChuanConfig
__all__
=
[
"MPTConfig"
,
"BaiChuanConfig"
,
]
vllm/transformers_utils/configs/baichuan.py
0 → 100644
View file @
20b0d88d
# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
transformers.configuration_utils
import
PretrainedConfig
class
BaiChuanConfig
(
PretrainedConfig
):
model_type
=
"baichuan"
keys_to_ignore_at_inference
=
[
"past_key_values"
]
def
__init__
(
self
,
vocab_size
=
64000
,
hidden_size
=
4096
,
intermediate_size
=
11008
,
num_hidden_layers
=
32
,
num_attention_heads
=
32
,
hidden_act
=
"silu"
,
max_position_embeddings
=
4096
,
initializer_range
=
0.02
,
rms_norm_eps
=
1e-6
,
use_cache
=
True
,
pad_token_id
=
0
,
bos_token_id
=
1
,
eos_token_id
=
2
,
tie_word_embeddings
=
False
,
**
kwargs
,
):
self
.
vocab_size
=
vocab_size
self
.
max_position_embeddings
=
max_position_embeddings
self
.
hidden_size
=
hidden_size
self
.
intermediate_size
=
intermediate_size
self
.
num_hidden_layers
=
num_hidden_layers
self
.
num_attention_heads
=
num_attention_heads
self
.
hidden_act
=
hidden_act
self
.
initializer_range
=
initializer_range
self
.
rms_norm_eps
=
rms_norm_eps
self
.
use_cache
=
use_cache
super
().
__init__
(
pad_token_id
=
pad_token_id
,
bos_token_id
=
bos_token_id
,
eos_token_id
=
eos_token_id
,
tie_word_embeddings
=
tie_word_embeddings
,
**
kwargs
,
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment