Unverified Commit dcd3046f authored by Julien Plu's avatar Julien Plu Committed by GitHub
Browse files

Better booleans handling in the TF models (#8777)

* Apply on BERT and ALBERT

* Update TF Bart

* Add input processing to TF BART

* Add input processing for TF CTRL

* Add input processing to TF Distilbert

* Add input processing to TF DPR

* Add input processing to TF Electra

* Add deprecated arguments

* Add input processing to TF XLM

* Add input processing to TF Funnel

* Add input processing to TF GPT2

* Add input processing to TF Longformer

* Add input processing to TF Lxmert

* Apply style

* Add input processing to TF Mobilebert

* Add input processing to TF GPT

* Add input processing to TF Roberta

* Add input processing to TF T5

* Add input processing to TF TransfoXL

* Apply style

* Rebase on master

* Bug fix

* Retry to bugfix

* Retry bug fix

* Fix wrong model name

* Try another fix

* Fix BART

* Fix input precessing

* Apply style

* Put the deprecated warnings in the input processing function

* Remove the unused imports

* Raise an error when len(kwargs)>0

* test ModelOutput instead of TFBaseModelOutput

* Bug fix

* Address Patrick's comments

* Address Patrick's comments

* Address Sylvain's comments

* Add boolean processing for the inputs

* Apply style

* Missing optional

* Fix missing some input proc

* Update the template

* Fix missing inputs

* Missing input

* Fix args parameter

* Trigger CI

* Trigger CI

* Trigger CI

* Address Patrick's and Sylvain's comments

* Replace warn by warning

* Trigger CI

* Fix XLNET

* Fix detection
parent 4c3d98dd
......@@ -474,6 +474,7 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.config = config
self.num_hidden_layers = config.num_hidden_layers
self.initializer_range = config.initializer_range
self.output_attentions = config.output_attentions
......@@ -513,6 +514,7 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
):
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
......@@ -525,14 +527,7 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
training=training,
kwargs_call=kwargs,
)
output_attentions = (
inputs["output_attentions"] if inputs["output_attentions"] is not None else self.output_attentions
)
output_hidden_states = (
inputs["output_hidden_states"] if inputs["output_hidden_states"] is not None else self.output_hidden_states
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.return_dict
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif inputs["input_ids"] is not None:
......@@ -585,9 +580,9 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
embedding_output,
extended_attention_mask,
inputs["head_mask"],
output_attentions,
output_hidden_states,
return_dict,
inputs["output_attentions"],
inputs["output_hidden_states"],
inputs["return_dict"],
training=inputs["training"],
)
......@@ -740,6 +735,7 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod
):
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
......@@ -817,6 +813,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForMaskedLM(TF{{cookiecutter.camelca
"""
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
......@@ -830,7 +827,6 @@ class TF{{cookiecutter.camelcase_modelname}}ForMaskedLM(TF{{cookiecutter.camelca
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.{{cookiecutter.lowercase_modelname}}.return_dict
outputs = self.{{cookiecutter.lowercase_modelname}}(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
......@@ -840,7 +836,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForMaskedLM(TF{{cookiecutter.camelca
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
return_dict=inputs["return_dict"],
training=inputs["training"],
)
......@@ -848,7 +844,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForMaskedLM(TF{{cookiecutter.camelca
prediction_scores = self.mlm(sequence_output, training=inputs["training"])
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], prediction_scores)
if not return_dict:
if not inputs["return_dict"]:
output = (prediction_scores,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
......@@ -930,6 +926,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForSequenceClassification(TF{{cookie
"""
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
......@@ -943,7 +940,6 @@ class TF{{cookiecutter.camelcase_modelname}}ForSequenceClassification(TF{{cookie
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.{{cookiecutter.lowercase_modelname}}.return_dict
outputs = self.{{cookiecutter.lowercase_modelname}}(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
......@@ -953,13 +949,13 @@ class TF{{cookiecutter.camelcase_modelname}}ForSequenceClassification(TF{{cookie
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
return_dict=inputs["return_dict"],
training=inputs["training"],
)
logits = self.classifier(outputs[0], training=inputs["training"])
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
if not return_dict:
if not inputs["return_dict"]:
output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
......@@ -1029,6 +1025,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice(TF{{cookiecutter.c
"""
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
......@@ -1042,8 +1039,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice(TF{{cookiecutter.c
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.{{cookiecutter.lowercase_modelname}}.config.return_dict
if inputs["input_ids"] is not None:
num_choices = shape_list(inputs["input_ids"])[1]
seq_length = shape_list(inputs["input_ids"])[2]
......@@ -1075,7 +1071,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice(TF{{cookiecutter.c
flat_inputs_embeds,
inputs["output_attentions"],
inputs["output_hidden_states"],
return_dict=return_dict,
return_dict=inputs["return_dict"],
training=inputs["training"],
)
logits = self.sequence_summary(outputs[0], training=inputs["training"])
......@@ -1083,7 +1079,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice(TF{{cookiecutter.c
reshaped_logits = tf.reshape(logits, (-1, num_choices))
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], reshaped_logits)
if not return_dict:
if not inputs["return_dict"]:
output = (reshaped_logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
......@@ -1141,6 +1137,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForTokenClassification(TF{{cookiecut
"""
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
......@@ -1154,7 +1151,6 @@ class TF{{cookiecutter.camelcase_modelname}}ForTokenClassification(TF{{cookiecut
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.{{cookiecutter.uppercase_modelname}}.return_dict
outputs = self.{{cookiecutter.uppercase_modelname}}(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
......@@ -1164,7 +1160,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForTokenClassification(TF{{cookiecut
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
return_dict=inputs["return_dict"],
training=inputs["training"],
)
sequence_output = outputs[0]
......@@ -1172,7 +1168,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForTokenClassification(TF{{cookiecut
logits = self.classifier(sequence_output)
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
if not return_dict:
if not inputs["return_dict"]:
output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
......@@ -1235,6 +1231,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering(TF{{cookiecutte
"""
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
......@@ -1249,7 +1246,6 @@ class TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering(TF{{cookiecutte
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.{{cookiecutter.uppercase_modelname}}.return_dict
outputs = self.{{cookiecutter.uppercase_modelname}}(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
......@@ -1259,7 +1255,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering(TF{{cookiecutte
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
return_dict=inputs["return_dict"],
training=inputs["training"],
)
sequence_output = outputs[0]
......@@ -1274,7 +1270,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering(TF{{cookiecutte
labels["end_position"] = inputs["end_positions"]
loss = self.compute_loss(labels, (start_logits, end_logits))
if not return_dict:
if not inputs["return_dict"]:
output = (start_logits, end_logits) + outputs[1:]
return ((loss,) + output) if loss is not None else output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment