Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
a3fb96a4
Unverified
Commit
a3fb96a4
authored
Jun 26, 2024
by
Joao Gante
Committed by
GitHub
Jun 26, 2024
Browse files
Generate: fix assisted generation with `past_key_values` passed as kwargs (#31644)
parent
492ee17e
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
21 additions
and
35 deletions
+21
-35
src/transformers/cache_utils.py
src/transformers/cache_utils.py
+9
-9
src/transformers/generation/candidate_generator.py
src/transformers/generation/candidate_generator.py
+11
-24
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+1
-2
No files found.
src/transformers/cache_utils.py
View file @
a3fb96a4
...
...
@@ -395,21 +395,21 @@ class DynamicCache(Cache):
cache
.
update
(
key_states
,
value_states
,
layer_idx
)
return
cache
def
crop
(
self
,
max
imum
_length
:
int
):
"""Crop the past key values up to a new `max
imum
_length` in terms of tokens. `max
imum
_length` can also be
negative to remove `max
imum
_length` tokens. This is used in assisted decoding and contrastive search."""
def
crop
(
self
,
max_length
:
int
):
"""Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be
negative to remove `max_length` tokens. This is used in assisted decoding and contrastive search."""
# In case it is negative
if
max
imum
_length
<
0
:
max
imum
_length
=
self
.
get_seq_length
()
-
abs
(
max
imum
_length
)
if
max_length
<
0
:
max_length
=
self
.
get_seq_length
()
-
abs
(
max_length
)
if
self
.
get_seq_length
()
<=
max
imum
_length
:
if
self
.
get_seq_length
()
<=
max_length
:
return
self
.
_seen_tokens
=
max
imum
_length
self
.
_seen_tokens
=
max_length
for
idx
in
range
(
len
(
self
.
key_cache
)):
self
.
key_cache
[
idx
]
=
self
.
key_cache
[
idx
][...,
:
max
imum
_length
,
:]
self
.
value_cache
[
idx
]
=
self
.
value_cache
[
idx
][...,
:
max
imum
_length
,
:]
self
.
key_cache
[
idx
]
=
self
.
key_cache
[
idx
][...,
:
max_length
,
:]
self
.
value_cache
[
idx
]
=
self
.
value_cache
[
idx
][...,
:
max_length
,
:]
def
batch_split
(
self
,
full_batch_size
:
int
,
split_size
:
int
)
->
List
[
"DynamicCache"
]:
"""Split the current instance into a list of `DynamicCache` by the batch size. This will be used by
...
...
src/transformers/generation/candidate_generator.py
View file @
a3fb96a4
...
...
@@ -111,24 +111,11 @@ class AssistedCandidateGenerator(CandidateGenerator):
# Prepare the kwargs for the assistant model
assistant_kwargs
=
{}
for
key
,
value
in
model_kwargs
.
items
():
# deepcopy crashes if we attempt to copy encoder outputs with grads
if
key
not
in
(
"encoder_outputs"
,
"assistant_encoder_outputs"
):
if
key
not
in
(
"encoder_outputs"
,
"assistant_encoder_outputs"
,
"past_key_values"
):
assistant_kwargs
[
key
]
=
(
value
.
detach
().
to
(
device
)
if
isinstance
(
value
,
torch
.
Tensor
)
else
copy
.
deepcopy
(
value
)
)
# Remove potential default DynamicCache if assistant does not support it
if
"past_key_values"
in
assistant_kwargs
.
keys
():
if
(
isinstance
(
assistant_kwargs
[
"past_key_values"
],
DynamicCache
)
and
not
self
.
assistant_model
.
_supports_cache_class
):
# Cache is empty -> remove it from kwargs
if
len
(
assistant_kwargs
[
"past_key_values"
])
==
0
:
del
assistant_kwargs
[
"past_key_values"
]
# Cache is not empty -> convert to legacy
else
:
assistant_kwargs
[
"past_key_values"
]
=
assistant_kwargs
[
"past_key_values"
].
to_legacy_cache
()
if
"assistant_encoder_outputs"
in
model_kwargs
:
assistant_kwargs
[
"encoder_outputs"
]
=
model_kwargs
[
"assistant_encoder_outputs"
]
elif
assistant_model
.
config
.
is_encoder_decoder
:
...
...
@@ -363,15 +350,15 @@ class PromptLookupCandidateGenerator(CandidateGenerator):
return
def
_crop_past_key_values
(
model
,
past_key_values
,
max
imum
_length
):
def
_crop_past_key_values
(
model
,
past_key_values
,
max_length
):
"""Crops the past key values up to a certain maximum length."""
new_past
=
[]
if
model
.
config
.
is_encoder_decoder
:
for
idx
in
range
(
len
(
past_key_values
)):
new_past
.
append
(
(
past_key_values
[
idx
][
0
][:,
:,
:
max
imum
_length
,
:],
past_key_values
[
idx
][
1
][:,
:,
:
max
imum
_length
,
:],
past_key_values
[
idx
][
0
][:,
:,
:
max_length
,
:],
past_key_values
[
idx
][
1
][:,
:,
:
max_length
,
:],
past_key_values
[
idx
][
2
],
past_key_values
[
idx
][
3
],
)
...
...
@@ -384,8 +371,8 @@ def _crop_past_key_values(model, past_key_values, maximum_length):
for
idx
in
range
(
len
(
past_key_values
)):
new_past
.
append
(
(
past_key_values
[
idx
][
0
][:,
:,
:
max
imum
_length
],
past_key_values
[
idx
][
1
][:,
:
max
imum
_length
,
:],
past_key_values
[
idx
][
0
][:,
:,
:
max_length
],
past_key_values
[
idx
][
1
][:,
:
max_length
,
:],
)
)
past_key_values
=
tuple
(
new_past
)
...
...
@@ -395,19 +382,19 @@ def _crop_past_key_values(model, past_key_values, maximum_length):
):
if
model
.
config
.
multi_query
:
for
idx
in
range
(
len
(
past_key_values
)):
past_key_values
[
idx
]
=
past_key_values
[
idx
][:,
:
max
imum
_length
,
:]
past_key_values
[
idx
]
=
past_key_values
[
idx
][:,
:
max_length
,
:]
else
:
for
idx
in
range
(
len
(
past_key_values
)):
past_key_values
[
idx
]
=
past_key_values
[
idx
][:,
:,
:
max
imum
_length
,
:]
past_key_values
[
idx
]
=
past_key_values
[
idx
][:,
:,
:
max_length
,
:]
elif
isinstance
(
past_key_values
,
DynamicCache
):
past_key_values
.
crop
(
max
imum
_length
)
past_key_values
.
crop
(
max_length
)
elif
past_key_values
is
not
None
:
for
idx
in
range
(
len
(
past_key_values
)):
new_past
.
append
(
(
past_key_values
[
idx
][
0
][:,
:,
:
max
imum
_length
,
:],
past_key_values
[
idx
][
1
][:,
:,
:
max
imum
_length
,
:],
past_key_values
[
idx
][
0
][:,
:,
:
max_length
,
:],
past_key_values
[
idx
][
1
][:,
:,
:
max_length
,
:],
)
)
past_key_values
=
tuple
(
new_past
)
...
...
src/transformers/generation/utils.py
View file @
a3fb96a4
...
...
@@ -3697,11 +3697,10 @@ class GenerationMixin:
model_kwargs
=
self
.
_get_initial_cache_position
(
input_ids
,
model_kwargs
)
# This is needed if return_dict_in_generate is True
start_from_empty_dynamic_cache
=
False
if
isinstance
(
model_kwargs
.
get
(
"past_key_values"
,
None
),
DynamicCache
):
if
len
(
model_kwargs
[
"past_key_values"
])
==
0
:
start_from_empty_dynamic_cache
=
True
else
:
start_from_empty_dynamic_cache
=
False
this_peer_finished
=
False
while
self
.
_has_unfinished_sequences
(
this_peer_finished
,
synced_gpus
,
device
=
input_ids
.
device
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment