Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
a3fb96a4
Unverified
Commit
a3fb96a4
authored
Jun 26, 2024
by
Joao Gante
Committed by
GitHub
Jun 26, 2024
Browse files
Generate: fix assisted generation with `past_key_values` passed as kwargs (#31644)
parent
492ee17e
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
21 additions
and
35 deletions
+21
-35
src/transformers/cache_utils.py
src/transformers/cache_utils.py
+9
-9
src/transformers/generation/candidate_generator.py
src/transformers/generation/candidate_generator.py
+11
-24
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+1
-2
No files found.
src/transformers/cache_utils.py
View file @
a3fb96a4
...
...
@@ -395,21 +395,21 @@ class DynamicCache(Cache):
cache
.
update
(
key_states
,
value_states
,
layer_idx
)
return
cache
def
crop
(
self
,
max
imum
_length
:
int
):
"""Crop the past key values up to a new `max
imum
_length` in terms of tokens. `max
imum
_length` can also be
negative to remove `max
imum
_length` tokens. This is used in assisted decoding and contrastive search."""
def
crop
(
self
,
max_length
:
int
):
"""Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be
negative to remove `max_length` tokens. This is used in assisted decoding and contrastive search."""
# In case it is negative
if
max
imum
_length
<
0
:
max
imum
_length
=
self
.
get_seq_length
()
-
abs
(
max
imum
_length
)
if
max_length
<
0
:
max_length
=
self
.
get_seq_length
()
-
abs
(
max_length
)
if
self
.
get_seq_length
()
<=
max
imum
_length
:
if
self
.
get_seq_length
()
<=
max_length
:
return
self
.
_seen_tokens
=
max
imum
_length
self
.
_seen_tokens
=
max_length
for
idx
in
range
(
len
(
self
.
key_cache
)):
self
.
key_cache
[
idx
]
=
self
.
key_cache
[
idx
][...,
:
max
imum
_length
,
:]
self
.
value_cache
[
idx
]
=
self
.
value_cache
[
idx
][...,
:
max
imum
_length
,
:]
self
.
key_cache
[
idx
]
=
self
.
key_cache
[
idx
][...,
:
max_length
,
:]
self
.
value_cache
[
idx
]
=
self
.
value_cache
[
idx
][...,
:
max_length
,
:]
def
batch_split
(
self
,
full_batch_size
:
int
,
split_size
:
int
)
->
List
[
"DynamicCache"
]:
"""Split the current instance into a list of `DynamicCache` by the batch size. This will be used by
...
...
src/transformers/generation/candidate_generator.py
View file @
a3fb96a4
...
...
@@ -111,24 +111,11 @@ class AssistedCandidateGenerator(CandidateGenerator):
# Prepare the kwargs for the assistant model
assistant_kwargs
=
{}
for
key
,
value
in
model_kwargs
.
items
():
# deepcopy crashes if we attempt to copy encoder outputs with grads
if
key
not
in
(
"encoder_outputs"
,
"assistant_encoder_outputs"
):
if
key
not
in
(
"encoder_outputs"
,
"assistant_encoder_outputs"
,
"past_key_values"
):
assistant_kwargs
[
key
]
=
(
value
.
detach
().
to
(
device
)
if
isinstance
(
value
,
torch
.
Tensor
)
else
copy
.
deepcopy
(
value
)
)
# Remove potential default DynamicCache if assistant does not support it
if
"past_key_values"
in
assistant_kwargs
.
keys
():
if
(
isinstance
(
assistant_kwargs
[
"past_key_values"
],
DynamicCache
)
and
not
self
.
assistant_model
.
_supports_cache_class
):
# Cache is empty -> remove it from kwargs
if
len
(
assistant_kwargs
[
"past_key_values"
])
==
0
:
del
assistant_kwargs
[
"past_key_values"
]
# Cache is not empty -> convert to legacy
else
:
assistant_kwargs
[
"past_key_values"
]
=
assistant_kwargs
[
"past_key_values"
].
to_legacy_cache
()
if
"assistant_encoder_outputs"
in
model_kwargs
:
assistant_kwargs
[
"encoder_outputs"
]
=
model_kwargs
[
"assistant_encoder_outputs"
]
elif
assistant_model
.
config
.
is_encoder_decoder
:
...
...
@@ -363,15 +350,15 @@ class PromptLookupCandidateGenerator(CandidateGenerator):
return
def
_crop_past_key_values
(
model
,
past_key_values
,
max
imum
_length
):
def
_crop_past_key_values
(
model
,
past_key_values
,
max_length
):
"""Crops the past key values up to a certain maximum length."""
new_past
=
[]
if
model
.
config
.
is_encoder_decoder
:
for
idx
in
range
(
len
(
past_key_values
)):
new_past
.
append
(
(
past_key_values
[
idx
][
0
][:,
:,
:
max
imum
_length
,
:],
past_key_values
[
idx
][
1
][:,
:,
:
max
imum
_length
,
:],
past_key_values
[
idx
][
0
][:,
:,
:
max_length
,
:],
past_key_values
[
idx
][
1
][:,
:,
:
max_length
,
:],
past_key_values
[
idx
][
2
],
past_key_values
[
idx
][
3
],
)
...
...
@@ -384,8 +371,8 @@ def _crop_past_key_values(model, past_key_values, maximum_length):
for
idx
in
range
(
len
(
past_key_values
)):
new_past
.
append
(
(
past_key_values
[
idx
][
0
][:,
:,
:
max
imum
_length
],
past_key_values
[
idx
][
1
][:,
:
max
imum
_length
,
:],
past_key_values
[
idx
][
0
][:,
:,
:
max_length
],
past_key_values
[
idx
][
1
][:,
:
max_length
,
:],
)
)
past_key_values
=
tuple
(
new_past
)
...
...
@@ -395,19 +382,19 @@ def _crop_past_key_values(model, past_key_values, maximum_length):
):
if
model
.
config
.
multi_query
:
for
idx
in
range
(
len
(
past_key_values
)):
past_key_values
[
idx
]
=
past_key_values
[
idx
][:,
:
max
imum
_length
,
:]
past_key_values
[
idx
]
=
past_key_values
[
idx
][:,
:
max_length
,
:]
else
:
for
idx
in
range
(
len
(
past_key_values
)):
past_key_values
[
idx
]
=
past_key_values
[
idx
][:,
:,
:
max
imum
_length
,
:]
past_key_values
[
idx
]
=
past_key_values
[
idx
][:,
:,
:
max_length
,
:]
elif
isinstance
(
past_key_values
,
DynamicCache
):
past_key_values
.
crop
(
max
imum
_length
)
past_key_values
.
crop
(
max_length
)
elif
past_key_values
is
not
None
:
for
idx
in
range
(
len
(
past_key_values
)):
new_past
.
append
(
(
past_key_values
[
idx
][
0
][:,
:,
:
max
imum
_length
,
:],
past_key_values
[
idx
][
1
][:,
:,
:
max
imum
_length
,
:],
past_key_values
[
idx
][
0
][:,
:,
:
max_length
,
:],
past_key_values
[
idx
][
1
][:,
:,
:
max_length
,
:],
)
)
past_key_values
=
tuple
(
new_past
)
...
...
src/transformers/generation/utils.py
View file @
a3fb96a4
...
...
@@ -3697,11 +3697,10 @@ class GenerationMixin:
model_kwargs
=
self
.
_get_initial_cache_position
(
input_ids
,
model_kwargs
)
# This is needed if return_dict_in_generate is True
start_from_empty_dynamic_cache
=
False
if
isinstance
(
model_kwargs
.
get
(
"past_key_values"
,
None
),
DynamicCache
):
if
len
(
model_kwargs
[
"past_key_values"
])
==
0
:
start_from_empty_dynamic_cache
=
True
else
:
start_from_empty_dynamic_cache
=
False
this_peer_finished
=
False
while
self
.
_has_unfinished_sequences
(
this_peer_finished
,
synced_gpus
,
device
=
input_ids
.
device
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment