Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
b13e2a85
Commit
b13e2a85
authored
Oct 06, 2023
by
Casper Hansen
Browse files
Remove attention sinks
parent
69733d2c
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
3 additions
and
11 deletions
+3
-11
awq/modules/fused/cache.py
awq/modules/fused/cache.py
+3
-11
No files found.
awq/modules/fused/cache.py
View file @
b13e2a85
import
torch
import
torch
class
WindowedCache
:
class
WindowedCache
:
def
__init__
(
self
,
cache_v_shape
,
cache_k_shape
,
device
,
attention_sinks
=
4
):
def
__init__
(
self
,
cache_v_shape
,
cache_k_shape
,
device
):
"""
"""
The window size is the same as the max_new_tokens. The window will
The window size is the same as the max_new_tokens. The window will
automatically roll once max_new_tokens is exceeded.
automatically roll once max_new_tokens is exceeded.
"""
"""
self
.
attention_sinks
=
attention_sinks
# [batch_size, n_kv_heads, max_seq_len, head_dim]
# [batch_size, n_kv_heads, max_seq_len, head_dim]
self
.
v
=
torch
.
zeros
(
cache_v_shape
).
to
(
device
).
half
()
self
.
v
=
torch
.
zeros
(
cache_v_shape
).
to
(
device
).
half
()
# [batch_size, n_kv_heads, head_dim // pack_factor, max_seq_len, pack_factor]
# [batch_size, n_kv_heads, head_dim // pack_factor, max_seq_len, pack_factor]
...
@@ -25,15 +23,9 @@ class WindowedCache:
...
@@ -25,15 +23,9 @@ class WindowedCache:
self
.
k
[:
batch_size
,
:,
:,
start_pos
:
start_pos
+
seqlen
,
:]
=
keys_store
self
.
k
[:
batch_size
,
:,
:,
start_pos
:
start_pos
+
seqlen
,
:]
=
keys_store
def
roll_kv
(
self
,
roll_len
,
start_pos
):
def
roll_kv
(
self
,
roll_len
,
start_pos
):
"""
With sink=0, roll_len=3, and [A,B,C,D,E] we get [D,E,F,G,H]
With sink=1, roll_len=3, and [A,B,C,D,E] we get [A,E,F,G,H]
With sink=2, roll_len=3, and [A,B,C,D,E] we get [A,B,F,G,H]
With sink=3, roll_len=3, and [A,B,C,D,E] we get [A,B,C,G,H]
"""
# Roll only the necessary part of the cache to the left
# Roll only the necessary part of the cache to the left
self
.
v
[:,
:,
self
.
attention_sinks
:
-
roll_len
+
self
.
attention_sinks
,
:]
=
self
.
v
[:,
:,
roll_len
:,
:]
self
.
v
[:,
:,
:
-
roll_len
,
:]
=
self
.
v
[:,
:,
roll_len
:,
:]
self
.
k
[:,
:,
:,
self
.
attention_sinks
:
-
roll_len
+
self
.
attention_sinks
,
:]
=
self
.
k
[:,
:,
:,
roll_len
:,
:]
self
.
k
[:,
:,
:,
:
-
roll_len
,
:]
=
self
.
k
[:,
:,
:,
roll_len
:,
:]
# Zero out the new part
# Zero out the new part
self
.
v
[:,
:,
-
roll_len
:,
:]
=
0
self
.
v
[:,
:,
-
roll_len
:,
:]
=
0
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment