Unverified Commit 20509ab0 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

TF: unpack_inputs decorator independent from main_input_name (#18110)

parent fcefa200
...@@ -404,9 +404,7 @@ def unpack_inputs(func): ...@@ -404,9 +404,7 @@ def unpack_inputs(func):
fn_args_and_kwargs.update(dict(zip(func.__code__.co_varnames[1:], args))) fn_args_and_kwargs.update(dict(zip(func.__code__.co_varnames[1:], args)))
# process the inputs and call the wrapped function # process the inputs and call the wrapped function
main_input_name = getattr(self, "main_input_name", func.__code__.co_varnames[1]) unpacked_inputs = input_processing(func, self.config, **fn_args_and_kwargs)
main_input = fn_args_and_kwargs.pop(main_input_name, None)
unpacked_inputs = input_processing(func, self.config, main_input, **fn_args_and_kwargs)
return func(self, **unpacked_inputs) return func(self, **unpacked_inputs)
# Keras enforces the first layer argument to be passed, and checks it through `inspect.getfullargspec()`. This # Keras enforces the first layer argument to be passed, and checks it through `inspect.getfullargspec()`. This
...@@ -417,7 +415,7 @@ def unpack_inputs(func): ...@@ -417,7 +415,7 @@ def unpack_inputs(func):
return run_call_with_unpacked_inputs return run_call_with_unpacked_inputs
def input_processing(func, config, input_ids, **kwargs): def input_processing(func, config, **kwargs):
""" """
Process the input of each TensorFlow model including the booleans. In case of a list of symbolic inputs, each input Process the input of each TensorFlow model including the booleans. In case of a list of symbolic inputs, each input
has to be named accordingly to the parameters name, i.e. `input_ids = tf.keras.Input(shape=(128,), dtype='int32', has to be named accordingly to the parameters name, i.e. `input_ids = tf.keras.Input(shape=(128,), dtype='int32',
...@@ -438,6 +436,8 @@ def input_processing(func, config, input_ids, **kwargs): ...@@ -438,6 +436,8 @@ def input_processing(func, config, input_ids, **kwargs):
has_kwargs = bool(signature.pop("kwargs", None)) has_kwargs = bool(signature.pop("kwargs", None))
signature.pop("self", None) signature.pop("self", None)
parameter_names = list(signature.keys()) parameter_names = list(signature.keys())
main_input_name = parameter_names[0]
main_input = kwargs.pop(main_input_name, None)
output = {} output = {}
allowed_types = (tf.Tensor, bool, int, ModelOutput, tuple, list, dict, np.ndarray, KerasTensor) allowed_types = (tf.Tensor, bool, int, ModelOutput, tuple, list, dict, np.ndarray, KerasTensor)
...@@ -483,8 +483,8 @@ def input_processing(func, config, input_ids, **kwargs): ...@@ -483,8 +483,8 @@ def input_processing(func, config, input_ids, **kwargs):
else: else:
raise ValueError(f"Data of type {type(v)} is not allowed only {allowed_types} is accepted for {k}.") raise ValueError(f"Data of type {type(v)} is not allowed only {allowed_types} is accepted for {k}.")
if isinstance(input_ids, (tuple, list)): if isinstance(main_input, (tuple, list)):
for i, input in enumerate(input_ids): for i, input in enumerate(main_input):
# EagerTensors don't allow to use the .name property so we check for a real Tensor # EagerTensors don't allow to use the .name property so we check for a real Tensor
if type(input) == tf.Tensor: if type(input) == tf.Tensor:
# Tensor names have always the pattern `name:id` then we check only the # Tensor names have always the pattern `name:id` then we check only the
...@@ -502,25 +502,25 @@ def input_processing(func, config, input_ids, **kwargs): ...@@ -502,25 +502,25 @@ def input_processing(func, config, input_ids, **kwargs):
f"Data of type {type(input)} is not allowed only {allowed_types} is accepted for" f"Data of type {type(input)} is not allowed only {allowed_types} is accepted for"
f" {parameter_names[i]}." f" {parameter_names[i]}."
) )
elif isinstance(input_ids, Mapping): elif isinstance(main_input, Mapping):
if "inputs" in input_ids: if "inputs" in main_input:
warnings.warn( warnings.warn(
"The `inputs` argument is deprecated and will be removed in a future version, use `input_ids`" "The `inputs` argument is deprecated and will be removed in a future version, use `input_ids`"
" instead.", " instead.",
FutureWarning, FutureWarning,
) )
output["input_ids"] = input_ids.pop("inputs") output["input_ids"] = main_input.pop("inputs")
if "decoder_cached_states" in input_ids: if "decoder_cached_states" in main_input:
warnings.warn( warnings.warn(
"The `decoder_cached_states` argument is deprecated and will be removed in a future version, use" "The `decoder_cached_states` argument is deprecated and will be removed in a future version, use"
" `past_key_values` instead.", " `past_key_values` instead.",
FutureWarning, FutureWarning,
) )
output["past_key_values"] = input_ids.pop("decoder_cached_states") output["past_key_values"] = main_input.pop("decoder_cached_states")
for k, v in dict(input_ids).items(): for k, v in dict(main_input).items():
if isinstance(v, allowed_types) or v is None: if isinstance(v, allowed_types) or v is None:
output[k] = v output[k] = v
elif k not in parameter_names and "args" not in parameter_names: elif k not in parameter_names and "args" not in parameter_names:
...@@ -531,12 +531,12 @@ def input_processing(func, config, input_ids, **kwargs): ...@@ -531,12 +531,12 @@ def input_processing(func, config, input_ids, **kwargs):
else: else:
raise ValueError(f"Data of type {type(v)} is not allowed only {allowed_types} is accepted for {k}.") raise ValueError(f"Data of type {type(v)} is not allowed only {allowed_types} is accepted for {k}.")
else: else:
if isinstance(input_ids, (tf.Tensor, KerasTensor)) or input_ids is None: if isinstance(main_input, (tf.Tensor, KerasTensor)) or main_input is None:
output[parameter_names[0]] = input_ids output[main_input_name] = main_input
else: else:
raise ValueError( raise ValueError(
f"Data of type {type(input_ids)} is not allowed only {allowed_types} is accepted for" f"Data of type {type(main_input)} is not allowed only {allowed_types} is accepted for"
f" {parameter_names[0]}." f" {main_input_name}."
) )
# Populates any unspecified argument with their default value, according to the signature. # Populates any unspecified argument with their default value, according to the signature.
......
...@@ -1881,6 +1881,7 @@ class UtilsFunctionsTest(unittest.TestCase): ...@@ -1881,6 +1881,7 @@ class UtilsFunctionsTest(unittest.TestCase):
def __init__(self): def __init__(self):
config_kwargs = {"output_attentions": False, "output_hidden_states": False, "return_dict": False} config_kwargs = {"output_attentions": False, "output_hidden_states": False, "return_dict": False}
self.config = PretrainedConfig(**config_kwargs) self.config = PretrainedConfig(**config_kwargs)
self.main_input_name = "input_ids"
@unpack_inputs @unpack_inputs
def call( def call(
...@@ -1888,9 +1889,14 @@ class UtilsFunctionsTest(unittest.TestCase): ...@@ -1888,9 +1889,14 @@ class UtilsFunctionsTest(unittest.TestCase):
): ):
return input_ids, past, output_attentions, output_hidden_states, return_dict return input_ids, past, output_attentions, output_hidden_states, return_dict
@unpack_inputs
def foo(self, pixel_values, output_attentions=None, output_hidden_states=None, return_dict=None):
return pixel_values, output_attentions, output_hidden_states, return_dict
dummy_model = DummyModel() dummy_model = DummyModel()
input_ids = tf.constant([0, 1, 2, 3]) input_ids = tf.constant([0, 1, 2, 3])
past = tf.constant([4, 5, 6, 7]) past = tf.constant([4, 5, 6, 7])
pixel_values = tf.constant([8, 9, 10, 11])
# test case 1: Pass inputs as keyword arguments; Booleans are inherited from the config. # test case 1: Pass inputs as keyword arguments; Booleans are inherited from the config.
output = dummy_model.call(input_ids=input_ids, past=past) output = dummy_model.call(input_ids=input_ids, past=past)
...@@ -1937,6 +1943,14 @@ class UtilsFunctionsTest(unittest.TestCase): ...@@ -1937,6 +1943,14 @@ class UtilsFunctionsTest(unittest.TestCase):
self.assertFalse(output[3]) self.assertFalse(output[3])
self.assertFalse(output[4]) self.assertFalse(output[4])
# test case 7: the decorator is independent from `main_input_name` -- it treats the first argument of the
# decorated function as its main input.
output = dummy_model.foo(pixel_values=pixel_values)
tf.debugging.assert_equal(output[0], pixel_values)
self.assertFalse(output[1])
self.assertFalse(output[2])
self.assertFalse(output[3])
# Tests whether the stable softmax is stable on CPU, with and without XLA # Tests whether the stable softmax is stable on CPU, with and without XLA
def test_xla_stable_softmax(self): def test_xla_stable_softmax(self):
large_penalty = -1e9 large_penalty = -1e9
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment