Unverified Commit 9e8c37dc authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

TF - Fix interchangeable past/past_key_values and revert output variable name in GPT2 (#16332)

* revert tf gpt2

* add test for unpack_inputs and fix test case

* add changes to vision encoder decoder
parent 12428f0e
...@@ -372,7 +372,7 @@ def unpack_inputs(func): ...@@ -372,7 +372,7 @@ def unpack_inputs(func):
# process the inputs and call the wrapped function # process the inputs and call the wrapped function
main_input_name = getattr(self, "main_input_name", func.__code__.co_varnames[1]) main_input_name = getattr(self, "main_input_name", func.__code__.co_varnames[1])
main_input = fn_args_and_kwargs.pop(main_input_name) main_input = fn_args_and_kwargs.pop(main_input_name, None)
unpacked_inputs = input_processing(func, self.config, main_input, **fn_args_and_kwargs) unpacked_inputs = input_processing(func, self.config, main_input, **fn_args_and_kwargs)
return func(self, **unpacked_inputs) return func(self, **unpacked_inputs)
...@@ -423,13 +423,13 @@ def input_processing(func, config, input_ids, **kwargs): ...@@ -423,13 +423,13 @@ def input_processing(func, config, input_ids, **kwargs):
) )
output["past_key_values"] = kwargs["kwargs_call"].pop("decoder_cached_states") output["past_key_values"] = kwargs["kwargs_call"].pop("decoder_cached_states")
if "past" in kwargs["kwargs_call"] and "past_key_values" in kwargs: if "past" in kwargs["kwargs_call"] and "past_key_values" in parameter_names:
warnings.warn( warnings.warn(
"The `past` argument is deprecated and will be removed in a future version, use `past_key_values` instead.", "The `past` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
FutureWarning, FutureWarning,
) )
kwargs["past_key_values"] = kwargs["kwargs_call"].pop("past") kwargs["past_key_values"] = kwargs["kwargs_call"].pop("past")
elif "past_key_values" in kwargs["kwargs_call"] and "past" in kwargs: elif "past_key_values" in kwargs["kwargs_call"] and "past" in parameter_names:
kwargs["past"] = kwargs["kwargs_call"].pop("past_key_values") kwargs["past"] = kwargs["kwargs_call"].pop("past_key_values")
if len(kwargs["kwargs_call"]) > 0: if len(kwargs["kwargs_call"]) > 0:
...@@ -497,6 +497,7 @@ def input_processing(func, config, input_ids, **kwargs): ...@@ -497,6 +497,7 @@ def input_processing(func, config, input_ids, **kwargs):
f"Data of type {type(input_ids)} is not allowed only {allowed_types} is accepted for {parameter_names[0]}." f"Data of type {type(input_ids)} is not allowed only {allowed_types} is accepted for {parameter_names[0]}."
) )
# Populates any unspecified argument with their default value, according to the signature.
for name in parameter_names: for name in parameter_names:
if name not in list(output.keys()) and name != "args": if name not in list(output.keys()) and name != "args":
output[name] = kwargs.pop(name, signature[name].default) output[name] = kwargs.pop(name, signature[name].default)
......
...@@ -694,6 +694,9 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -694,6 +694,9 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss):
): ):
decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past=past) decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past=past)
decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None
past_key_values = decoder_inputs.get("past_key_values")
if past_key_values is None:
past_key_values = decoder_inputs.get("past") # e.g. on TF GPT2
input_dict = { input_dict = {
"input_ids": None, # needs to be passed to make Keras.layer.__call__ happy "input_ids": None, # needs to be passed to make Keras.layer.__call__ happy
"attention_mask": attention_mask, "attention_mask": attention_mask,
...@@ -701,7 +704,7 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -701,7 +704,7 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss):
"decoder_input_ids": decoder_inputs["input_ids"], "decoder_input_ids": decoder_inputs["input_ids"],
# TODO (joao): the `TFBaseModelOutput` wrapper should not be needed after the generate refactor is complete # TODO (joao): the `TFBaseModelOutput` wrapper should not be needed after the generate refactor is complete
"encoder_outputs": TFBaseModelOutput(last_hidden_state=encoder_outputs[0]), "encoder_outputs": TFBaseModelOutput(last_hidden_state=encoder_outputs[0]),
"past_key_values": decoder_inputs["past_key_values"], "past_key_values": past_key_values,
"use_cache": use_cache, "use_cache": use_cache,
} }
return input_dict return input_dict
......
...@@ -878,7 +878,7 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -878,7 +878,7 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss):
"input_ids": inputs, "input_ids": inputs,
"attention_mask": attention_mask, "attention_mask": attention_mask,
"position_ids": position_ids, "position_ids": position_ids,
"past_key_values": past, "past": past,
"use_cache": use_cache, "use_cache": use_cache,
} }
......
...@@ -725,6 +725,9 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos ...@@ -725,6 +725,9 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos
): ):
decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past=past) decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past=past)
decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None
past_key_values = decoder_inputs.get("past_key_values")
if past_key_values is None:
past_key_values = decoder_inputs.get("past") # e.g. on TF GPT2
input_dict = { input_dict = {
"pixel_values": None, # needs to be passed to make Keras.layer.__call__ happy "pixel_values": None, # needs to be passed to make Keras.layer.__call__ happy
"attention_mask": attention_mask, "attention_mask": attention_mask,
...@@ -732,7 +735,7 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos ...@@ -732,7 +735,7 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos
"decoder_input_ids": decoder_inputs["input_ids"], "decoder_input_ids": decoder_inputs["input_ids"],
# TODO (joao): the `TFBaseModelOutput` wrapper should not be needed after the generate refactor is complete # TODO (joao): the `TFBaseModelOutput` wrapper should not be needed after the generate refactor is complete
"encoder_outputs": TFBaseModelOutput(last_hidden_state=encoder_outputs[0]), "encoder_outputs": TFBaseModelOutput(last_hidden_state=encoder_outputs[0]),
"past_key_values": decoder_inputs["past_key_values"], "past_key_values": past_key_values,
"use_cache": use_cache, "use_cache": use_cache,
} }
return input_dict return input_dict
......
...@@ -27,6 +27,7 @@ from typing import List, Tuple ...@@ -27,6 +27,7 @@ from typing import List, Tuple
from huggingface_hub import delete_repo, login from huggingface_hub import delete_repo, login
from requests.exceptions import HTTPError from requests.exceptions import HTTPError
from transformers import is_tf_available from transformers import is_tf_available
from transformers.configuration_utils import PretrainedConfig
from transformers.models.auto import get_values from transformers.models.auto import get_values
from transformers.testing_utils import tooslow # noqa: F401 from transformers.testing_utils import tooslow # noqa: F401
from transformers.testing_utils import ( from transformers.testing_utils import (
...@@ -80,6 +81,7 @@ if is_tf_available(): ...@@ -80,6 +81,7 @@ if is_tf_available():
TFSampleDecoderOnlyOutput, TFSampleDecoderOnlyOutput,
TFSampleEncoderDecoderOutput, TFSampleEncoderDecoderOutput,
) )
from transformers.modeling_tf_utils import unpack_inputs
if _tf_gpu_memory_limit is not None: if _tf_gpu_memory_limit is not None:
gpus = tf.config.list_physical_devices("GPU") gpus = tf.config.list_physical_devices("GPU")
...@@ -1553,6 +1555,68 @@ class UtilsFunctionsTest(unittest.TestCase): ...@@ -1553,6 +1555,68 @@ class UtilsFunctionsTest(unittest.TestCase):
tf.debugging.assert_near(non_inf_output, non_inf_expected_output, rtol=1e-12) tf.debugging.assert_near(non_inf_output, non_inf_expected_output, rtol=1e-12)
tf.debugging.assert_equal(non_inf_idx, non_inf_expected_idx) tf.debugging.assert_equal(non_inf_idx, non_inf_expected_idx)
# tests whether the unpack_inputs function behaves as expected
def test_unpack_inputs(self):
class DummyModel:
def __init__(self):
config_kwargs = {"output_attentions": False, "output_hidden_states": False, "return_dict": False}
self.config = PretrainedConfig(**config_kwargs)
@unpack_inputs
def call(
self, input_ids=None, past=None, output_attentions=None, output_hidden_states=None, return_dict=None
):
return input_ids, past, output_attentions, output_hidden_states, return_dict
dummy_model = DummyModel()
input_ids = tf.constant([0, 1, 2, 3])
past = tf.constant([4, 5, 6, 7])
# test case 1: Pass inputs as keyword arguments; Booleans are inherited from the config.
output = dummy_model.call(input_ids=input_ids, past=past)
tf.debugging.assert_equal(output[0], input_ids)
tf.debugging.assert_equal(output[1], past)
self.assertFalse(output[2])
self.assertFalse(output[3])
self.assertFalse(output[4])
# test case 2: Same as above, but with positional arguments.
output = dummy_model.call(input_ids, past)
tf.debugging.assert_equal(output[0], input_ids)
tf.debugging.assert_equal(output[1], past)
self.assertFalse(output[2])
self.assertFalse(output[3])
self.assertFalse(output[4])
# test case 3: We can also pack everything in the first input.
output = dummy_model.call(input_ids={"input_ids": input_ids, "past": past})
tf.debugging.assert_equal(output[0], input_ids)
tf.debugging.assert_equal(output[1], past)
self.assertFalse(output[2])
self.assertFalse(output[3])
self.assertFalse(output[4])
# test case 4: Explicit boolean arguments should override the config.
output = dummy_model.call(input_ids=input_ids, past=past, output_attentions=False, return_dict=True)
tf.debugging.assert_equal(output[0], input_ids)
tf.debugging.assert_equal(output[1], past)
self.assertFalse(output[2])
self.assertFalse(output[3])
self.assertTrue(output[4])
# test case 5: Unexpected arguments should raise an exception.
with self.assertRaises(ValueError):
output = dummy_model.call(input_ids=input_ids, past=past, foo="bar")
# test case 6: Despite the above, `past_key_values` should be interchangeable with `past`
# (the decorator moves it to `past`, or vice-versa, depending on the signature).
output = dummy_model.call(input_ids=input_ids, past_key_values=past)
tf.debugging.assert_equal(output[0], input_ids)
tf.debugging.assert_equal(output[1], past)
self.assertFalse(output[2])
self.assertFalse(output[3])
self.assertFalse(output[4])
@require_tf @require_tf
@is_staging_test @is_staging_test
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment