Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
1f07200a
Unverified
Commit
1f07200a
authored
Apr 06, 2024
by
Younes Belkada
Committed by
GitHub
Apr 06, 2024
Browse files
FIX: Add safe guards for static cache + llama on transformers latest (#401)
parent
5d7b0502
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
18 additions
and
7 deletions
+18
-7
awq/modules/fused/attn.py
awq/modules/fused/attn.py
+18
-7
No files found.
awq/modules/fused/attn.py
View file @
1f07200a
...
...
@@ -188,16 +188,19 @@ class QuantAttentionFused(nn.Module):
# Always reset to 0
self
.
start_pos
=
0
hf_is_generating
=
False
if
self
.
is_hf_transformers
and
"use_cache"
in
kwargs
:
hf_is_generating
=
kwargs
[
"use_cache"
]
# In case we re-generate, we need to refresh the starting position
# to 0. We detect it by checking if `past_key_values` is set to None,
# which indicates that we are on the first step of `generate()`.
# This is only applicable for `transformers` integration
if
(
self
.
is_hf_transformers
and
"past_key_value"
in
kwargs
and
kwargs
[
"past_key_value"
]
is
None
):
if
(
self
.
is_hf_transformers
and
"past_key_value"
in
kwargs
and
kwargs
[
"past_key_value"
]
is
None
)
or
(
self
.
is_hf_transformers
and
not
hf_is_generating
):
self
.
start_pos
=
0
xqkv
=
self
.
qkv_proj
(
hidden_states
)
xqkv
=
xqkv
.
view
((
bsz
,
seqlen
)
+
self
.
attention_shapes
[
"xqkv_view"
])
...
...
@@ -214,8 +217,6 @@ class QuantAttentionFused(nn.Module):
if
not
self
.
use_alibi
:
xq
,
xk
=
self
.
rope
.
forward
(
xq
,
xk
,
self
.
start_pos
,
seqlen
)
self
.
cache
.
to
(
xq
)
values_store
=
xv
.
transpose
(
2
,
1
)
keys_store
=
(
xk
.
reshape
((
bsz
,
seqlen
)
+
self
.
attention_shapes
[
"xk_reshape"
])
...
...
@@ -223,6 +224,7 @@ class QuantAttentionFused(nn.Module):
.
contiguous
()
)
self
.
cache
.
to
(
xq
)
self
.
cache
.
update_kv
(
values_store
,
keys_store
,
bsz
,
self
.
start_pos
,
seqlen
)
# Only necessary to retrieve from cache when we are not processing context
...
...
@@ -248,6 +250,11 @@ class QuantAttentionFused(nn.Module):
# When seqlen is 1, there is nothing else to attend to
if
attention_mask
is
not
None
and
seqlen
>
1
:
# For llama-arch, the causal mask is preallocated with bsz x 1 x max_seq_len x max_seq_len, thus we
# need to slice it
if
attention_mask
.
shape
[
-
1
]
!=
seqlen
:
attention_mask
=
attention_mask
[:,
:,
:
seqlen
,
:
seqlen
]
scores
=
(
scores
+
attention_mask
)
# (bs, n_local_heads, slen, cache_len + slen)
...
...
@@ -278,11 +285,15 @@ class QuantAttentionFused(nn.Module):
attn_output
=
self
.
o_proj
(
attention_weight
)
self
.
start_pos
+=
seqlen
if
self
.
is_hf_transformers
and
not
hf_is_generating
:
self
.
start_pos
=
0
# past_key_value is replaced with cache_v, cache_k, returning empty data
# we pass a dummy past kv cache for transformers to be able to retrieve the correct info
# about past key length
past_key_value
=
[
torch
.
zeros
(
1
,
1
,
self
.
start_pos
,
1
)]
if
HF_NEW_CACHE_FORMAT
and
self
.
is_hf_transformers
:
new_cache
=
DynamicCache
()
new_cache
.
update
(
past_key_value
[
0
],
past_key_value
[
0
],
layer_idx
=
0
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment