Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
74d0fe44
Unverified
Commit
74d0fe44
authored
Nov 15, 2023
by
Younes Belkada
Committed by
GitHub
Nov 15, 2023
Browse files
[`core`] Add `is_hf_transformers` flag (#195)
parent
3b362c0d
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
3 additions
and
1 deletion
+3
-1
awq/modules/fused/attn.py
awq/modules/fused/attn.py
+3
-1
No files found.
awq/modules/fused/attn.py
View file @
74d0fe44
...
@@ -100,6 +100,7 @@ class QuantAttentionFused(nn.Module):
...
@@ -100,6 +100,7 @@ class QuantAttentionFused(nn.Module):
self
.
use_alibi
=
use_alibi
self
.
use_alibi
=
use_alibi
self
.
cache_batch_size
=
int
(
os
.
getenv
(
"AWQ_BATCH_SIZE"
,
"1"
))
self
.
cache_batch_size
=
int
(
os
.
getenv
(
"AWQ_BATCH_SIZE"
,
"1"
))
self
.
max_seq_len
=
max_seq_len
self
.
max_seq_len
=
max_seq_len
self
.
is_hf_transformers
=
False
# attention shapes for self attention
# attention shapes for self attention
self
.
attention_shapes
=
get_attention_shapes
(
self
.
attention_shapes
=
get_attention_shapes
(
...
@@ -138,7 +139,8 @@ class QuantAttentionFused(nn.Module):
...
@@ -138,7 +139,8 @@ class QuantAttentionFused(nn.Module):
# In case we re-generate, we need to refresh the starting position
# In case we re-generate, we need to refresh the starting position
# to 0. We detect it by checking if `past_key_values` is set to None,
# to 0. We detect it by checking if `past_key_values` is set to None,
# which indicates that we are on the first step of `generate()`.
# which indicates that we are on the first step of `generate()`.
if
"past_key_value"
in
kwargs
and
kwargs
[
"past_key_value"
]
is
None
:
# This is only applicable for `transformers` integration
if
self
.
is_hf_transformers
and
"past_key_value"
in
kwargs
and
kwargs
[
"past_key_value"
]
is
None
:
self
.
start_pos
=
0
self
.
start_pos
=
0
xqkv
=
self
.
qkv_proj
(
hidden_states
)
xqkv
=
self
.
qkv_proj
(
hidden_states
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment