Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
1f01a18d
Unverified
Commit
1f01a18d
authored
Apr 02, 2023
by
Zhuohan Li
Committed by
GitHub
Apr 02, 2023
Browse files
Merge QKV into one linear layer (#15)
parent
2c5cd0de
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
90 additions
and
83 deletions
+90
-83
cacheflow/models/llama.py
cacheflow/models/llama.py
+52
-47
cacheflow/models/opt.py
cacheflow/models/opt.py
+38
-36
No files found.
cacheflow/models/llama.py
View file @
1f01a18d
...
...
@@ -33,22 +33,21 @@ class LlamaMLP(nn.Module):
hidden_act
:
str
,
):
super
().
__init__
()
# TODO: Merge the gate and down linear layers.
self
.
gate_proj
=
ColumnParallelLinear
(
hidden_size
,
intermediate_size
,
bias
=
False
,
gather_output
=
False
,
perform_initialization
=
False
)
self
.
gate_up_proj
=
ColumnParallelLinear
(
hidden_size
,
2
*
intermediate_size
,
bias
=
False
,
gather_output
=
False
,
perform_initialization
=
False
)
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
False
,
input_is_parallel
=
True
,
perform_initialization
=
False
)
self
.
up_proj
=
ColumnParallelLinear
(
hidden_size
,
intermediate_size
,
bias
=
False
,
gather_output
=
False
,
perform_initialization
=
False
)
assert
hidden_act
==
'silu'
self
.
act_fn
=
nn
.
SiLU
()
def
forward
(
self
,
x
):
gate
,
_
=
self
.
gate_proj
(
x
)
up
,
_
=
self
.
up_proj
(
x
)
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
gate_up
=
gate_up
.
reshape
(
gate_up
.
shape
[:
-
1
]
+
(
2
,
-
1
))
gate
,
up
=
torch
.
split
(
gate_up
,
1
,
dim
=-
2
)
gate
=
gate
.
squeeze
(
dim
=-
2
).
contiguous
()
up
=
up
.
squeeze
(
dim
=-
2
).
contiguous
()
x
=
self
.
act_fn
(
gate
)
*
up
x
,
_
=
self
.
down_proj
(
x
)
return
x
...
...
@@ -70,24 +69,9 @@ class LlamaAttention(nn.Module):
self
.
head_dim
=
hidden_size
//
self
.
total_num_heads
self
.
scaling
=
self
.
head_dim
**
-
0.5
# TODO: Merge the QKV linear layers.
self
.
q_proj
=
ColumnParallelLinear
(
hidden_size
,
self
.
total_num_heads
*
self
.
head_dim
,
bias
=
False
,
gather_output
=
False
,
perform_initialization
=
False
,
)
self
.
k_proj
=
ColumnParallelLinear
(
hidden_size
,
self
.
total_num_heads
*
self
.
head_dim
,
bias
=
False
,
gather_output
=
False
,
perform_initialization
=
False
,
)
self
.
v_proj
=
ColumnParallelLinear
(
self
.
qkv_proj
=
ColumnParallelLinear
(
hidden_size
,
self
.
total_num_heads
*
self
.
head_dim
,
3
*
self
.
total_num_heads
*
self
.
head_dim
,
bias
=
False
,
gather_output
=
False
,
perform_initialization
=
False
,
...
...
@@ -109,9 +93,12 @@ class LlamaAttention(nn.Module):
input_metadata
:
InputMetadata
,
cache_event
:
Optional
[
torch
.
cuda
.
Event
],
)
->
torch
.
Tensor
:
q
,
_
=
self
.
q_proj
(
hidden_states
)
k
,
_
=
self
.
k_proj
(
hidden_states
)
v
,
_
=
self
.
v_proj
(
hidden_states
)
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
qkv
=
qkv
.
reshape
(
qkv
.
shape
[:
-
1
]
+
(
3
,
-
1
))
q
,
k
,
v
=
torch
.
split
(
qkv
,
1
,
dim
=-
2
)
q
=
q
.
squeeze
(
dim
=-
2
).
contiguous
()
k
=
k
.
squeeze
(
dim
=-
2
).
contiguous
()
v
=
v
.
squeeze
(
dim
=-
2
).
contiguous
()
k_cache
,
v_cache
=
kv_cache
attn_output
=
self
.
attn
(
positions
,
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
,
cache_event
)
...
...
@@ -230,8 +217,7 @@ class LlamaForCausalLM(nn.Module):
return
next_tokens
_column_parallel_weights
=
[
"embed_tokens.weight"
,
"lm_head.weight"
,
"q_proj.weight"
,
"k_proj.weight"
,
"v_proj.weight"
,
"gate_proj.weight"
,
"qkv_proj.weight"
,
"gate_proj.weight"
,
"up_proj.weight"
]
_row_parallel_weights
=
[
"o_proj.weight"
,
"down_proj.weight"
]
...
...
@@ -239,23 +225,42 @@ class LlamaForCausalLM(nn.Module):
tensor_model_parallel_rank
=
get_tensor_model_parallel_rank
()
state_dict
=
self
.
state_dict
()
for
name
,
param
in
state_dict
.
items
():
loaded_weight
=
torch
.
from_numpy
(
np
.
load
(
os
.
path
.
join
(
weights_path
,
name
)))
for
p
in
self
.
_column_parallel_weights
:
if
p
in
name
:
shard_size
=
param
.
shape
[
0
]
loaded_weight
=
loaded_weight
[
if
"qkv_proj"
in
name
or
"gate_up_proj"
in
name
:
if
"qkv_proj"
in
name
:
original_name
=
"qkv_proj"
weight_names
=
[
"q_proj"
,
"k_proj"
,
"v_proj"
]
shard_size
=
param
.
shape
[
0
]
//
3
else
:
original_name
=
"gate_up_proj"
weight_names
=
[
"gate_proj"
,
"up_proj"
]
shard_size
=
param
.
shape
[
0
]
//
2
weights_to_concat
=
[]
for
weight_name
in
weight_names
:
weight
=
np
.
load
(
os
.
path
.
join
(
weights_path
,
name
.
replace
(
original_name
,
weight_name
)))
weights_to_concat
.
append
(
weight
[
shard_size
*
tensor_model_parallel_rank
:
shard_size
*
(
tensor_model_parallel_rank
+
1
)]
break
for
p
in
self
.
_row_parallel_weights
:
if
p
in
name
:
shard_size
=
param
.
shape
[
1
]
loaded_weight
=
loaded_weight
[
:,
shard_size
*
tensor_model_parallel_rank
:
shard_size
*
(
tensor_model_parallel_rank
+
1
)]
break
:
shard_size
*
(
tensor_model_parallel_rank
+
1
)])
loaded_weight
=
torch
.
from_numpy
(
np
.
concatenate
(
weights_to_concat
,
axis
=
0
))
else
:
loaded_weight
=
torch
.
from_numpy
(
np
.
load
(
os
.
path
.
join
(
weights_path
,
name
)))
for
p
in
self
.
_column_parallel_weights
:
if
p
in
name
:
shard_size
=
param
.
shape
[
0
]
loaded_weight
=
loaded_weight
[
shard_size
*
tensor_model_parallel_rank
:
shard_size
*
(
tensor_model_parallel_rank
+
1
)]
break
for
p
in
self
.
_row_parallel_weights
:
if
p
in
name
:
shard_size
=
param
.
shape
[
1
]
loaded_weight
=
loaded_weight
[
:,
shard_size
*
tensor_model_parallel_rank
:
shard_size
*
(
tensor_model_parallel_rank
+
1
)]
break
assert
param
.
shape
==
loaded_weight
.
shape
param
.
data
.
copy_
(
loaded_weight
)
...
...
cacheflow/models/opt.py
View file @
1f01a18d
...
...
@@ -53,16 +53,9 @@ class OPTAttention(nn.Module):
self
.
head_dim
=
embed_dim
//
total_num_heads
self
.
scaling
=
self
.
head_dim
**
-
0.5
# TODO(woosuk): Fuse the three linear layers into one QKV linear layer.
self
.
k_proj
=
ColumnParallelLinear
(
embed_dim
,
embed_dim
,
bias
=
bias
,
gather_output
=
False
,
perform_initialization
=
False
)
self
.
v_proj
=
ColumnParallelLinear
(
embed_dim
,
embed_dim
,
bias
=
bias
,
gather_output
=
False
,
perform_initialization
=
False
)
self
.
q_proj
=
ColumnParallelLinear
(
embed_dim
,
embed_dim
,
bias
=
bias
,
gather_output
=
False
,
perform_initialization
=
False
)
self
.
qkv_proj
=
ColumnParallelLinear
(
embed_dim
,
3
*
embed_dim
,
bias
=
bias
,
gather_output
=
False
,
perform_initialization
=
False
)
self
.
out_proj
=
RowParallelLinear
(
embed_dim
,
embed_dim
,
bias
=
bias
,
input_is_parallel
=
True
,
perform_initialization
=
False
)
...
...
@@ -75,16 +68,18 @@ class OPTAttention(nn.Module):
input_metadata
:
InputMetadata
,
cache_event
:
Optional
[
torch
.
cuda
.
Event
],
)
->
torch
.
Tensor
:
q
,
_
=
self
.
q_proj
(
hidden_states
)
k
,
_
=
self
.
k_proj
(
hidden_states
)
v
,
_
=
self
.
v_proj
(
hidden_states
)
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
qkv
=
qkv
.
reshape
(
qkv
.
shape
[:
-
1
]
+
(
3
,
-
1
))
q
,
k
,
v
=
torch
.
split
(
qkv
,
1
,
dim
=-
2
)
q
=
q
.
squeeze
(
dim
=-
2
).
contiguous
()
k
=
k
.
squeeze
(
dim
=-
2
).
contiguous
()
v
=
v
.
squeeze
(
dim
=-
2
).
contiguous
()
key_cache
,
value_cache
=
kv_cache
attn_output
=
self
.
attn
(
q
,
k
,
v
,
key_cache
,
value_cache
,
input_metadata
,
cache_event
)
output
,
_
=
self
.
out_proj
(
attn_output
)
return
output
class
OPTDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
OPTConfig
):
...
...
@@ -262,11 +257,7 @@ class OPTForCausalLM(nn.Module):
self
.
lm_head_weight
,
hidden_states
,
input_metadata
)
return
next_tokens
_column_parallel_weights
=
[
"embed_tokens.weight"
,
"q_proj.weight"
,
"k_proj.weight"
,
"v_proj.weight"
,
"fc1.weight"
]
_column_parallel_biases
=
[
"q_proj.bias"
,
"k_proj.bias"
,
"v_proj.bias"
,
"fc1.bias"
]
_column_parallel_weights
=
[
"embed_tokens.weight"
,
"fc1.weight"
,
"fc1.bias"
]
_row_parallel_weights
=
[
"out_proj.weight"
,
"fc2.weight"
]
def
load_weights
(
self
,
weights_path
:
str
):
...
...
@@ -275,24 +266,35 @@ class OPTForCausalLM(nn.Module):
for
name
,
param
in
state_dict
.
items
():
if
"lm_head_weight"
in
name
:
continue
loaded_weight
=
torch
.
from_numpy
(
np
.
load
(
os
.
path
.
join
(
weights_path
,
name
)))
for
p
in
(
self
.
_column_parallel_weights
+
self
.
_column_parallel_biases
):
if
p
in
name
:
shard_size
=
param
.
shape
[
0
]
loaded_weight
=
loaded_weight
[
shard_size
*
tensor_model_parallel_rank
:
shard_size
*
(
tensor_model_parallel_rank
+
1
)]
break
for
p
in
self
.
_row_parallel_weights
:
if
p
in
name
:
shard_size
=
param
.
shape
[
1
]
loaded_weight
=
loaded_weight
[
:,
if
"qkv_proj"
in
name
:
shard_size
=
param
.
shape
[
0
]
//
3
weights_to_concat
=
[]
for
weight_name
in
[
"q_proj"
,
"k_proj"
,
"v_proj"
]:
weight
=
np
.
load
(
os
.
path
.
join
(
weights_path
,
name
.
replace
(
"qkv_proj"
,
weight_name
)))
weights_to_concat
.
append
(
weight
[
shard_size
*
tensor_model_parallel_rank
:
shard_size
*
(
tensor_model_parallel_rank
+
1
)]
break
:
shard_size
*
(
tensor_model_parallel_rank
+
1
)])
loaded_weight
=
torch
.
from_numpy
(
np
.
concatenate
(
weights_to_concat
,
axis
=
0
))
else
:
loaded_weight
=
torch
.
from_numpy
(
np
.
load
(
os
.
path
.
join
(
weights_path
,
name
)))
for
p
in
self
.
_column_parallel_weights
:
if
p
in
name
:
shard_size
=
param
.
shape
[
0
]
loaded_weight
=
loaded_weight
[
shard_size
*
tensor_model_parallel_rank
:
shard_size
*
(
tensor_model_parallel_rank
+
1
)]
break
for
p
in
self
.
_row_parallel_weights
:
if
p
in
name
:
shard_size
=
param
.
shape
[
1
]
loaded_weight
=
loaded_weight
[
:,
shard_size
*
tensor_model_parallel_rank
:
shard_size
*
(
tensor_model_parallel_rank
+
1
)]
break
assert
param
.
shape
==
loaded_weight
.
shape
param
.
data
.
copy_
(
loaded_weight
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment