Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
0392a823
Commit
0392a823
authored
Sep 26, 2023
by
Casper Hansen
Browse files
Fix edge case
parent
3bb4a9f6
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
22 additions
and
8 deletions
+22
-8
awq/modules/fused/attn.py
awq/modules/fused/attn.py
+22
-8
No files found.
awq/modules/fused/attn.py
View file @
0392a823
...
@@ -5,6 +5,7 @@ import logging
...
@@ -5,6 +5,7 @@ import logging
import
torch.nn
as
nn
import
torch.nn
as
nn
import
awq_inference_engine
import
awq_inference_engine
from
torch.nn
import
functional
as
F
from
torch.nn
import
functional
as
F
from
awq.utils.utils
import
compute_memory_used_pct
try
:
try
:
import
ft_inference_engine
import
ft_inference_engine
...
@@ -164,16 +165,22 @@ class QuantAttentionFused(nn.Module):
...
@@ -164,16 +165,22 @@ class QuantAttentionFused(nn.Module):
f
"Use: AutoAWQForCausalLM.from_quantized(batch_size=
{
bsz
}
)"
f
"Use: AutoAWQForCausalLM.from_quantized(batch_size=
{
bsz
}
)"
)
)
if
self
.
start_pos
>
self
.
max_seq_len
:
if
self
.
start_pos
>
self
.
max_seq_len
or
self
.
start_pos
+
seqlen
>
self
.
max_seq_len
:
logging
.
warning
(
'You have exceeded max_new_tokens, resetting cache...'
)
logging
.
warning
(
'You have exceeded max_new_tokens, resetting cache...'
)
self
.
_initialize_cache
(
hidden_states
.
device
)
self
.
_initialize_cache
(
hidden_states
.
device
)
self
.
start_pos
=
0
self
.
start_pos
=
0
elif
seqlen
>
self
.
max_seq_len
:
elif
seqlen
>
self
.
max_seq_len
:
logging
.
warning
(
'Sequence length > max_seq_len, increasing and resetting cache...'
)
memory_used
=
compute_memory_used_pct
(
hidden_states
.
device
)
self
.
max_seq_len
*=
2
if
memory_used
<=
80
:
logging
.
warning
(
'Input sequence length > max_seq_len, increasing and resetting cache...'
)
self
.
max_seq_len
+=
self
.
max_seq_len
self
.
attention_shapes
=
self
.
_get_attention_shapes
(
None
,
self
.
max_seq_len
)
self
.
_initialize_cache
(
hidden_states
.
device
)
self
.
_initialize_cache
(
hidden_states
.
device
)
self
.
start_pos
=
0
self
.
start_pos
=
0
else
:
logging
.
error
(
'Input sequence length > max_seq_len, memory is filled, exiting...'
)
xqkv
=
self
.
qkv_proj
(
hidden_states
)
xqkv
=
self
.
qkv_proj
(
hidden_states
)
xqkv
=
xqkv
.
view
((
bsz
,
seqlen
)
+
self
.
attention_shapes
[
"xqkv_view"
])
xqkv
=
xqkv
.
view
((
bsz
,
seqlen
)
+
self
.
attention_shapes
[
"xqkv_view"
])
...
@@ -200,8 +207,15 @@ class QuantAttentionFused(nn.Module):
...
@@ -200,8 +207,15 @@ class QuantAttentionFused(nn.Module):
.
contiguous
()
.
contiguous
()
)
)
try
:
self
.
cache_v
[:
bsz
,
:,
self
.
start_pos
:
self
.
start_pos
+
seqlen
,
:]
=
values_store
self
.
cache_v
[:
bsz
,
:,
self
.
start_pos
:
self
.
start_pos
+
seqlen
,
:]
=
values_store
self
.
cache_k
[:
bsz
,
:,
:,
self
.
start_pos
:
self
.
start_pos
+
seqlen
,
:]
=
keys_store
self
.
cache_k
[:
bsz
,
:,
:,
self
.
start_pos
:
self
.
start_pos
+
seqlen
,
:]
=
keys_store
except
Exception
as
ex
:
print
(
seqlen
,
self
.
max_seq_len
)
print
(
self
.
cache_v
.
shape
,
self
.
cache_v
[:
bsz
,
:,
self
.
start_pos
:
self
.
start_pos
+
seqlen
,
:].
shape
,
values_store
.
shape
)
print
(
self
.
cache_k
.
shape
,
self
.
cache_k
[:
bsz
,
:,
:,
self
.
start_pos
:
self
.
start_pos
+
seqlen
,
:].
shape
,
keys_store
.
shape
)
print
(
ex
)
exit
(
0
)
if
seqlen
==
1
:
if
seqlen
==
1
:
xv
=
self
.
cache_v
[:
bsz
,
:,
:
self
.
start_pos
+
seqlen
,
:].
transpose
(
1
,
2
).
contiguous
()
xv
=
self
.
cache_v
[:
bsz
,
:,
:
self
.
start_pos
+
seqlen
,
:].
transpose
(
1
,
2
).
contiguous
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment