Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
64f23c29
Unverified
Commit
64f23c29
authored
Aug 02, 2023
by
Song
Committed by
GitHub
Aug 01, 2023
Browse files
fix baichuan for different position embedding for 7b and 13b models (#643)
parent
d4c7755c
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
76 additions
and
17 deletions
+76
-17
vllm/model_executor/model_loader.py
vllm/model_executor/model_loader.py
+2
-1
vllm/model_executor/models/__init__.py
vllm/model_executor/models/__init__.py
+2
-1
vllm/model_executor/models/baichuan.py
vllm/model_executor/models/baichuan.py
+72
-15
No files found.
vllm/model_executor/model_loader.py
View file @
64f23c29
...
@@ -11,7 +11,8 @@ from vllm.model_executor.weight_utils import initialize_dummy_weights
...
@@ -11,7 +11,8 @@ from vllm.model_executor.weight_utils import initialize_dummy_weights
# TODO(woosuk): Lazy-load the model classes.
# TODO(woosuk): Lazy-load the model classes.
_MODEL_REGISTRY
=
{
_MODEL_REGISTRY
=
{
"BaiChuanForCausalLM"
:
BaiChuanForCausalLM
,
"BaiChuanForCausalLM"
:
BaiChuanForCausalLM
,
# baichuan-7b
"BaichuanForCausalLM"
:
BaichuanForCausalLM
,
# baichuan-13b
"BloomForCausalLM"
:
BloomForCausalLM
,
"BloomForCausalLM"
:
BloomForCausalLM
,
"GPT2LMHeadModel"
:
GPT2LMHeadModel
,
"GPT2LMHeadModel"
:
GPT2LMHeadModel
,
"GPTBigCodeForCausalLM"
:
GPTBigCodeForCausalLM
,
"GPTBigCodeForCausalLM"
:
GPTBigCodeForCausalLM
,
...
...
vllm/model_executor/models/__init__.py
View file @
64f23c29
from
vllm.model_executor.models.baichuan
import
BaiChuanForCausalLM
from
vllm.model_executor.models.baichuan
import
BaiChuanForCausalLM
,
BaichuanForCausalLM
from
vllm.model_executor.models.bloom
import
BloomForCausalLM
from
vllm.model_executor.models.bloom
import
BloomForCausalLM
from
vllm.model_executor.models.gpt2
import
GPT2LMHeadModel
from
vllm.model_executor.models.gpt2
import
GPT2LMHeadModel
from
vllm.model_executor.models.gpt_bigcode
import
GPTBigCodeForCausalLM
from
vllm.model_executor.models.gpt_bigcode
import
GPTBigCodeForCausalLM
...
@@ -10,6 +10,7 @@ from vllm.model_executor.models.opt import OPTForCausalLM
...
@@ -10,6 +10,7 @@ from vllm.model_executor.models.opt import OPTForCausalLM
__all__
=
[
__all__
=
[
"BaiChuanForCausalLM"
,
"BaiChuanForCausalLM"
,
"BaichuanForCausalLM"
,
"BloomForCausalLM"
,
"BloomForCausalLM"
,
"GPT2LMHeadModel"
,
"GPT2LMHeadModel"
,
"GPTBigCodeForCausalLM"
,
"GPTBigCodeForCausalLM"
,
...
...
vllm/model_executor/models/baichuan.py
View file @
64f23c29
...
@@ -22,6 +22,7 @@
...
@@ -22,6 +22,7 @@
The input of the model is flattened to a 1D tensor of tokens. The model uses
The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input.
InputMetadata to extract the original 2D shape of the input.
"""
"""
import
math
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
torch
import
torch
...
@@ -31,7 +32,7 @@ from vllm.sequence import SequenceOutputs
...
@@ -31,7 +32,7 @@ from vllm.sequence import SequenceOutputs
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.attention
import
PagedAttentionWithRoPE
from
vllm.model_executor.layers.attention
import
PagedAttentionWithRoPE
,
PagedAttentionWithALiBi
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.weight_utils
import
(
hf_model_weights_iterator
,
from
vllm.model_executor.weight_utils
import
(
hf_model_weights_iterator
,
load_tensor_parallel_weights
)
load_tensor_parallel_weights
)
...
@@ -44,6 +45,31 @@ from vllm.transformers_utils.configs.baichuan import BaiChuanConfig
...
@@ -44,6 +45,31 @@ from vllm.transformers_utils.configs.baichuan import BaiChuanConfig
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
def
_get_alibi_slopes
(
total_num_heads
:
int
)
->
torch
.
Tensor
:
closest_power_of_2
=
2
**
math
.
floor
(
math
.
log2
(
total_num_heads
))
base
=
torch
.
tensor
(
2
**
(
-
(
2
**-
(
math
.
log2
(
closest_power_of_2
)
-
3
))),
dtype
=
torch
.
float32
,
)
powers
=
torch
.
arange
(
1
,
1
+
closest_power_of_2
,
dtype
=
torch
.
int32
)
slopes
=
torch
.
pow
(
base
,
powers
)
if
closest_power_of_2
!=
total_num_heads
:
extra_base
=
torch
.
tensor
(
2
**
(
-
(
2
**-
(
math
.
log2
(
2
*
closest_power_of_2
)
-
3
))),
dtype
=
torch
.
float32
,
)
num_remaining_heads
=
min
(
closest_power_of_2
,
total_num_heads
-
closest_power_of_2
)
extra_powers
=
torch
.
arange
(
start
=
1
,
end
=
1
+
2
*
num_remaining_heads
,
step
=
2
,
dtype
=
torch
.
int32
)
slopes
=
torch
.
cat
(
[
slopes
,
torch
.
pow
(
extra_base
,
extra_powers
)],
dim
=
0
)
return
slopes
class
BaiChuanMLP
(
nn
.
Module
):
class
BaiChuanMLP
(
nn
.
Module
):
def
__init__
(
def
__init__
(
...
@@ -82,6 +108,7 @@ class BaiChuanAttention(nn.Module):
...
@@ -82,6 +108,7 @@ class BaiChuanAttention(nn.Module):
self
,
self
,
hidden_size
:
int
,
hidden_size
:
int
,
num_heads
:
int
,
num_heads
:
int
,
position_embedding
:
str
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
...
@@ -92,7 +119,7 @@ class BaiChuanAttention(nn.Module):
...
@@ -92,7 +119,7 @@ class BaiChuanAttention(nn.Module):
self
.
num_heads
=
(
self
.
total_num_heads
//
self
.
num_heads
=
(
self
.
total_num_heads
//
tensor_model_parallel_world_size
)
tensor_model_parallel_world_size
)
self
.
head_dim
=
hidden_size
//
self
.
total_num_heads
self
.
head_dim
=
hidden_size
//
self
.
total_num_heads
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
postion_embedding
=
position_embedding
# pylint: disable=invalid-name
# pylint: disable=invalid-name
self
.
W_pack
=
ColumnParallelLinear
(
self
.
W_pack
=
ColumnParallelLinear
(
...
@@ -109,7 +136,19 @@ class BaiChuanAttention(nn.Module):
...
@@ -109,7 +136,19 @@ class BaiChuanAttention(nn.Module):
input_is_parallel
=
True
,
input_is_parallel
=
True
,
perform_initialization
=
False
,
perform_initialization
=
False
,
)
)
# Create the alibi slopes and slice them.
if
self
.
postion_embedding
==
"ALIBI"
:
tp_rank
=
get_tensor_model_parallel_rank
()
head_start
=
tp_rank
*
self
.
num_heads
head_end
=
(
tp_rank
+
1
)
*
self
.
num_heads
alibi_slopes
=
_get_alibi_slopes
(
self
.
total_num_heads
)
alibi_slopes
=
alibi_slopes
[
head_start
:
head_end
].
tolist
()
scaling
=
self
.
head_dim
**-
0.5
self
.
attn
=
PagedAttentionWithALiBi
(
self
.
num_heads
,
self
.
head_dim
,
scaling
,
alibi_slopes
)
else
:
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
attn
=
PagedAttentionWithRoPE
(
self
.
num_heads
,
self
.
attn
=
PagedAttentionWithRoPE
(
self
.
num_heads
,
self
.
head_dim
,
self
.
head_dim
,
self
.
scaling
,
self
.
scaling
,
...
@@ -126,20 +165,26 @@ class BaiChuanAttention(nn.Module):
...
@@ -126,20 +165,26 @@ class BaiChuanAttention(nn.Module):
qkv
,
_
=
self
.
W_pack
(
hidden_states
)
qkv
,
_
=
self
.
W_pack
(
hidden_states
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
k_cache
,
v_cache
=
kv_cache
k_cache
,
v_cache
=
kv_cache
if
self
.
postion_embedding
==
"ALIBI"
:
attn_output
=
self
.
attn
(
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
,
cache_event
)
else
:
attn_output
=
self
.
attn
(
positions
,
q
,
k
,
v
,
k_cache
,
v_cache
,
attn_output
=
self
.
attn
(
positions
,
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
,
cache_event
)
input_metadata
,
cache_event
)
output
,
_
=
self
.
o_proj
(
attn_output
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
return
output
class
BaiChuanDecoderLayer
(
nn
.
Module
):
class
BaiChuanDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
BaiChuanConfig
):
def
__init__
(
self
,
config
:
BaiChuanConfig
,
position_embedding
:
str
):
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
self
.
hidden_size
=
config
.
hidden_size
self
.
self_attn
=
BaiChuanAttention
(
self
.
self_attn
=
BaiChuanAttention
(
hidden_size
=
self
.
hidden_size
,
hidden_size
=
self
.
hidden_size
,
num_heads
=
config
.
num_attention_heads
,
num_heads
=
config
.
num_attention_heads
,
position_embedding
=
position_embedding
,
)
)
self
.
mlp
=
BaiChuanMLP
(
self
.
mlp
=
BaiChuanMLP
(
hidden_size
=
self
.
hidden_size
,
hidden_size
=
self
.
hidden_size
,
...
@@ -181,7 +226,7 @@ class BaiChuanDecoderLayer(nn.Module):
...
@@ -181,7 +226,7 @@ class BaiChuanDecoderLayer(nn.Module):
class
BaiChuanModel
(
nn
.
Module
):
class
BaiChuanModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
BaiChuanConfig
):
def
__init__
(
self
,
config
:
BaiChuanConfig
,
position_embedding
:
str
):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
padding_idx
=
config
.
pad_token_id
self
.
padding_idx
=
config
.
pad_token_id
...
@@ -192,7 +237,7 @@ class BaiChuanModel(nn.Module):
...
@@ -192,7 +237,7 @@ class BaiChuanModel(nn.Module):
config
.
hidden_size
,
config
.
hidden_size
,
perform_initialization
=
False
)
perform_initialization
=
False
)
self
.
layers
=
nn
.
ModuleList
([
self
.
layers
=
nn
.
ModuleList
([
BaiChuanDecoderLayer
(
config
)
BaiChuanDecoderLayer
(
config
,
position_embedding
)
for
_
in
range
(
config
.
num_hidden_layers
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
])
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
...
@@ -223,12 +268,12 @@ class BaiChuanModel(nn.Module):
...
@@ -223,12 +268,12 @@ class BaiChuanModel(nn.Module):
return
hidden_states
return
hidden_states
class
BaiChuanForCausalLM
(
nn
.
Module
):
class
BaiChuan
Base
ForCausalLM
(
nn
.
Module
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
,
position_embedding
:
str
):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
model
=
BaiChuanModel
(
config
)
self
.
model
=
BaiChuanModel
(
config
,
position_embedding
)
self
.
lm_head
=
ColumnParallelLinear
(
config
.
hidden_size
,
self
.
lm_head
=
ColumnParallelLinear
(
config
.
hidden_size
,
config
.
vocab_size
,
config
.
vocab_size
,
bias
=
False
,
bias
=
False
,
...
@@ -318,3 +363,15 @@ class BaiChuanForCausalLM(nn.Module):
...
@@ -318,3 +363,15 @@ class BaiChuanForCausalLM(nn.Module):
self
.
_row_parallel_weights
,
self
.
_row_parallel_weights
,
tp_rank
,
tp_rank
,
)
)
class
BaichuanForCausalLM
(
BaiChuanBaseForCausalLM
):
# baichuan 13b
def
__init__
(
self
,
config
):
super
().
__init__
(
config
,
"ALIBI"
)
class
BaiChuanForCausalLM
(
BaiChuanBaseForCausalLM
):
# baichuan 7b
def
__init__
(
self
,
config
):
super
().
__init__
(
config
,
"ROPE"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment