Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
7f8f9f16
Commit
7f8f9f16
authored
Sep 12, 2023
by
Casper Hansen
Browse files
xk_reshape key
parent
a8c9afd5
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
5 additions
and
4 deletions
+5
-4
awq/modules/fused/attn.py
awq/modules/fused/attn.py
+5
-4
No files found.
awq/modules/fused/attn.py
View file @
7f8f9f16
...
@@ -114,7 +114,7 @@ class QuantLlamaRotaryEmbedding(nn.Module):
...
@@ -114,7 +114,7 @@ class QuantLlamaRotaryEmbedding(nn.Module):
return
query
,
key
return
query
,
key
class
QuantAttentionFused
(
nn
.
Module
):
class
QuantAttentionFused
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
num_heads
,
qkv_layer
,
o_proj
,
dev
,
max_seq_len
,
use_alibi
=
False
):
def
__init__
(
self
,
hidden_size
,
num_heads
,
qkv_layer
,
o_proj
,
dev
,
max_seq_len
,
use_alibi
=
False
,
attention_shapes
=
None
):
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
self
.
n_local_heads
=
num_heads
self
.
n_local_heads
=
num_heads
...
@@ -124,7 +124,7 @@ class QuantAttentionFused(nn.Module):
...
@@ -124,7 +124,7 @@ class QuantAttentionFused(nn.Module):
self
.
start_pos
=
0
self
.
start_pos
=
0
self
.
use_alibi
=
use_alibi
self
.
use_alibi
=
use_alibi
self
.
cache_batch_size
=
1
self
.
cache_batch_size
=
1
self
.
attention_shapes
=
{
self
.
attention_shapes
=
attention_shapes
if
attention_shapes
is
not
None
else
{
# following fastertransformer definition
# following fastertransformer definition
"cache_v"
:
(
self
.
cache_batch_size
,
self
.
n_local_heads
,
max_seq_len
,
self
.
head_dim
,),
"cache_v"
:
(
self
.
cache_batch_size
,
self
.
n_local_heads
,
max_seq_len
,
self
.
head_dim
,),
# 8: pack 8 fp16 in FT, if fp32 then use 4
# 8: pack 8 fp16 in FT, if fp32 then use 4
...
@@ -133,13 +133,14 @@ class QuantAttentionFused(nn.Module):
...
@@ -133,13 +133,14 @@ class QuantAttentionFused(nn.Module):
"xq_slice"
:
lambda
xqkv
:
xqkv
[:,
:,
0
],
"xq_slice"
:
lambda
xqkv
:
xqkv
[:,
:,
0
],
"xk_slice"
:
lambda
xqkv
:
xqkv
[:,
:,
1
],
"xk_slice"
:
lambda
xqkv
:
xqkv
[:,
:,
1
],
"xv_slice"
:
lambda
xqkv
:
xqkv
[:,
:,
2
],
"xv_slice"
:
lambda
xqkv
:
xqkv
[:,
:,
2
],
"xk_reshape"
:
(
self
.
n_local_heads
,
self
.
head_dim
//
8
,
8
),
"xk_view"
:
(
self
.
n_local_heads
,
self
.
head_dim
),
"xk_view"
:
(
self
.
n_local_heads
,
self
.
head_dim
),
"xv_view"
:
(
self
.
n_local_heads
,
self
.
head_dim
),
"xv_view"
:
(
self
.
n_local_heads
,
self
.
head_dim
),
"single_xq_view"
:
(
self
.
n_local_heads
,
self
.
head_dim
),
"single_xq_view"
:
(
self
.
n_local_heads
,
self
.
head_dim
),
"single_xk_view"
:
(
self
.
n_local_heads
,
self
.
head_dim
),
"single_xk_view"
:
(
self
.
n_local_heads
,
self
.
head_dim
),
"single_xv_view"
:
(
self
.
n_local_heads
,
self
.
head_dim
)
"single_xv_view"
:
(
self
.
n_local_heads
,
self
.
head_dim
)
}
}
self
.
cache_v
=
(
self
.
cache_v
=
(
torch
.
zeros
(
self
.
attention_shapes
[
"cache_v"
]).
to
(
dev
).
half
()
torch
.
zeros
(
self
.
attention_shapes
[
"cache_v"
]).
to
(
dev
).
half
()
)
)
...
@@ -187,7 +188,7 @@ class QuantAttentionFused(nn.Module):
...
@@ -187,7 +188,7 @@ class QuantAttentionFused(nn.Module):
values_store
=
xv
.
transpose
(
2
,
1
)
values_store
=
xv
.
transpose
(
2
,
1
)
keys_store
=
(
keys_store
=
(
xk
.
reshape
(
bsz
,
seqlen
,
self
.
n_local_heads
,
self
.
head_dim
//
8
,
8
)
xk
.
reshape
(
(
bsz
,
seqlen
)
+
self
.
attention_shapes
[
"xk_reshape"
]
)
.
permute
(
0
,
2
,
3
,
1
,
4
)
.
permute
(
0
,
2
,
3
,
1
,
4
)
.
contiguous
()
.
contiguous
()
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment