Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
9719202d
Unverified
Commit
9719202d
authored
May 02, 2024
by
Joao Gante
Committed by
GitHub
May 02, 2024
Browse files
Generate: fix `SinkCache` on Llama models (#30581)
parent
66abe139
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
22 additions
and
5 deletions
+22
-5
src/transformers/cache_utils.py
src/transformers/cache_utils.py
+22
-5
No files found.
src/transformers/cache_utils.py
View file @
9719202d
...
...
@@ -207,7 +207,9 @@ class SinkCache(Cache):
self
.
value_cache
:
List
[
torch
.
Tensor
]
=
[]
self
.
window_length
=
window_length
self
.
num_sink_tokens
=
num_sink_tokens
self
.
cos_sin_cache
=
{}
self
.
cos_sin_rerotation_cache
=
{}
self
.
_cos_cache
=
None
self
.
_sin_cache
=
None
self
.
_seen_tokens
=
0
# Used in `generate` to keep tally of how many tokens the cache has seen
@
staticmethod
...
...
@@ -225,7 +227,7 @@ class SinkCache(Cache):
def
_get_rerotation_cos_sin
(
self
,
key_states
:
torch
.
Tensor
,
cos
:
torch
.
Tensor
,
sin
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
key_states
.
shape
[
-
2
]
not
in
self
.
cos_sin_cache
:
if
key_states
.
shape
[
-
2
]
not
in
self
.
cos_sin_
rerotation_
cache
:
# Upcast to float32 temporarily for better accuracy
cos
=
cos
.
to
(
torch
.
float32
)
sin
=
sin
.
to
(
torch
.
float32
)
...
...
@@ -238,11 +240,11 @@ class SinkCache(Cache):
rerotation_cos
=
original_cos
*
shifted_cos
+
original_sin
*
shifted_sin
rerotation_sin
=
-
original_sin
*
shifted_cos
+
original_cos
*
shifted_sin
self
.
cos_sin_cache
[
key_states
.
shape
[
-
2
]]
=
(
self
.
cos_sin_
rerotation_
cache
[
key_states
.
shape
[
-
2
]]
=
(
rerotation_cos
.
to
(
key_states
.
dtype
).
unsqueeze
(
0
),
rerotation_sin
.
to
(
key_states
.
dtype
).
unsqueeze
(
0
),
)
return
self
.
cos_sin_cache
[
key_states
.
shape
[
-
2
]]
return
self
.
cos_sin_
rerotation_
cache
[
key_states
.
shape
[
-
2
]]
def
get_seq_length
(
self
,
layer_idx
:
Optional
[
int
]
=
0
)
->
int
:
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
...
...
@@ -292,6 +294,21 @@ class SinkCache(Cache):
if
layer_idx
==
0
:
self
.
_seen_tokens
+=
key_states
.
shape
[
-
2
]
# Update the sin/cos cache, which holds sin/cos values for all possible positions
if
using_rope
and
layer_idx
==
0
:
# BC: some models still pass `sin`/`cos` with 2 dims. In those models, they are the full sin/cos. Remove
# after all RoPE models have a llama-like cache utilization.
if
cos
.
dim
()
==
2
:
self
.
_cos_cache
=
cos
self
.
_sin_cache
=
sin
else
:
if
self
.
_cos_cache
is
None
:
self
.
_cos_cache
=
cos
[
0
,
...]
self
.
_sin_cache
=
sin
[
0
,
...]
elif
self
.
_cos_cache
.
shape
[
0
]
<
self
.
window_length
:
self
.
_cos_cache
=
torch
.
cat
([
self
.
_cos_cache
,
cos
[
0
,
...]],
dim
=
0
)
self
.
_sin_cache
=
torch
.
cat
([
self
.
_sin_cache
,
sin
[
0
,
...]],
dim
=
0
)
# [bsz, num_heads, seq_len, head_dim]
if
len
(
self
.
key_cache
)
<=
layer_idx
:
# Empty cache
...
...
@@ -312,7 +329,7 @@ class SinkCache(Cache):
# On RoPE models, we need to recompute the Key rotation as the tokens are shifted
if
using_rope
:
rerotation_cos
,
rerotation_sin
=
self
.
_get_rerotation_cos_sin
(
key_states
,
cos
[:
self
.
window_length
],
s
in
[:
self
.
window_length
]
key_states
,
self
.
_cos_cache
[:
self
.
window_length
],
s
elf
.
_sin_cache
[:
self
.
window_length
]
)
if
partial_rotation_size
is
not
None
:
keys_to_keep
,
keys_pass
=
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment