Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
3256ffec
Commit
3256ffec
authored
Sep 20, 2023
by
Casper Hansen
Browse files
Support kv_heads
parent
63a12504
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
61 additions
and
27 deletions
+61
-27
awq/models/llama.py
awq/models/llama.py
+1
-0
awq/modules/fused/attn.py
awq/modules/fused/attn.py
+49
-22
awq/modules/fused/block.py
awq/modules/fused/block.py
+11
-5
No files found.
awq/models/llama.py
View file @
3256ffec
...
...
@@ -100,6 +100,7 @@ class LlamaFuser:
attn
=
QuantAttentionFused
(
module
.
hidden_size
,
module
.
num_heads
,
module
.
num_key_value_heads
,
qkv_layer
,
module
.
o_proj
,
next
(
iter
(
qkv_layer
.
state_dict
().
values
())).
device
,
...
...
awq/modules/fused/attn.py
View file @
3256ffec
...
...
@@ -115,34 +115,61 @@ class QuantLlamaRotaryEmbedding(nn.Module):
return
query
,
key
class
QuantAttentionFused
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
n
um
_heads
,
qkv_layer
,
o_proj
,
dev
,
max_seq_len
,
def
__init__
(
self
,
hidden_size
,
n
_heads
,
n_kv
_heads
,
qkv_layer
,
o_proj
,
dev
,
max_seq_len
,
use_alibi
=
False
,
attention_shapes
=
None
):
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
n_local_heads
=
num_heads
self
.
head_dim
=
self
.
hidden_size
//
num_heads
self
.
n_heads
=
n_heads
self
.
n_kv_heads
=
n_kv_heads
self
.
head_dim
=
self
.
hidden_size
//
n_heads
self
.
qkv_proj
=
qkv_layer
self
.
o_proj
=
o_proj
self
.
start_pos
=
0
self
.
use_alibi
=
use_alibi
self
.
cache_batch_size
=
int
(
os
.
getenv
(
"AWQ_BATCH_SIZE"
,
"1"
))
self
.
attention_shapes
=
attention_shapes
if
attention_shapes
is
not
None
else
{
# following fastertransformer definition
"cache_v"
:
(
self
.
cache_batch_size
,
self
.
n_local_heads
,
max_seq_len
,
self
.
head_dim
,),
# 8: pack 8 fp16 in FT, if fp32 then use 4
"cache_k"
:
(
self
.
cache_batch_size
,
self
.
n_local_heads
,
self
.
head_dim
//
8
,
max_seq_len
,
8
,),
"xqkv_view"
:
(
-
1
,
self
.
n_local_heads
,
self
.
head_dim
),
"xq_slice"
:
lambda
xqkv
:
xqkv
[:,
:,
0
],
"xk_slice"
:
lambda
xqkv
:
xqkv
[:,
:,
1
],
"xv_slice"
:
lambda
xqkv
:
xqkv
[:,
:,
2
],
"xk_reshape"
:
(
self
.
n_local_heads
,
self
.
head_dim
//
8
,
8
),
"xq_view"
:
(
self
.
n_local_heads
,
self
.
head_dim
),
"xk_view"
:
(
self
.
n_local_heads
,
self
.
head_dim
),
"xv_view"
:
(
self
.
n_local_heads
,
self
.
head_dim
),
"single_xq_view"
:
(
self
.
n_local_heads
,
self
.
head_dim
),
"single_xk_view"
:
(
self
.
n_local_heads
,
self
.
head_dim
),
"single_xv_view"
:
(
self
.
n_local_heads
,
self
.
head_dim
)
}
if
attention_shapes
is
not
None
:
self
.
attention_shapes
=
attention_shapes
elif
self
.
n_kv_heads
==
0
:
self
.
attention_shapes
=
{
# following fastertransformer definition
"cache_v"
:
(
self
.
cache_batch_size
,
self
.
n_heads
,
max_seq_len
,
self
.
head_dim
,),
# 8: pack 8 fp16 in FT, if fp32 then use 4
"cache_k"
:
(
self
.
cache_batch_size
,
self
.
n_heads
,
self
.
head_dim
//
8
,
max_seq_len
,
8
,),
"xqkv_view"
:
(
-
1
,
self
.
n_heads
,
self
.
head_dim
),
"xq_slice"
:
lambda
xqkv
:
xqkv
[:,
:,
0
],
"xk_slice"
:
lambda
xqkv
:
xqkv
[:,
:,
1
],
"xv_slice"
:
lambda
xqkv
:
xqkv
[:,
:,
2
],
"xq_view"
:
(
self
.
n_heads
,
self
.
head_dim
),
"xk_view"
:
(
self
.
n_heads
,
self
.
head_dim
),
"xv_view"
:
(
self
.
n_heads
,
self
.
head_dim
),
"xk_reshape"
:
(
self
.
n_heads
,
self
.
head_dim
//
8
,
8
),
"single_xq_view"
:
(
self
.
n_heads
,
self
.
head_dim
),
"single_xk_view"
:
(
self
.
n_heads
,
self
.
head_dim
),
"single_xv_view"
:
(
self
.
n_heads
,
self
.
head_dim
)
}
else
:
self
.
attention_shapes
=
{
# following fastertransformer definition
"cache_v"
:
(
self
.
cache_batch_size
,
self
.
n_kv_heads
,
max_seq_len
,
self
.
head_dim
,),
# 8: pack 8 fp16 in FT, if fp32 then use 4
"cache_k"
:
(
self
.
cache_batch_size
,
self
.
n_kv_heads
,
self
.
head_dim
//
8
,
max_seq_len
,
8
,),
"xqkv_view"
:
(
self
.
n_heads
+
self
.
n_kv_heads
*
2
,
self
.
head_dim
),
"xq_slice"
:
lambda
xqkv
:
xqkv
[:,
:,
0
:
self
.
n_kv_heads
],
"xk_slice"
:
lambda
xqkv
:
xqkv
[:,
:,
self
.
n_heads
:
(
self
.
n_heads
+
self
.
n_kv_heads
)],
"xv_slice"
:
lambda
xqkv
:
xqkv
[:,
:,
-
self
.
n_kv_heads
:],
"xq_view"
:
(
self
.
n_kv_heads
,
self
.
head_dim
),
"xk_view"
:
(
self
.
n_kv_heads
,
self
.
head_dim
),
"xv_view"
:
(
self
.
n_kv_heads
,
self
.
head_dim
),
"xk_reshape"
:
(
self
.
n_kv_heads
,
self
.
head_dim
//
8
,
8
),
"single_xq_view"
:
(
self
.
n_kv_heads
,
self
.
head_dim
),
"single_xk_view"
:
(
self
.
n_kv_heads
,
self
.
head_dim
),
"single_xv_view"
:
(
self
.
n_kv_heads
,
self
.
head_dim
)
}
print
(
self
.
attention_shapes
)
self
.
cache_v
=
(
torch
.
zeros
(
self
.
attention_shapes
[
"cache_v"
]).
to
(
dev
).
half
()
...
...
@@ -153,14 +180,14 @@ class QuantAttentionFused(nn.Module):
)
if
use_alibi
:
alibi_slopes
,
alibi_bias
=
build_alibi_bias
(
self
.
n_
local_
heads
,
max_seq_len
)
alibi_slopes
,
alibi_bias
=
build_alibi_bias
(
self
.
n_heads
,
max_seq_len
)
self
.
alibi_slopes
=
alibi_slopes
.
float
().
to
(
dev
)
self
.
alibi_bias
=
alibi_bias
.
float
().
to
(
dev
)
self
.
rotary_dim
=
0
self
.
is_neox
=
False
else
:
self
.
freqs_cis
=
precompute_freqs_cis
(
hidden_size
//
n
um
_heads
,
hidden_size
//
n_heads
,
max_seq_len
*
2
,
).
to
(
dev
)
self
.
rotary_dim
=
self
.
head_dim
...
...
awq/modules/fused/block.py
View file @
3256ffec
...
...
@@ -6,9 +6,13 @@ class MPTBlock(nn.Module):
def
__init__
(
self
,
hidden_size
,
n_heads
,
qkv_layer
,
o_proj
,
mpt_mlp
,
norm_1
,
norm_2
,
dev
,
max_seq_len
):
super
().
__init__
()
self
.
n_heads
=
n_heads
self
.
n_kv_heads
=
0
self
.
hidden_size
=
hidden_size
self
.
norm_1
=
norm_1
self
.
attn
=
QuantAttentionFused
(
hidden_size
,
self
.
n_heads
,
qkv_layer
,
o_proj
,
dev
=
dev
,
max_seq_len
=
max_seq_len
,
use_alibi
=
True
).
to
(
dev
)
self
.
attn
=
QuantAttentionFused
(
hidden_size
,
self
.
n_heads
,
self
.
n_kv_heads
,
qkv_layer
,
o_proj
,
dev
=
dev
,
max_seq_len
=
max_seq_len
,
use_alibi
=
True
).
to
(
dev
)
self
.
norm_2
=
norm_2
self
.
ffn
=
mpt_mlp
.
to
(
dev
)
...
...
@@ -30,16 +34,18 @@ class MPTBlock(nn.Module):
return
out
,
None
,
past_key_value
class
FalconDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
n_heads
,
qkv_layer
,
o_proj
,
mlp
,
dev
,
max_seq_len
,
input_layernorm
=
None
,
ln_attn
=
None
,
ln_mlp
=
None
,
new_decoder_arch
=
True
):
def
__init__
(
self
,
hidden_size
,
n_heads
,
qkv_layer
,
o_proj
,
mlp
,
dev
,
max_seq_len
,
input_layernorm
=
None
,
ln_attn
=
None
,
ln_mlp
=
None
,
new_decoder_arch
=
True
):
super
().
__init__
()
self
.
n_heads
=
n_heads
self
.
n_kv_heads
=
8
self
.
hidden_size
=
hidden_size
self
.
new_decoder_arch
=
new_decoder_arch
attention_shapes
=
self
.
_get_attention_shapes
(
n_heads
,
max_seq_len
,
self
.
hidden_size
//
n_heads
,
new_decoder_arch
)
# TODO: Falcon has ALiBi implemented but which model uses it?
self
.
attn
=
QuantAttentionFused
(
hidden_size
,
self
.
n_heads
,
qkv_layer
,
o_proj
,
hidden_size
,
self
.
n_heads
,
self
.
n_kv_heads
,
qkv_layer
,
o_proj
,
dev
=
dev
,
max_seq_len
=
max_seq_len
,
use_alibi
=
False
,
attention_shapes
=
attention_shapes
).
to
(
dev
)
...
...
@@ -67,10 +73,10 @@ class FalconDecoderLayer(nn.Module):
"xq_slice"
:
lambda
xqkv
:
xqkv
[:,
:,
:,
0
],
"xk_slice"
:
lambda
xqkv
:
xqkv
[:,
:,
:,
1
],
"xv_slice"
:
lambda
xqkv
:
xqkv
[:,
:,
:,
2
],
"xk_reshape"
:
(
1
,
head_dim
//
8
,
8
),
"xq_view"
:
(
1
,
head_dim
),
"xk_view"
:
(
1
,
head_dim
),
"xv_view"
:
(
1
,
head_dim
),
"xk_reshape"
:
(
1
,
head_dim
//
8
,
8
),
"single_xq_view"
:
(
n_heads
,
head_dim
),
"single_xk_view"
:
(
1
,
8
,
head_dim
),
"single_xv_view"
:
(
1
,
8
,
head_dim
)
...
...
@@ -85,10 +91,10 @@ class FalconDecoderLayer(nn.Module):
"xq_slice"
:
lambda
xqkv
:
xqkv
[:,
:,
:
-
2
],
"xk_slice"
:
lambda
xqkv
:
xqkv
[:,
:,
[
-
2
]],
"xv_slice"
:
lambda
xqkv
:
xqkv
[:,
:,
[
-
1
]],
"xk_reshape"
:
(
1
,
head_dim
//
8
,
8
),
"xq_view"
:
(
n_heads
,
head_dim
),
"xk_view"
:
(
1
,
head_dim
),
"xv_view"
:
(
1
,
head_dim
),
"xk_reshape"
:
(
1
,
head_dim
//
8
,
8
),
"single_xq_view"
:
(
n_heads
,
head_dim
),
"single_xk_view"
:
(
1
,
head_dim
),
"single_xv_view"
:
(
1
,
head_dim
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment