Unverified Commit fcefa200 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

TF: remove graph mode distinction when processing boolean options (#18102)

parent bc34c211
......@@ -346,8 +346,7 @@ class TFNextSentencePredictionLoss:
def booleans_processing(config, **kwargs):
"""
Process the input booleans of each model in order to be sure they are compliant with the execution mode (eager or
graph)
Process the input booleans of each model.
Args:
config ([`PretrainedConfig`]):
......@@ -360,42 +359,21 @@ def booleans_processing(config, **kwargs):
"""
final_booleans = {}
if tf.executing_eagerly():
# Pure conv models (such as ConvNext) do not have `output_attentions`. If the signature has
# `output_attentions`, it will be present here in `kwargs`, even if unset (in that case, as `None`)
if "output_attentions" in kwargs:
final_booleans["output_attentions"] = (
kwargs["output_attentions"] if kwargs["output_attentions"] is not None else config.output_attentions
)
final_booleans["output_hidden_states"] = (
kwargs["output_hidden_states"]
if kwargs["output_hidden_states"] is not None
else config.output_hidden_states
)
final_booleans["return_dict"] = (
kwargs["return_dict"] if kwargs["return_dict"] is not None else config.return_dict
# Pure conv models (such as ConvNext) do not have `output_attentions`. If the signature has
# `output_attentions`, it will be present here in `kwargs`, even if unset (in that case, as `None`)
if "output_attentions" in kwargs:
final_booleans["output_attentions"] = (
kwargs["output_attentions"] if kwargs["output_attentions"] is not None else config.output_attentions
)
final_booleans["output_hidden_states"] = (
kwargs["output_hidden_states"] if kwargs["output_hidden_states"] is not None else config.output_hidden_states
)
final_booleans["return_dict"] = kwargs["return_dict"] if kwargs["return_dict"] is not None else config.return_dict
if "use_cache" in kwargs:
final_booleans["use_cache"] = (
kwargs["use_cache"] if kwargs["use_cache"] is not None else getattr(config, "use_cache", None)
)
else:
# Pure conv models (such as ConvNext) do not have `output_attentions`. If the signature has
# `output_attentions`, it will be present here in `kwargs`, even if unset (in that case, as `None`)
if "output_attentions" in kwargs:
final_booleans["output_attentions"] = config.output_attentions
final_booleans["output_hidden_states"] = config.output_hidden_states
if kwargs.get("return_dict", None) not in (None, True):
tf_logger.warning(
"The parameter `return_dict` cannot be set in graph mode and will always be set to `True`."
)
final_booleans["return_dict"] = True
if "use_cache" in kwargs:
final_booleans["use_cache"] = getattr(config, "use_cache", None)
if "use_cache" in kwargs:
final_booleans["use_cache"] = (
kwargs["use_cache"] if kwargs["use_cache"] is not None else getattr(config, "use_cache", None)
)
return final_booleans
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment