Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
a6c82d45
Unverified
Commit
a6c82d45
authored
Nov 02, 2023
by
Joao Gante
Committed by
GitHub
Nov 02, 2023
Browse files
Generate: return `past_key_values` (#25086)
parent
441c3e0d
Changes
2
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
233 additions
and
5 deletions
+233
-5
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+112
-5
tests/generation/test_utils.py
tests/generation/test_utils.py
+121
-0
No files found.
src/transformers/generation/utils.py
View file @
a6c82d45
This diff is collapsed.
Click to expand it.
tests/generation/test_utils.py
View file @
a6c82d45
...
...
@@ -1829,6 +1829,85 @@ class GenerationTesterMixin:
outputs_from_embeds_wo_ids
[:,
1
:].
tolist
(),
)
def
test_generate_continue_from_past_key_values
(
self
):
# Tests that we can continue generating from past key values, returned from a previous `generate` call
for
model_class
in
self
.
all_generative_model_classes
:
# won't fix: old models with unique inputs/caches/others
if
any
(
model_name
in
model_class
.
__name__
.
lower
()
for
model_name
in
[
"imagegpt"
]):
return
# may fix in the future: needs modeling or test input preparation fixes for compatibility
if
any
(
model_name
in
model_class
.
__name__
.
lower
()
for
model_name
in
[
"umt5"
]):
return
config
,
inputs
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
# If it doesn't support cache, pass the test
if
not
hasattr
(
config
,
"use_cache"
):
return
# Let's make it always:
# 1. use cache (for obvious reasons)
# 2. generate to max length (which can be achieved by setting the eos token to an invalid value), which
# would make the test flaky (e.g. EOS is generated on iteration 1 on both generations, but the
# continuation would force it to generate beyond an EOS token)
# 3. ignore `token_type_ids` for simplicity
# 4. ignore `forced_eos_token_id`, which requires further manipulation of the continuation inputs and is
# active by default on some models
config
.
use_cache
=
True
if
"token_type_ids"
in
inputs
:
del
inputs
[
"token_type_ids"
]
model
=
model_class
(
config
).
to
(
torch_device
)
model
.
eval
()
model
.
generation_config
.
pad_token_id
=
model
.
generation_config
.
eos_token_id
=
-
1
model
.
generation_config
.
forced_eos_token_id
=
None
# If "past_key_values" is not returned, pass the test (e.g. RWKV uses a different cache name and format)
outputs
=
model
(
**
inputs
)
if
"past_key_values"
not
in
outputs
:
return
# Traditional way of generating text, with `return_dict_in_generate` to return the past key values
outputs
=
model
.
generate
(
**
inputs
,
do_sample
=
False
,
max_new_tokens
=
4
,
return_dict_in_generate
=
True
)
# Let's generate again, but passing the past key values in between (3 + 1 = 4 tokens). Note that the
# inputs may need to be tweaked across `generate` calls (like the attention mask).
outputs_cached
=
model
.
generate
(
**
inputs
,
do_sample
=
False
,
max_new_tokens
=
3
,
return_dict_in_generate
=
True
)
# Continue from the tokens generated above, preparing the inputs accordingly
inputs
[
"past_key_values"
]
=
outputs_cached
.
past_key_values
new_attention_len
=
outputs_cached
.
sequences
.
shape
[
-
1
]
if
config
.
is_encoder_decoder
:
inputs
[
"decoder_input_ids"
]
=
outputs_cached
.
sequences
if
"decoder_attention_mask"
in
inputs
:
inputs
[
"decoder_attention_mask"
]
=
torch
.
nn
.
functional
.
pad
(
inputs
[
"decoder_attention_mask"
],
(
0
,
new_attention_len
-
inputs
[
"decoder_attention_mask"
].
shape
[
1
]),
mode
=
"constant"
,
value
=
1
,
)
else
:
inputs
[
"input_ids"
]
=
outputs_cached
.
sequences
if
"attention_mask"
in
inputs
:
inputs
[
"attention_mask"
]
=
torch
.
nn
.
functional
.
pad
(
inputs
[
"attention_mask"
],
(
0
,
new_attention_len
-
inputs
[
"attention_mask"
].
shape
[
1
]),
mode
=
"constant"
,
value
=
1
,
)
outputs_cached
=
model
.
generate
(
**
inputs
,
do_sample
=
False
,
max_new_tokens
=
1
,
return_dict_in_generate
=
True
)
# The two sets of generated text and past kv should be equal to each other
self
.
assertListEqual
(
outputs
.
sequences
.
tolist
(),
outputs_cached
.
sequences
.
tolist
())
for
layer_idx
in
range
(
len
(
outputs_cached
.
past_key_values
)):
for
kv_idx
in
range
(
len
(
outputs_cached
.
past_key_values
[
layer_idx
])):
self
.
assertTrue
(
torch
.
allclose
(
outputs
.
past_key_values
[
layer_idx
][
kv_idx
],
outputs_cached
.
past_key_values
[
layer_idx
][
kv_idx
],
)
)
def
_check_outputs
(
self
,
output
,
input_ids
,
config
,
use_cache
=
False
,
num_return_sequences
=
1
):
batch_size
,
seq_length
=
input_ids
.
shape
num_sequences_in_output
=
batch_size
*
num_return_sequences
...
...
@@ -1894,6 +1973,24 @@ class GenerationTesterMixin:
use_cache
=
use_cache
,
)
# Past Key Value States -- two notes here:
# 1. Its inner sequence length is with respect to the inputs of the latest forward pass, hence the "-1"
# 2. Some old models still return `output.past_key_values` even without `use_cache=True`
# 3. TODO (joao): A few models have different formats, skipping those until the cache refactor is complete
models_without_standard_cache
=
(
"bloom"
,
"ctrl"
,
"fsmt"
,
"gptbigcode"
,
"mega"
,
"reformer"
)
has_standard_cache
=
not
any
(
model_name
in
config
.
__class__
.
__name__
.
lower
()
for
model_name
in
models_without_standard_cache
)
if
use_cache
and
has_standard_cache
:
past_key_values
=
output
.
past_key_values
past_sequence_length
=
output
.
sequences
.
shape
[
-
1
]
-
1
self
.
_check_past_key_values_for_generate
(
num_sequences_in_output
,
past_key_values
,
seq_length
=
past_sequence_length
,
config
=
config
,
)
def
_check_scores
(
self
,
batch_size
,
scores
,
length
,
config
):
expected_shape
=
(
batch_size
,
config
.
vocab_size
)
self
.
assertIsInstance
(
scores
,
tuple
)
...
...
@@ -1959,6 +2056,30 @@ class GenerationTesterMixin:
[
encoder_expected_shape
]
*
len
(
hidden_states
),
)
def
_check_past_key_values_for_generate
(
self
,
batch_size
,
past_key_values
,
seq_length
,
config
,
num_beam_groups
=
1
):
self
.
assertIsInstance
(
past_key_values
,
tuple
)
self
.
assertListEqual
(
[
isinstance
(
iter_past_key_values
,
tuple
)
for
iter_past_key_values
in
past_key_values
],
[
True
]
*
len
(
past_key_values
),
)
# (batch, head, seq_length, head_features)
expected_shape
=
(
batch_size
*
num_beam_groups
,
config
.
num_key_value_heads
if
hasattr
(
config
,
"num_key_value_heads"
)
else
config
.
num_attention_heads
,
seq_length
,
config
.
hidden_size
//
config
.
num_attention_heads
,
)
# check shape key, value
self
.
assertListEqual
(
[
layer_past_key_values
[
0
].
shape
for
layer_past_key_values
in
past_key_values
],
[
expected_shape
]
*
len
(
past_key_values
),
)
self
.
assertListEqual
(
[
layer_past_key_values
[
1
].
shape
for
layer_past_key_values
in
past_key_values
],
[
expected_shape
]
*
len
(
past_key_values
),
)
def
_check_sequence_inside_sequence
(
self
,
tensor_1
,
tensor_2
):
# check if tensor_1 inside tensor_2 or tensor_2 inside tensor_1.
# set to same device. we don't care what device.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment