Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
d973a0e0
Commit
d973a0e0
authored
Sep 20, 2023
by
Casper Hansen
Browse files
Use default attention shapes
parent
2d593b84
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
23 additions
and
40 deletions
+23
-40
awq/modules/fused/block.py
awq/modules/fused/block.py
+23
-40
No files found.
awq/modules/fused/block.py
View file @
d973a0e0
...
@@ -41,7 +41,11 @@ class FalconDecoderLayer(nn.Module):
...
@@ -41,7 +41,11 @@ class FalconDecoderLayer(nn.Module):
self
.
n_kv_heads
=
8
self
.
n_kv_heads
=
8
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
self
.
new_decoder_arch
=
new_decoder_arch
self
.
new_decoder_arch
=
new_decoder_arch
attention_shapes
=
self
.
_get_attention_shapes
(
n_heads
,
max_seq_len
,
self
.
hidden_size
//
n_heads
,
new_decoder_arch
)
if
new_decoder_arch
:
attention_shapes
=
None
else
:
attention_shapes
=
self
.
_get_attention_shapes
(
n_heads
,
max_seq_len
,
self
.
hidden_size
//
n_heads
)
# TODO: Falcon has ALiBi implemented but which model uses it?
# TODO: Falcon has ALiBi implemented but which model uses it?
self
.
attn
=
QuantAttentionFused
(
self
.
attn
=
QuantAttentionFused
(
...
@@ -58,47 +62,26 @@ class FalconDecoderLayer(nn.Module):
...
@@ -58,47 +62,26 @@ class FalconDecoderLayer(nn.Module):
self
.
mlp
=
mlp
self
.
mlp
=
mlp
def
_get_attention_shapes
(
self
,
n_heads
,
max_seq_len
,
head_dim
,
new_decoder_arch
):
def
_get_attention_shapes
(
self
,
n_heads
,
max_seq_len
,
head_dim
):
batch_size
=
int
(
os
.
getenv
(
"AWQ_BATCH_SIZE"
,
"1"
))
batch_size
=
int
(
os
.
getenv
(
"AWQ_BATCH_SIZE"
,
"1"
))
if
new_decoder_arch
:
self
.
attention_shapes
=
{
kv_heads
=
8
# following fastertransformer definition
"cache_v"
:
(
batch_size
,
1
,
max_seq_len
,
head_dim
,),
self
.
attention_shapes
=
{
# 8: pack 8 fp16 in FT, if fp32 then use 4
# following fastertransformer definition
"cache_k"
:
(
batch_size
,
1
,
head_dim
//
8
,
max_seq_len
,
8
,),
"cache_v"
:
(
batch_size
,
n_heads
+
(
kv_heads
*
2
),
max_seq_len
,
head_dim
,),
"xqkv_view"
:
(
n_heads
+
2
,
head_dim
),
# 8: pack 8 fp16 in FT, if fp32 then use 4
"xq_slice"
:
lambda
xqkv
:
xqkv
[:,
:,
:
-
2
],
"cache_k"
:
(
batch_size
,
n_heads
+
(
kv_heads
*
2
),
head_dim
//
8
,
max_seq_len
,
8
,),
"xk_slice"
:
lambda
xqkv
:
xqkv
[:,
:,
[
-
2
]],
"xqkv_view"
:
(
-
1
,
n_heads
+
(
kv_heads
*
2
),
head_dim
),
"xv_slice"
:
lambda
xqkv
:
xqkv
[:,
:,
[
-
1
]],
"xq_slice"
:
lambda
xqkv
:
xqkv
[:,
:,
:,
0
],
"xq_view"
:
(
n_heads
,
head_dim
),
"xk_slice"
:
lambda
xqkv
:
xqkv
[:,
:,
:,
1
],
"xk_view"
:
(
1
,
head_dim
),
"xv_slice"
:
lambda
xqkv
:
xqkv
[:,
:,
:,
2
],
"xv_view"
:
(
1
,
head_dim
),
"xq_view"
:
(
1
,
head_dim
),
"xk_reshape"
:
(
1
,
head_dim
//
8
,
8
),
"xk_view"
:
(
1
,
head_dim
),
"single_xq_view"
:
(
n_heads
,
head_dim
),
"xv_view"
:
(
1
,
head_dim
),
"single_xk_view"
:
(
1
,
head_dim
),
"xk_reshape"
:
(
1
,
head_dim
//
8
,
8
),
"single_xv_view"
:
(
1
,
head_dim
)
"single_xq_view"
:
(
n_heads
,
head_dim
),
}
"single_xk_view"
:
(
1
,
8
,
head_dim
),
"single_xv_view"
:
(
1
,
8
,
head_dim
)
}
else
:
self
.
attention_shapes
=
{
# following fastertransformer definition
"cache_v"
:
(
batch_size
,
1
,
max_seq_len
,
head_dim
,),
# 8: pack 8 fp16 in FT, if fp32 then use 4
"cache_k"
:
(
batch_size
,
1
,
head_dim
//
8
,
max_seq_len
,
8
,),
"xqkv_view"
:
(
n_heads
+
2
,
head_dim
),
"xq_slice"
:
lambda
xqkv
:
xqkv
[:,
:,
:
-
2
],
"xk_slice"
:
lambda
xqkv
:
xqkv
[:,
:,
[
-
2
]],
"xv_slice"
:
lambda
xqkv
:
xqkv
[:,
:,
[
-
1
]],
"xq_view"
:
(
n_heads
,
head_dim
),
"xk_view"
:
(
1
,
head_dim
),
"xv_view"
:
(
1
,
head_dim
),
"xk_reshape"
:
(
1
,
head_dim
//
8
,
8
),
"single_xq_view"
:
(
n_heads
,
head_dim
),
"single_xk_view"
:
(
1
,
head_dim
),
"single_xv_view"
:
(
1
,
head_dim
)
}
return
self
.
attention_shapes
return
self
.
attention_shapes
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment