"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "a4e530e3c89fcd1cba869587d6d04929bc28bbbe"
Unverified Commit b7018abf authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

TF: Unpack model inputs through a decorator (#15907)

* MVP

* apply decorator to TFBertModel

* finish updating bert

* update rembert (copy-linked to bert)

* update roberta (copy-linked to bert); Fix args

* Now working for non-text modalities
parent 19597998
...@@ -344,6 +344,46 @@ def booleans_processing(config, **kwargs): ...@@ -344,6 +344,46 @@ def booleans_processing(config, **kwargs):
return final_booleans return final_booleans
def unpack_inputs(func):
"""
Decorator that processes the inputs to a Keras layer, passing them to the layer as keyword arguments. This enables
downstream use of the inputs by their variable name, even if they arrive packed as a dictionary in the first input
(common case in Keras).
Args:
func (`callable`):
The callable function of the TensorFlow model.
Returns:
A callable that wraps the original `func` with the behavior described above.
"""
original_signature = inspect.signature(func)
@functools.wraps(func)
def run_call_with_unpacked_inputs(self, *args, **kwargs):
# isolates the actual `**kwargs` for the decorated function
kwargs_call = {key: val for key, val in kwargs.items() if key not in dict(original_signature.parameters)}
fn_args_and_kwargs = {key: val for key, val in kwargs.items() if key not in kwargs_call}
fn_args_and_kwargs.update({"kwargs_call": kwargs_call})
# move any arg into kwargs, if they exist
fn_args_and_kwargs.update(dict(zip(func.__code__.co_varnames[1:], args)))
# process the inputs and call the wrapped function
main_input_name = getattr(self, "main_input_name", func.__code__.co_varnames[1])
main_input = fn_args_and_kwargs.pop(main_input_name)
unpacked_inputs = input_processing(func, self.config, main_input, **fn_args_and_kwargs)
return func(self, **unpacked_inputs)
# Keras enforces the first layer argument to be passed, and checks it through `inspect.getfullargspec()`. This
# function does not follow wrapper chains (i.e. ignores `functools.wraps()`), meaning that without the line below
# Keras would attempt to check the first argument against the literal signature of the wrapper.
run_call_with_unpacked_inputs.__signature__ = original_signature
return run_call_with_unpacked_inputs
def input_processing(func, config, input_ids, **kwargs): def input_processing(func, config, input_ids, **kwargs):
""" """
Process the input of each TensorFlow model including the booleans. In case of a list of symbolic inputs, each input Process the input of each TensorFlow model including the booleans. In case of a list of symbolic inputs, each input
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment