Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
64f23c29
"docs/vscode:/vscode.git/clone" did not exist on "ced96a0ca66e35ae61b64267e083fc3e3a62172a"
Unverified
Commit
64f23c29
authored
Aug 02, 2023
by
Song
Committed by
GitHub
Aug 01, 2023
Browse files
fix baichuan for different position embedding for 7b and 13b models (#643)
parent
d4c7755c
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
76 additions
and
17 deletions
+76
-17
vllm/model_executor/model_loader.py
vllm/model_executor/model_loader.py
+2
-1
vllm/model_executor/models/__init__.py
vllm/model_executor/models/__init__.py
+2
-1
vllm/model_executor/models/baichuan.py
vllm/model_executor/models/baichuan.py
+72
-15
No files found.
vllm/model_executor/model_loader.py
View file @
64f23c29
...
...
@@ -11,7 +11,8 @@ from vllm.model_executor.weight_utils import initialize_dummy_weights
# TODO(woosuk): Lazy-load the model classes.
_MODEL_REGISTRY
=
{
"BaiChuanForCausalLM"
:
BaiChuanForCausalLM
,
"BaiChuanForCausalLM"
:
BaiChuanForCausalLM
,
# baichuan-7b
"BaichuanForCausalLM"
:
BaichuanForCausalLM
,
# baichuan-13b
"BloomForCausalLM"
:
BloomForCausalLM
,
"GPT2LMHeadModel"
:
GPT2LMHeadModel
,
"GPTBigCodeForCausalLM"
:
GPTBigCodeForCausalLM
,
...
...
vllm/model_executor/models/__init__.py
View file @
64f23c29
from
vllm.model_executor.models.baichuan
import
BaiChuanForCausalLM
from
vllm.model_executor.models.baichuan
import
BaiChuanForCausalLM
,
BaichuanForCausalLM
from
vllm.model_executor.models.bloom
import
BloomForCausalLM
from
vllm.model_executor.models.gpt2
import
GPT2LMHeadModel
from
vllm.model_executor.models.gpt_bigcode
import
GPTBigCodeForCausalLM
...
...
@@ -10,6 +10,7 @@ from vllm.model_executor.models.opt import OPTForCausalLM
__all__
=
[
"BaiChuanForCausalLM"
,
"BaichuanForCausalLM"
,
"BloomForCausalLM"
,
"GPT2LMHeadModel"
,
"GPTBigCodeForCausalLM"
,
...
...
vllm/model_executor/models/baichuan.py
View file @
64f23c29
...
...
@@ -22,6 +22,7 @@
The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input.
"""
import
math
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
torch
...
...
@@ -31,7 +32,7 @@ from vllm.sequence import SequenceOutputs
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.attention
import
PagedAttentionWithRoPE
from
vllm.model_executor.layers.attention
import
PagedAttentionWithRoPE
,
PagedAttentionWithALiBi
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.weight_utils
import
(
hf_model_weights_iterator
,
load_tensor_parallel_weights
)
...
...
@@ -44,6 +45,31 @@ from vllm.transformers_utils.configs.baichuan import BaiChuanConfig
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
def
_get_alibi_slopes
(
total_num_heads
:
int
)
->
torch
.
Tensor
:
closest_power_of_2
=
2
**
math
.
floor
(
math
.
log2
(
total_num_heads
))
base
=
torch
.
tensor
(
2
**
(
-
(
2
**-
(
math
.
log2
(
closest_power_of_2
)
-
3
))),
dtype
=
torch
.
float32
,
)
powers
=
torch
.
arange
(
1
,
1
+
closest_power_of_2
,
dtype
=
torch
.
int32
)
slopes
=
torch
.
pow
(
base
,
powers
)
if
closest_power_of_2
!=
total_num_heads
:
extra_base
=
torch
.
tensor
(
2
**
(
-
(
2
**-
(
math
.
log2
(
2
*
closest_power_of_2
)
-
3
))),
dtype
=
torch
.
float32
,
)
num_remaining_heads
=
min
(
closest_power_of_2
,
total_num_heads
-
closest_power_of_2
)
extra_powers
=
torch
.
arange
(
start
=
1
,
end
=
1
+
2
*
num_remaining_heads
,
step
=
2
,
dtype
=
torch
.
int32
)
slopes
=
torch
.
cat
(
[
slopes
,
torch
.
pow
(
extra_base
,
extra_powers
)],
dim
=
0
)
return
slopes
class
BaiChuanMLP
(
nn
.
Module
):
def
__init__
(
...
...
@@ -82,6 +108,7 @@ class BaiChuanAttention(nn.Module):
self
,
hidden_size
:
int
,
num_heads
:
int
,
position_embedding
:
str
,
):
super
().
__init__
()
self
.
hidden_size
=
hidden_size
...
...
@@ -92,7 +119,7 @@ class BaiChuanAttention(nn.Module):
self
.
num_heads
=
(
self
.
total_num_heads
//
tensor_model_parallel_world_size
)
self
.
head_dim
=
hidden_size
//
self
.
total_num_heads
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
postion_embedding
=
position_embedding
# pylint: disable=invalid-name
self
.
W_pack
=
ColumnParallelLinear
(
...
...
@@ -109,11 +136,23 @@ class BaiChuanAttention(nn.Module):
input_is_parallel
=
True
,
perform_initialization
=
False
,
)
self
.
attn
=
PagedAttentionWithRoPE
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
rotary_dim
=
self
.
head_dim
)
# Create the alibi slopes and slice them.
if
self
.
postion_embedding
==
"ALIBI"
:
tp_rank
=
get_tensor_model_parallel_rank
()
head_start
=
tp_rank
*
self
.
num_heads
head_end
=
(
tp_rank
+
1
)
*
self
.
num_heads
alibi_slopes
=
_get_alibi_slopes
(
self
.
total_num_heads
)
alibi_slopes
=
alibi_slopes
[
head_start
:
head_end
].
tolist
()
scaling
=
self
.
head_dim
**-
0.5
self
.
attn
=
PagedAttentionWithALiBi
(
self
.
num_heads
,
self
.
head_dim
,
scaling
,
alibi_slopes
)
else
:
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
attn
=
PagedAttentionWithRoPE
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
rotary_dim
=
self
.
head_dim
)
def
forward
(
self
,
...
...
@@ -126,20 +165,26 @@ class BaiChuanAttention(nn.Module):
qkv
,
_
=
self
.
W_pack
(
hidden_states
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
k_cache
,
v_cache
=
kv_cache
attn_output
=
self
.
attn
(
positions
,
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
,
cache_event
)
if
self
.
postion_embedding
==
"ALIBI"
:
attn_output
=
self
.
attn
(
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
,
cache_event
)
else
:
attn_output
=
self
.
attn
(
positions
,
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
,
cache_event
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
class
BaiChuanDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
BaiChuanConfig
):
def
__init__
(
self
,
config
:
BaiChuanConfig
,
position_embedding
:
str
):
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
self
.
self_attn
=
BaiChuanAttention
(
hidden_size
=
self
.
hidden_size
,
num_heads
=
config
.
num_attention_heads
,
position_embedding
=
position_embedding
,
)
self
.
mlp
=
BaiChuanMLP
(
hidden_size
=
self
.
hidden_size
,
...
...
@@ -181,7 +226,7 @@ class BaiChuanDecoderLayer(nn.Module):
class
BaiChuanModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
BaiChuanConfig
):
def
__init__
(
self
,
config
:
BaiChuanConfig
,
position_embedding
:
str
):
super
().
__init__
()
self
.
config
=
config
self
.
padding_idx
=
config
.
pad_token_id
...
...
@@ -192,7 +237,7 @@ class BaiChuanModel(nn.Module):
config
.
hidden_size
,
perform_initialization
=
False
)
self
.
layers
=
nn
.
ModuleList
([
BaiChuanDecoderLayer
(
config
)
BaiChuanDecoderLayer
(
config
,
position_embedding
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
...
...
@@ -223,12 +268,12 @@ class BaiChuanModel(nn.Module):
return
hidden_states
class
BaiChuanForCausalLM
(
nn
.
Module
):
class
BaiChuan
Base
ForCausalLM
(
nn
.
Module
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
,
position_embedding
:
str
):
super
().
__init__
()
self
.
config
=
config
self
.
model
=
BaiChuanModel
(
config
)
self
.
model
=
BaiChuanModel
(
config
,
position_embedding
)
self
.
lm_head
=
ColumnParallelLinear
(
config
.
hidden_size
,
config
.
vocab_size
,
bias
=
False
,
...
...
@@ -318,3 +363,15 @@ class BaiChuanForCausalLM(nn.Module):
self
.
_row_parallel_weights
,
tp_rank
,
)
class
BaichuanForCausalLM
(
BaiChuanBaseForCausalLM
):
# baichuan 13b
def
__init__
(
self
,
config
):
super
().
__init__
(
config
,
"ALIBI"
)
class
BaiChuanForCausalLM
(
BaiChuanBaseForCausalLM
):
# baichuan 7b
def
__init__
(
self
,
config
):
super
().
__init__
(
config
,
"ROPE"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment