Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
df909e83
Unverified
Commit
df909e83
authored
Nov 11, 2023
by
Casper
Committed by
GitHub
Nov 11, 2023
Browse files
Reset cache on new generation (#178)
parent
5db86ec5
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
23 additions
and
13 deletions
+23
-13
awq/modules/fused/attn.py
awq/modules/fused/attn.py
+0
-9
awq/modules/fused/model.py
awq/modules/fused/model.py
+10
-4
awq/utils/fused_utils.py
awq/utils/fused_utils.py
+13
-0
No files found.
awq/modules/fused/attn.py
View file @
df909e83
...
...
@@ -128,15 +128,6 @@ class QuantAttentionFused(nn.Module):
f
"Batch size is incorrectly set - input batch size
{
bsz
}
, kv-cache batch size
{
self
.
cache_batch_size
}
. "
f
"Use: AutoAWQForCausalLM.from_quantized(batch_size=
{
bsz
}
)"
)
will_cache_be_exceeded
=
self
.
start_pos
+
seqlen
>
self
.
max_seq_len
# Reset and avoid retaining state when processing context
if
will_cache_be_exceeded
and
seqlen
>
1
:
self
.
start_pos
=
self
.
cache
.
roll_kv_n_steps
(
self
.
start_pos
,
n
=
self
.
start_pos
)
# Slowly roll out old tokens without performance hit if exceeded during decoding
elif
will_cache_be_exceeded
and
seqlen
==
1
:
self
.
start_pos
=
self
.
cache
.
roll_kv_n_steps
(
self
.
start_pos
,
n
=
100
)
xqkv
=
self
.
qkv_proj
(
hidden_states
)
xqkv
=
xqkv
.
view
((
bsz
,
seqlen
)
+
self
.
attention_shapes
[
"xqkv_view"
])
...
...
awq/modules/fused/model.py
View file @
df909e83
...
...
@@ -2,8 +2,8 @@ import torch
import
torch.nn
as
nn
from
typing
import
List
from
transformers.modeling_outputs
import
BaseModelOutputWithPast
from
awq.utils.fused_utils
import
prepare_attention_mask
,
prepare_input_ids
from
awq.modules.fused.block
import
MPTBlock
,
FalconDecoderLayer
,
LlamaLikeBlock
from
awq.utils.fused_utils
import
prepare_attention_mask
,
prepare_input_ids
,
prepare_cache
class
LlamaLikeModel
(
nn
.
Module
):
"""
...
...
@@ -24,8 +24,10 @@ class LlamaLikeModel(nn.Module):
input_ids
,
self
.
last_forward_num_tokens
)
_bsz
,
seqlen
=
input_ids
.
shape
prepare_cache
(
self
.
blocks
,
seqlen
)
h
=
self
.
embedding
(
input_ids
)
mask
=
prepare_attention_mask
(
...
...
@@ -58,8 +60,10 @@ class MPTModel(nn.Module):
input_ids
,
self
.
last_forward_num_tokens
)
_bsz
,
seqlen
=
input_ids
.
shape
prepare_cache
(
self
.
blocks
,
seqlen
)
h
=
self
.
wte
(
input_ids
)
mask
=
prepare_attention_mask
(
...
...
@@ -92,8 +96,10 @@ class FalconModel(nn.Module):
input_ids
,
self
.
last_forward_num_tokens
)
_bsz
,
seqlen
=
input_ids
.
shape
prepare_cache
(
self
.
blocks
,
seqlen
)
h
=
self
.
word_embeddings
(
input_ids
)
mask
=
prepare_attention_mask
(
...
...
awq/utils/fused_utils.py
View file @
df909e83
import
torch
from
awq.modules.linear
import
WQLinear_GEMM
,
WQLinear_GEMV
def
prepare_cache
(
blocks
,
seqlen
:
int
)
->
int
:
for
block
in
blocks
:
start_pos
=
block
.
attn
.
start_pos
will_cache_be_exceeded
=
start_pos
+
seqlen
>
block
.
attn
.
max_seq_len
# Reset and avoid retaining state when processing context
if
seqlen
>
1
and
(
will_cache_be_exceeded
or
seqlen
>
1
):
block
.
attn
.
start_pos
=
block
.
attn
.
cache
.
roll_kv_n_steps
(
start_pos
,
n
=
start_pos
)
# Slowly roll out old tokens without performance hit if exceeded during decoding
elif
seqlen
==
1
and
will_cache_be_exceeded
:
block
.
attn
.
start_pos
=
block
.
attn
.
cache
.
roll_kv_n_steps
(
start_pos
,
n
=
100
)
def
prepare_input_ids
(
input_ids
:
torch
.
Tensor
,
last_forward_num_tokens
:
int
):
# NOTE: from transformers 4.35.0, input_ids includes full context during decoding
num_input_tokens
=
input_ids
.
shape
[
-
1
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment