Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
d7badefc
Commit
d7badefc
authored
Sep 11, 2023
by
Casper Hansen
Browse files
Switch to torch SDPA
parent
54f02854
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
28 additions
and
18 deletions
+28
-18
awq/modules/fused/attn.py
awq/modules/fused/attn.py
+28
-18
No files found.
awq/modules/fused/attn.py
View file @
d7badefc
...
...
@@ -154,6 +154,28 @@ class QuantAttentionFused(nn.Module):
self
.
alibi_slopes
=
None
self
.
is_neox
=
True
def
_multi_query_attention_torch
(
self
,
query
,
key
,
value
,
batch_size
,
seqlen
,
use_cache
,
past_key_value
,
attention_mask
):
query
=
query
.
view
(
batch_size
,
seqlen
,
self
.
n_local_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
key
=
key
.
view
(
batch_size
,
seqlen
,
self
.
n_local_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
value
=
value
.
view
(
batch_size
,
seqlen
,
self
.
n_local_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
if
use_cache
:
key
=
key
.
contiguous
()
value
=
value
.
contiguous
()
query
=
query
.
contiguous
()
output
=
F
.
scaled_dot_product_attention
(
query
,
key
,
value
,
is_causal
=
past_key_value
is
None
,
attn_mask
=
attention_mask
)
del
query
,
key
,
value
output
=
output
.
transpose
(
1
,
2
).
reshape
(
batch_size
,
seqlen
,
self
.
hidden_size
)
return
output
def
forward
(
self
,
hidden_states
,
past_key_value
=
None
,
attention_mask
=
None
,
position_ids
=
None
,
output_attentions
=
False
,
use_cache
=
False
...
...
@@ -186,24 +208,12 @@ class QuantAttentionFused(nn.Module):
self
.
cache_v
[:
bsz
,
:,
self
.
start_pos
:
self
.
start_pos
+
seqlen
,
:]
=
values_store
self
.
cache_k
[:
bsz
,
:,
:,
self
.
start_pos
:
self
.
start_pos
+
seqlen
,
:]
=
keys_store
keys
=
xk
values
=
xv
past_key_value
=
(
xk
,
xv
)
if
use_cache
else
None
xq
=
xq
.
transpose
(
1
,
2
)
keys
=
keys
.
transpose
(
1
,
2
)
values
=
values
.
transpose
(
1
,
2
)
scores
=
torch
.
matmul
(
xq
,
keys
.
transpose
(
2
,
3
))
/
math
.
sqrt
(
self
.
head_dim
)
if
self
.
use_alibi
:
scores
+=
self
.
alibi_bias
[...,
:
seqlen
]
if
attention_mask
is
not
None
:
scores
=
scores
+
attention_mask
# (bs, n_local_heads, slen, cache_len + slen)
scores
=
F
.
softmax
(
scores
.
float
(),
dim
=-
1
).
type_as
(
xq
)
output
=
torch
.
matmul
(
scores
,
values
)
# (bs, n_local_heads, slen, head_dim)
output
=
output
.
transpose
(
1
,
2
).
contiguous
().
view
(
bsz
,
seqlen
,
-
1
)
output
=
self
.
_multi_query_attention_torch
(
xq
,
xk
,
xv
,
bsz
,
seqlen
,
True
,
past_key_value
,
attention_mask
)
else
:
xq
=
xq
[:,
0
,
:,
:]
xk
=
xk
[:,
0
,
:,
:]
...
...
@@ -231,4 +241,4 @@ class QuantAttentionFused(nn.Module):
else
:
self
.
start_pos
=
0
return
attn_output
,
None
,
past_key_value
return
attn_output
,
None
,
past_key_value
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment