Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
a5e8b048
Unverified
Commit
a5e8b048
authored
Sep 20, 2023
by
Casper
Committed by
GitHub
Sep 20, 2023
Browse files
Merge pull request #60 from casper-hansen/kv_heads
Support kv_heads
parents
bf76e108
a024e893
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
90 additions
and
66 deletions
+90
-66
awq/models/falcon.py
awq/models/falcon.py
+4
-1
awq/models/llama.py
awq/models/llama.py
+1
-0
awq/modules/fused/attn.py
awq/modules/fused/attn.py
+53
-22
awq/modules/fused/block.py
awq/modules/fused/block.py
+32
-43
No files found.
awq/models/falcon.py
View file @
a5e8b048
...
@@ -7,7 +7,10 @@ class FalconAWQForCausalLM(BaseAWQForCausalLM):
...
@@ -7,7 +7,10 @@ class FalconAWQForCausalLM(BaseAWQForCausalLM):
@
staticmethod
@
staticmethod
def
fuse_layers
(
model
:
FalconForCausalLM
,
quant_config
:
dict
):
def
fuse_layers
(
model
:
FalconForCausalLM
,
quant_config
:
dict
):
fuser
=
FalconFuser
(
model
)
fuser
=
FalconFuser
(
model
)
fuser
.
fuse_transformer
()
# TODO: Implement correctly fused modules for Falcon 40B and Falcon 180B
if
model
.
config
.
num_attention_heads
==
71
:
fuser
.
fuse_transformer
()
@
staticmethod
@
staticmethod
def
get_model_layers
(
model
:
FalconForCausalLM
):
def
get_model_layers
(
model
:
FalconForCausalLM
):
...
...
awq/models/llama.py
View file @
a5e8b048
...
@@ -100,6 +100,7 @@ class LlamaFuser:
...
@@ -100,6 +100,7 @@ class LlamaFuser:
attn
=
QuantAttentionFused
(
attn
=
QuantAttentionFused
(
module
.
hidden_size
,
module
.
hidden_size
,
module
.
num_heads
,
module
.
num_heads
,
module
.
num_key_value_heads
,
qkv_layer
,
qkv_layer
,
module
.
o_proj
,
module
.
o_proj
,
next
(
iter
(
qkv_layer
.
state_dict
().
values
())).
device
,
next
(
iter
(
qkv_layer
.
state_dict
().
values
())).
device
,
...
...
awq/modules/fused/attn.py
View file @
a5e8b048
...
@@ -61,34 +61,60 @@ def build_alibi_bias(
...
@@ -61,34 +61,60 @@ def build_alibi_bias(
class
QuantAttentionFused
(
nn
.
Module
):
class
QuantAttentionFused
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
n
um
_heads
,
qkv_layer
,
o_proj
,
dev
,
max_seq_len
,
def
__init__
(
self
,
hidden_size
,
n
_heads
,
n_kv
_heads
,
qkv_layer
,
o_proj
,
dev
,
max_seq_len
,
use_alibi
=
False
,
attention_shapes
=
None
):
use_alibi
=
False
,
attention_shapes
=
None
):
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
self
.
n_local_heads
=
num_heads
self
.
n_heads
=
n_heads
self
.
head_dim
=
self
.
hidden_size
//
num_heads
self
.
n_kv_heads
=
n_kv_heads
self
.
n_kv_groups
=
n_heads
//
n_kv_heads
if
n_kv_heads
!=
0
else
0
self
.
head_dim
=
self
.
hidden_size
//
n_heads
self
.
qkv_proj
=
qkv_layer
self
.
qkv_proj
=
qkv_layer
self
.
o_proj
=
o_proj
self
.
o_proj
=
o_proj
self
.
start_pos
=
0
self
.
start_pos
=
0
self
.
use_alibi
=
use_alibi
self
.
use_alibi
=
use_alibi
self
.
cache_batch_size
=
int
(
os
.
getenv
(
"AWQ_BATCH_SIZE"
,
"1"
))
self
.
cache_batch_size
=
int
(
os
.
getenv
(
"AWQ_BATCH_SIZE"
,
"1"
))
self
.
attention_shapes
=
attention_shapes
if
attention_shapes
is
not
None
else
{
# following fastertransformer definition
if
attention_shapes
is
not
None
:
"cache_v"
:
(
self
.
cache_batch_size
,
self
.
n_local_heads
,
max_seq_len
,
self
.
head_dim
,),
self
.
attention_shapes
=
attention_shapes
# 8: pack 8 fp16 in FT, if fp32 then use 4
"cache_k"
:
(
self
.
cache_batch_size
,
self
.
n_local_heads
,
self
.
head_dim
//
8
,
max_seq_len
,
8
,),
elif
self
.
n_kv_heads
==
0
:
"xqkv_view"
:
(
-
1
,
self
.
n_local_heads
,
self
.
head_dim
),
self
.
attention_shapes
=
{
"xq_slice"
:
lambda
xqkv
:
xqkv
[:,
:,
0
],
# following fastertransformer definition
"xk_slice"
:
lambda
xqkv
:
xqkv
[:,
:,
1
],
"cache_v"
:
(
self
.
cache_batch_size
,
self
.
n_heads
,
max_seq_len
,
self
.
head_dim
,),
"xv_slice"
:
lambda
xqkv
:
xqkv
[:,
:,
2
],
# 8: pack 8 fp16 in FT, if fp32 then use 4
"xk_reshape"
:
(
self
.
n_local_heads
,
self
.
head_dim
//
8
,
8
),
"cache_k"
:
(
self
.
cache_batch_size
,
self
.
n_heads
,
self
.
head_dim
//
8
,
max_seq_len
,
8
,),
"xq_view"
:
(
self
.
n_local_heads
,
self
.
head_dim
),
"xqkv_view"
:
(
-
1
,
self
.
n_heads
,
self
.
head_dim
),
"xk_view"
:
(
self
.
n_local_heads
,
self
.
head_dim
),
"xq_slice"
:
lambda
xqkv
:
xqkv
[:,
:,
0
],
"xv_view"
:
(
self
.
n_local_heads
,
self
.
head_dim
),
"xk_slice"
:
lambda
xqkv
:
xqkv
[:,
:,
1
],
"single_xq_view"
:
(
self
.
n_local_heads
,
self
.
head_dim
),
"xv_slice"
:
lambda
xqkv
:
xqkv
[:,
:,
2
],
"single_xk_view"
:
(
self
.
n_local_heads
,
self
.
head_dim
),
"xq_view"
:
(
self
.
n_heads
,
self
.
head_dim
),
"single_xv_view"
:
(
self
.
n_local_heads
,
self
.
head_dim
)
"xk_view"
:
(
self
.
n_heads
,
self
.
head_dim
),
}
"xv_view"
:
(
self
.
n_heads
,
self
.
head_dim
),
"xk_reshape"
:
(
self
.
n_heads
,
self
.
head_dim
//
8
,
8
),
"single_xq_view"
:
(
self
.
n_heads
,
self
.
head_dim
),
"single_xk_view"
:
(
self
.
n_heads
,
self
.
head_dim
),
"single_xv_view"
:
(
self
.
n_heads
,
self
.
head_dim
)
}
else
:
self
.
attention_shapes
=
{
# following fastertransformer definition
"cache_v"
:
(
self
.
cache_batch_size
,
self
.
n_kv_heads
,
max_seq_len
,
self
.
head_dim
,),
# 8: pack 8 fp16 in FT, if fp32 then use 4
"cache_k"
:
(
self
.
cache_batch_size
,
self
.
n_kv_heads
,
self
.
head_dim
//
8
,
max_seq_len
,
8
,),
"xqkv_view"
:
(
self
.
n_heads
+
self
.
n_kv_heads
*
2
,
self
.
head_dim
),
"xq_slice"
:
lambda
xqkv
:
xqkv
[:,
:,
0
:
self
.
n_heads
],
"xk_slice"
:
lambda
xqkv
:
xqkv
[:,
:,
self
.
n_heads
:
(
self
.
n_heads
+
self
.
n_kv_heads
)],
"xv_slice"
:
lambda
xqkv
:
xqkv
[:,
:,
-
self
.
n_kv_heads
:],
"xq_view"
:
(
self
.
n_heads
,
self
.
head_dim
),
"xk_view"
:
(
self
.
n_kv_heads
,
self
.
head_dim
),
"xv_view"
:
(
self
.
n_kv_heads
,
self
.
head_dim
),
"xk_reshape"
:
(
self
.
n_kv_heads
,
self
.
head_dim
//
8
,
8
),
"single_xq_view"
:
(
self
.
n_heads
,
self
.
head_dim
),
"single_xk_view"
:
(
self
.
n_kv_heads
,
self
.
head_dim
),
"single_xv_view"
:
(
self
.
n_kv_heads
,
self
.
head_dim
)
}
self
.
cache_v
=
(
self
.
cache_v
=
(
torch
.
zeros
(
self
.
attention_shapes
[
"cache_v"
]).
to
(
dev
).
half
()
torch
.
zeros
(
self
.
attention_shapes
[
"cache_v"
]).
to
(
dev
).
half
()
...
@@ -99,14 +125,14 @@ class QuantAttentionFused(nn.Module):
...
@@ -99,14 +125,14 @@ class QuantAttentionFused(nn.Module):
)
)
if
use_alibi
:
if
use_alibi
:
alibi_slopes
,
alibi_bias
=
build_alibi_bias
(
self
.
n_
local_
heads
,
max_seq_len
)
alibi_slopes
,
alibi_bias
=
build_alibi_bias
(
self
.
n_heads
,
max_seq_len
)
self
.
alibi_slopes
=
alibi_slopes
.
float
().
to
(
dev
)
self
.
alibi_slopes
=
alibi_slopes
.
float
().
to
(
dev
)
self
.
alibi_bias
=
alibi_bias
.
float
().
to
(
dev
)
self
.
alibi_bias
=
alibi_bias
.
float
().
to
(
dev
)
self
.
rotary_dim
=
0
self
.
rotary_dim
=
0
self
.
is_neox
=
False
self
.
is_neox
=
False
else
:
else
:
self
.
freqs_cis
=
precompute_freqs_cis
(
self
.
freqs_cis
=
precompute_freqs_cis
(
hidden_size
//
n
um
_heads
,
hidden_size
//
n_heads
,
max_seq_len
*
2
,
max_seq_len
*
2
,
).
to
(
dev
)
).
to
(
dev
)
self
.
rotary_dim
=
self
.
head_dim
self
.
rotary_dim
=
self
.
head_dim
...
@@ -153,6 +179,11 @@ class QuantAttentionFused(nn.Module):
...
@@ -153,6 +179,11 @@ class QuantAttentionFused(nn.Module):
keys
=
xk
keys
=
xk
values
=
xv
values
=
xv
if
self
.
n_kv_groups
!=
0
:
keys
=
torch
.
repeat_interleave
(
keys
,
dim
=
2
,
repeats
=
self
.
n_kv_groups
)
values
=
torch
.
repeat_interleave
(
values
,
dim
=
2
,
repeats
=
self
.
n_kv_groups
)
past_key_value
=
(
xk
,
xv
)
if
use_cache
else
None
past_key_value
=
(
xk
,
xv
)
if
use_cache
else
None
xq
=
xq
.
transpose
(
1
,
2
)
xq
=
xq
.
transpose
(
1
,
2
)
...
...
awq/modules/fused/block.py
View file @
a5e8b048
...
@@ -6,9 +6,13 @@ class MPTBlock(nn.Module):
...
@@ -6,9 +6,13 @@ class MPTBlock(nn.Module):
def
__init__
(
self
,
hidden_size
,
n_heads
,
qkv_layer
,
o_proj
,
mpt_mlp
,
norm_1
,
norm_2
,
dev
,
max_seq_len
):
def
__init__
(
self
,
hidden_size
,
n_heads
,
qkv_layer
,
o_proj
,
mpt_mlp
,
norm_1
,
norm_2
,
dev
,
max_seq_len
):
super
().
__init__
()
super
().
__init__
()
self
.
n_heads
=
n_heads
self
.
n_heads
=
n_heads
self
.
n_kv_heads
=
0
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
self
.
norm_1
=
norm_1
self
.
norm_1
=
norm_1
self
.
attn
=
QuantAttentionFused
(
hidden_size
,
self
.
n_heads
,
qkv_layer
,
o_proj
,
dev
=
dev
,
max_seq_len
=
max_seq_len
,
use_alibi
=
True
).
to
(
dev
)
self
.
attn
=
QuantAttentionFused
(
hidden_size
,
self
.
n_heads
,
self
.
n_kv_heads
,
qkv_layer
,
o_proj
,
dev
=
dev
,
max_seq_len
=
max_seq_len
,
use_alibi
=
True
).
to
(
dev
)
self
.
norm_2
=
norm_2
self
.
norm_2
=
norm_2
self
.
ffn
=
mpt_mlp
.
to
(
dev
)
self
.
ffn
=
mpt_mlp
.
to
(
dev
)
...
@@ -30,16 +34,22 @@ class MPTBlock(nn.Module):
...
@@ -30,16 +34,22 @@ class MPTBlock(nn.Module):
return
out
,
None
,
past_key_value
return
out
,
None
,
past_key_value
class
FalconDecoderLayer
(
nn
.
Module
):
class
FalconDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
n_heads
,
qkv_layer
,
o_proj
,
mlp
,
dev
,
max_seq_len
,
input_layernorm
=
None
,
ln_attn
=
None
,
ln_mlp
=
None
,
new_decoder_arch
=
True
):
def
__init__
(
self
,
hidden_size
,
n_heads
,
qkv_layer
,
o_proj
,
mlp
,
dev
,
max_seq_len
,
input_layernorm
=
None
,
ln_attn
=
None
,
ln_mlp
=
None
,
new_decoder_arch
=
True
):
super
().
__init__
()
super
().
__init__
()
self
.
n_heads
=
n_heads
self
.
n_heads
=
n_heads
self
.
n_kv_heads
=
8
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
self
.
new_decoder_arch
=
new_decoder_arch
self
.
new_decoder_arch
=
new_decoder_arch
attention_shapes
=
self
.
_get_attention_shapes
(
n_heads
,
max_seq_len
,
self
.
hidden_size
//
n_heads
,
new_decoder_arch
)
if
new_decoder_arch
:
attention_shapes
=
None
else
:
attention_shapes
=
self
.
_get_attention_shapes
(
n_heads
,
max_seq_len
,
self
.
hidden_size
//
n_heads
)
# TODO: Falcon has ALiBi implemented but which model uses it?
# TODO: Falcon has ALiBi implemented but which model uses it?
self
.
attn
=
QuantAttentionFused
(
self
.
attn
=
QuantAttentionFused
(
hidden_size
,
self
.
n_heads
,
qkv_layer
,
o_proj
,
hidden_size
,
self
.
n_heads
,
self
.
n_kv_heads
,
qkv_layer
,
o_proj
,
dev
=
dev
,
max_seq_len
=
max_seq_len
,
use_alibi
=
False
,
dev
=
dev
,
max_seq_len
=
max_seq_len
,
use_alibi
=
False
,
attention_shapes
=
attention_shapes
attention_shapes
=
attention_shapes
).
to
(
dev
)
).
to
(
dev
)
...
@@ -52,47 +62,26 @@ class FalconDecoderLayer(nn.Module):
...
@@ -52,47 +62,26 @@ class FalconDecoderLayer(nn.Module):
self
.
mlp
=
mlp
self
.
mlp
=
mlp
def
_get_attention_shapes
(
self
,
n_heads
,
max_seq_len
,
head_dim
,
new_decoder_arch
):
def
_get_attention_shapes
(
self
,
n_heads
,
max_seq_len
,
head_dim
):
batch_size
=
int
(
os
.
getenv
(
"AWQ_BATCH_SIZE"
,
"1"
))
batch_size
=
int
(
os
.
getenv
(
"AWQ_BATCH_SIZE"
,
"1"
))
if
new_decoder_arch
:
self
.
attention_shapes
=
{
kv_heads
=
8
# following fastertransformer definition
"cache_v"
:
(
batch_size
,
1
,
max_seq_len
,
head_dim
,),
self
.
attention_shapes
=
{
# 8: pack 8 fp16 in FT, if fp32 then use 4
# following fastertransformer definition
"cache_k"
:
(
batch_size
,
1
,
head_dim
//
8
,
max_seq_len
,
8
,),
"cache_v"
:
(
batch_size
,
n_heads
+
(
kv_heads
*
2
),
max_seq_len
,
head_dim
,),
"xqkv_view"
:
(
n_heads
+
2
,
head_dim
),
# 8: pack 8 fp16 in FT, if fp32 then use 4
"xq_slice"
:
lambda
xqkv
:
xqkv
[:,
:,
:
-
2
],
"cache_k"
:
(
batch_size
,
n_heads
+
(
kv_heads
*
2
),
head_dim
//
8
,
max_seq_len
,
8
,),
"xk_slice"
:
lambda
xqkv
:
xqkv
[:,
:,
[
-
2
]],
"xqkv_view"
:
(
-
1
,
n_heads
+
(
kv_heads
*
2
),
head_dim
),
"xv_slice"
:
lambda
xqkv
:
xqkv
[:,
:,
[
-
1
]],
"xq_slice"
:
lambda
xqkv
:
xqkv
[:,
:,
:,
0
],
"xq_view"
:
(
n_heads
,
head_dim
),
"xk_slice"
:
lambda
xqkv
:
xqkv
[:,
:,
:,
1
],
"xk_view"
:
(
1
,
head_dim
),
"xv_slice"
:
lambda
xqkv
:
xqkv
[:,
:,
:,
2
],
"xv_view"
:
(
1
,
head_dim
),
"xk_reshape"
:
(
1
,
head_dim
//
8
,
8
),
"xk_reshape"
:
(
1
,
head_dim
//
8
,
8
),
"xq_view"
:
(
1
,
head_dim
),
"single_xq_view"
:
(
n_heads
,
head_dim
),
"xk_view"
:
(
1
,
head_dim
),
"single_xk_view"
:
(
1
,
head_dim
),
"xv_view"
:
(
1
,
head_dim
),
"single_xv_view"
:
(
1
,
head_dim
)
"single_xq_view"
:
(
n_heads
,
head_dim
),
}
"single_xk_view"
:
(
1
,
8
,
head_dim
),
"single_xv_view"
:
(
1
,
8
,
head_dim
)
}
else
:
self
.
attention_shapes
=
{
# following fastertransformer definition
"cache_v"
:
(
batch_size
,
1
,
max_seq_len
,
head_dim
,),
# 8: pack 8 fp16 in FT, if fp32 then use 4
"cache_k"
:
(
batch_size
,
1
,
head_dim
//
8
,
max_seq_len
,
8
,),
"xqkv_view"
:
(
n_heads
+
2
,
head_dim
),
"xq_slice"
:
lambda
xqkv
:
xqkv
[:,
:,
:
-
2
],
"xk_slice"
:
lambda
xqkv
:
xqkv
[:,
:,
[
-
2
]],
"xv_slice"
:
lambda
xqkv
:
xqkv
[:,
:,
[
-
1
]],
"xk_reshape"
:
(
1
,
head_dim
//
8
,
8
),
"xq_view"
:
(
n_heads
,
head_dim
),
"xk_view"
:
(
1
,
head_dim
),
"xv_view"
:
(
1
,
head_dim
),
"single_xq_view"
:
(
n_heads
,
head_dim
),
"single_xk_view"
:
(
1
,
head_dim
),
"single_xv_view"
:
(
1
,
head_dim
)
}
return
self
.
attention_shapes
return
self
.
attention_shapes
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment