Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
9719202d
Unverified
Commit
9719202d
authored
May 02, 2024
by
Joao Gante
Committed by
GitHub
May 02, 2024
Browse files
Generate: fix `SinkCache` on Llama models (#30581)
parent
66abe139
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
22 additions
and
5 deletions
+22
-5
src/transformers/cache_utils.py
src/transformers/cache_utils.py
+22
-5
No files found.
src/transformers/cache_utils.py
View file @
9719202d
...
@@ -207,7 +207,9 @@ class SinkCache(Cache):
...
@@ -207,7 +207,9 @@ class SinkCache(Cache):
self
.
value_cache
:
List
[
torch
.
Tensor
]
=
[]
self
.
value_cache
:
List
[
torch
.
Tensor
]
=
[]
self
.
window_length
=
window_length
self
.
window_length
=
window_length
self
.
num_sink_tokens
=
num_sink_tokens
self
.
num_sink_tokens
=
num_sink_tokens
self
.
cos_sin_cache
=
{}
self
.
cos_sin_rerotation_cache
=
{}
self
.
_cos_cache
=
None
self
.
_sin_cache
=
None
self
.
_seen_tokens
=
0
# Used in `generate` to keep tally of how many tokens the cache has seen
self
.
_seen_tokens
=
0
# Used in `generate` to keep tally of how many tokens the cache has seen
@
staticmethod
@
staticmethod
...
@@ -225,7 +227,7 @@ class SinkCache(Cache):
...
@@ -225,7 +227,7 @@ class SinkCache(Cache):
def
_get_rerotation_cos_sin
(
def
_get_rerotation_cos_sin
(
self
,
key_states
:
torch
.
Tensor
,
cos
:
torch
.
Tensor
,
sin
:
torch
.
Tensor
self
,
key_states
:
torch
.
Tensor
,
cos
:
torch
.
Tensor
,
sin
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
key_states
.
shape
[
-
2
]
not
in
self
.
cos_sin_cache
:
if
key_states
.
shape
[
-
2
]
not
in
self
.
cos_sin_
rerotation_
cache
:
# Upcast to float32 temporarily for better accuracy
# Upcast to float32 temporarily for better accuracy
cos
=
cos
.
to
(
torch
.
float32
)
cos
=
cos
.
to
(
torch
.
float32
)
sin
=
sin
.
to
(
torch
.
float32
)
sin
=
sin
.
to
(
torch
.
float32
)
...
@@ -238,11 +240,11 @@ class SinkCache(Cache):
...
@@ -238,11 +240,11 @@ class SinkCache(Cache):
rerotation_cos
=
original_cos
*
shifted_cos
+
original_sin
*
shifted_sin
rerotation_cos
=
original_cos
*
shifted_cos
+
original_sin
*
shifted_sin
rerotation_sin
=
-
original_sin
*
shifted_cos
+
original_cos
*
shifted_sin
rerotation_sin
=
-
original_sin
*
shifted_cos
+
original_cos
*
shifted_sin
self
.
cos_sin_cache
[
key_states
.
shape
[
-
2
]]
=
(
self
.
cos_sin_
rerotation_
cache
[
key_states
.
shape
[
-
2
]]
=
(
rerotation_cos
.
to
(
key_states
.
dtype
).
unsqueeze
(
0
),
rerotation_cos
.
to
(
key_states
.
dtype
).
unsqueeze
(
0
),
rerotation_sin
.
to
(
key_states
.
dtype
).
unsqueeze
(
0
),
rerotation_sin
.
to
(
key_states
.
dtype
).
unsqueeze
(
0
),
)
)
return
self
.
cos_sin_cache
[
key_states
.
shape
[
-
2
]]
return
self
.
cos_sin_
rerotation_
cache
[
key_states
.
shape
[
-
2
]]
def
get_seq_length
(
self
,
layer_idx
:
Optional
[
int
]
=
0
)
->
int
:
def
get_seq_length
(
self
,
layer_idx
:
Optional
[
int
]
=
0
)
->
int
:
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
...
@@ -292,6 +294,21 @@ class SinkCache(Cache):
...
@@ -292,6 +294,21 @@ class SinkCache(Cache):
if
layer_idx
==
0
:
if
layer_idx
==
0
:
self
.
_seen_tokens
+=
key_states
.
shape
[
-
2
]
self
.
_seen_tokens
+=
key_states
.
shape
[
-
2
]
# Update the sin/cos cache, which holds sin/cos values for all possible positions
if
using_rope
and
layer_idx
==
0
:
# BC: some models still pass `sin`/`cos` with 2 dims. In those models, they are the full sin/cos. Remove
# after all RoPE models have a llama-like cache utilization.
if
cos
.
dim
()
==
2
:
self
.
_cos_cache
=
cos
self
.
_sin_cache
=
sin
else
:
if
self
.
_cos_cache
is
None
:
self
.
_cos_cache
=
cos
[
0
,
...]
self
.
_sin_cache
=
sin
[
0
,
...]
elif
self
.
_cos_cache
.
shape
[
0
]
<
self
.
window_length
:
self
.
_cos_cache
=
torch
.
cat
([
self
.
_cos_cache
,
cos
[
0
,
...]],
dim
=
0
)
self
.
_sin_cache
=
torch
.
cat
([
self
.
_sin_cache
,
sin
[
0
,
...]],
dim
=
0
)
# [bsz, num_heads, seq_len, head_dim]
# [bsz, num_heads, seq_len, head_dim]
if
len
(
self
.
key_cache
)
<=
layer_idx
:
if
len
(
self
.
key_cache
)
<=
layer_idx
:
# Empty cache
# Empty cache
...
@@ -312,7 +329,7 @@ class SinkCache(Cache):
...
@@ -312,7 +329,7 @@ class SinkCache(Cache):
# On RoPE models, we need to recompute the Key rotation as the tokens are shifted
# On RoPE models, we need to recompute the Key rotation as the tokens are shifted
if
using_rope
:
if
using_rope
:
rerotation_cos
,
rerotation_sin
=
self
.
_get_rerotation_cos_sin
(
rerotation_cos
,
rerotation_sin
=
self
.
_get_rerotation_cos_sin
(
key_states
,
cos
[:
self
.
window_length
],
s
in
[:
self
.
window_length
]
key_states
,
self
.
_cos_cache
[:
self
.
window_length
],
s
elf
.
_sin_cache
[:
self
.
window_length
]
)
)
if
partial_rotation_size
is
not
None
:
if
partial_rotation_size
is
not
None
:
keys_to_keep
,
keys_pass
=
(
keys_to_keep
,
keys_pass
=
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment