Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
a6c82d45
Unverified
Commit
a6c82d45
authored
Nov 02, 2023
by
Joao Gante
Committed by
GitHub
Nov 02, 2023
Browse files
Generate: return `past_key_values` (#25086)
parent
441c3e0d
Changes
2
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
233 additions
and
5 deletions
+233
-5
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+112
-5
tests/generation/test_utils.py
tests/generation/test_utils.py
+121
-0
No files found.
src/transformers/generation/utils.py
View file @
a6c82d45
This diff is collapsed.
Click to expand it.
tests/generation/test_utils.py
View file @
a6c82d45
...
@@ -1829,6 +1829,85 @@ class GenerationTesterMixin:
...
@@ -1829,6 +1829,85 @@ class GenerationTesterMixin:
outputs_from_embeds_wo_ids
[:,
1
:].
tolist
(),
outputs_from_embeds_wo_ids
[:,
1
:].
tolist
(),
)
)
def
test_generate_continue_from_past_key_values
(
self
):
# Tests that we can continue generating from past key values, returned from a previous `generate` call
for
model_class
in
self
.
all_generative_model_classes
:
# won't fix: old models with unique inputs/caches/others
if
any
(
model_name
in
model_class
.
__name__
.
lower
()
for
model_name
in
[
"imagegpt"
]):
return
# may fix in the future: needs modeling or test input preparation fixes for compatibility
if
any
(
model_name
in
model_class
.
__name__
.
lower
()
for
model_name
in
[
"umt5"
]):
return
config
,
inputs
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
# If it doesn't support cache, pass the test
if
not
hasattr
(
config
,
"use_cache"
):
return
# Let's make it always:
# 1. use cache (for obvious reasons)
# 2. generate to max length (which can be achieved by setting the eos token to an invalid value), which
# would make the test flaky (e.g. EOS is generated on iteration 1 on both generations, but the
# continuation would force it to generate beyond an EOS token)
# 3. ignore `token_type_ids` for simplicity
# 4. ignore `forced_eos_token_id`, which requires further manipulation of the continuation inputs and is
# active by default on some models
config
.
use_cache
=
True
if
"token_type_ids"
in
inputs
:
del
inputs
[
"token_type_ids"
]
model
=
model_class
(
config
).
to
(
torch_device
)
model
.
eval
()
model
.
generation_config
.
pad_token_id
=
model
.
generation_config
.
eos_token_id
=
-
1
model
.
generation_config
.
forced_eos_token_id
=
None
# If "past_key_values" is not returned, pass the test (e.g. RWKV uses a different cache name and format)
outputs
=
model
(
**
inputs
)
if
"past_key_values"
not
in
outputs
:
return
# Traditional way of generating text, with `return_dict_in_generate` to return the past key values
outputs
=
model
.
generate
(
**
inputs
,
do_sample
=
False
,
max_new_tokens
=
4
,
return_dict_in_generate
=
True
)
# Let's generate again, but passing the past key values in between (3 + 1 = 4 tokens). Note that the
# inputs may need to be tweaked across `generate` calls (like the attention mask).
outputs_cached
=
model
.
generate
(
**
inputs
,
do_sample
=
False
,
max_new_tokens
=
3
,
return_dict_in_generate
=
True
)
# Continue from the tokens generated above, preparing the inputs accordingly
inputs
[
"past_key_values"
]
=
outputs_cached
.
past_key_values
new_attention_len
=
outputs_cached
.
sequences
.
shape
[
-
1
]
if
config
.
is_encoder_decoder
:
inputs
[
"decoder_input_ids"
]
=
outputs_cached
.
sequences
if
"decoder_attention_mask"
in
inputs
:
inputs
[
"decoder_attention_mask"
]
=
torch
.
nn
.
functional
.
pad
(
inputs
[
"decoder_attention_mask"
],
(
0
,
new_attention_len
-
inputs
[
"decoder_attention_mask"
].
shape
[
1
]),
mode
=
"constant"
,
value
=
1
,
)
else
:
inputs
[
"input_ids"
]
=
outputs_cached
.
sequences
if
"attention_mask"
in
inputs
:
inputs
[
"attention_mask"
]
=
torch
.
nn
.
functional
.
pad
(
inputs
[
"attention_mask"
],
(
0
,
new_attention_len
-
inputs
[
"attention_mask"
].
shape
[
1
]),
mode
=
"constant"
,
value
=
1
,
)
outputs_cached
=
model
.
generate
(
**
inputs
,
do_sample
=
False
,
max_new_tokens
=
1
,
return_dict_in_generate
=
True
)
# The two sets of generated text and past kv should be equal to each other
self
.
assertListEqual
(
outputs
.
sequences
.
tolist
(),
outputs_cached
.
sequences
.
tolist
())
for
layer_idx
in
range
(
len
(
outputs_cached
.
past_key_values
)):
for
kv_idx
in
range
(
len
(
outputs_cached
.
past_key_values
[
layer_idx
])):
self
.
assertTrue
(
torch
.
allclose
(
outputs
.
past_key_values
[
layer_idx
][
kv_idx
],
outputs_cached
.
past_key_values
[
layer_idx
][
kv_idx
],
)
)
def
_check_outputs
(
self
,
output
,
input_ids
,
config
,
use_cache
=
False
,
num_return_sequences
=
1
):
def
_check_outputs
(
self
,
output
,
input_ids
,
config
,
use_cache
=
False
,
num_return_sequences
=
1
):
batch_size
,
seq_length
=
input_ids
.
shape
batch_size
,
seq_length
=
input_ids
.
shape
num_sequences_in_output
=
batch_size
*
num_return_sequences
num_sequences_in_output
=
batch_size
*
num_return_sequences
...
@@ -1894,6 +1973,24 @@ class GenerationTesterMixin:
...
@@ -1894,6 +1973,24 @@ class GenerationTesterMixin:
use_cache
=
use_cache
,
use_cache
=
use_cache
,
)
)
# Past Key Value States -- two notes here:
# 1. Its inner sequence length is with respect to the inputs of the latest forward pass, hence the "-1"
# 2. Some old models still return `output.past_key_values` even without `use_cache=True`
# 3. TODO (joao): A few models have different formats, skipping those until the cache refactor is complete
models_without_standard_cache
=
(
"bloom"
,
"ctrl"
,
"fsmt"
,
"gptbigcode"
,
"mega"
,
"reformer"
)
has_standard_cache
=
not
any
(
model_name
in
config
.
__class__
.
__name__
.
lower
()
for
model_name
in
models_without_standard_cache
)
if
use_cache
and
has_standard_cache
:
past_key_values
=
output
.
past_key_values
past_sequence_length
=
output
.
sequences
.
shape
[
-
1
]
-
1
self
.
_check_past_key_values_for_generate
(
num_sequences_in_output
,
past_key_values
,
seq_length
=
past_sequence_length
,
config
=
config
,
)
def
_check_scores
(
self
,
batch_size
,
scores
,
length
,
config
):
def
_check_scores
(
self
,
batch_size
,
scores
,
length
,
config
):
expected_shape
=
(
batch_size
,
config
.
vocab_size
)
expected_shape
=
(
batch_size
,
config
.
vocab_size
)
self
.
assertIsInstance
(
scores
,
tuple
)
self
.
assertIsInstance
(
scores
,
tuple
)
...
@@ -1959,6 +2056,30 @@ class GenerationTesterMixin:
...
@@ -1959,6 +2056,30 @@ class GenerationTesterMixin:
[
encoder_expected_shape
]
*
len
(
hidden_states
),
[
encoder_expected_shape
]
*
len
(
hidden_states
),
)
)
def
_check_past_key_values_for_generate
(
self
,
batch_size
,
past_key_values
,
seq_length
,
config
,
num_beam_groups
=
1
):
self
.
assertIsInstance
(
past_key_values
,
tuple
)
self
.
assertListEqual
(
[
isinstance
(
iter_past_key_values
,
tuple
)
for
iter_past_key_values
in
past_key_values
],
[
True
]
*
len
(
past_key_values
),
)
# (batch, head, seq_length, head_features)
expected_shape
=
(
batch_size
*
num_beam_groups
,
config
.
num_key_value_heads
if
hasattr
(
config
,
"num_key_value_heads"
)
else
config
.
num_attention_heads
,
seq_length
,
config
.
hidden_size
//
config
.
num_attention_heads
,
)
# check shape key, value
self
.
assertListEqual
(
[
layer_past_key_values
[
0
].
shape
for
layer_past_key_values
in
past_key_values
],
[
expected_shape
]
*
len
(
past_key_values
),
)
self
.
assertListEqual
(
[
layer_past_key_values
[
1
].
shape
for
layer_past_key_values
in
past_key_values
],
[
expected_shape
]
*
len
(
past_key_values
),
)
def
_check_sequence_inside_sequence
(
self
,
tensor_1
,
tensor_2
):
def
_check_sequence_inside_sequence
(
self
,
tensor_1
,
tensor_2
):
# check if tensor_1 inside tensor_2 or tensor_2 inside tensor_1.
# check if tensor_1 inside tensor_2 or tensor_2 inside tensor_1.
# set to same device. we don't care what device.
# set to same device. we don't care what device.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment