Unverified Commit dcd3046f authored by Julien Plu's avatar Julien Plu Committed by GitHub
Browse files

Better booleans handling in the TF models (#8777)

* Apply on BERT and ALBERT

* Update TF Bart

* Add input processing to TF BART

* Add input processing for TF CTRL

* Add input processing to TF Distilbert

* Add input processing to TF DPR

* Add input processing to TF Electra

* Add deprecated arguments

* Add input processing to TF XLM

* Add input processing to TF Funnel

* Add input processing to TF GPT2

* Add input processing to TF Longformer

* Add input processing to TF Lxmert

* Apply style

* Add input processing to TF Mobilebert

* Add input processing to TF GPT

* Add input processing to TF Roberta

* Add input processing to TF T5

* Add input processing to TF TransfoXL

* Apply style

* Rebase on master

* Bug fix

* Retry to bugfix

* Retry bug fix

* Fix wrong model name

* Try another fix

* Fix BART

* Fix input precessing

* Apply style

* Put the deprecated warnings in the input processing function

* Remove the unused imports

* Raise an error when len(kwargs)>0

* test ModelOutput instead of TFBaseModelOutput

* Bug fix

* Address Patrick's comments

* Address Patrick's comments

* Address Sylvain's comments

* Add boolean processing for the inputs

* Apply style

* Missing optional

* Fix missing some input proc

* Update the template

* Fix missing inputs

* Missing input

* Fix args parameter

* Trigger CI

* Trigger CI

* Trigger CI

* Address Patrick's and Sylvain's comments

* Replace warn by warning

* Trigger CI

* Fix XLNET

* Fix detection
parent 4c3d98dd
...@@ -247,7 +247,81 @@ class TFNextSentencePredictionLoss: ...@@ -247,7 +247,81 @@ class TFNextSentencePredictionLoss:
return loss_fn(next_sentence_label, next_sentence_reduced_logits) return loss_fn(next_sentence_label, next_sentence_reduced_logits)
def input_processing(func, input_ids, **kwargs): def booleans_processing(config, **kwargs):
"""
Process the input booleans of each model in order to be sure they are compliant with the execution mode (eager or
graph)
Args:
config (:class:`~transformers.PretrainedConfig`):
The config of the running model.
**kwargs:
The boolean parameters
Returns:
A dictionary with the proper values for each boolean
"""
final_booleans = {}
if tf.executing_eagerly():
final_booleans["output_attentions"] = (
kwargs["output_attentions"] if kwargs["output_attentions"] is not None else config.output_attentions
)
final_booleans["output_hidden_states"] = (
kwargs["output_hidden_states"]
if kwargs["output_hidden_states"] is not None
else config.output_hidden_states
)
if "return_dict" in kwargs:
final_booleans["return_dict"] = (
kwargs["return_dict"] if kwargs["return_dict"] is not None else config.return_dict
)
if "use_cache" in kwargs:
final_booleans["use_cache"] = kwargs["use_cache"] if kwargs["use_cache"] is not None else config.use_cache
else:
if (
kwargs["output_attentions"] is not None
or kwargs["output_hidden_states"] is not None
or ("use_cache" in kwargs and kwargs["use_cache"] is not None)
):
logger.warning(
"The parameters `output_attentions`, `output_hidden_states` and `use_cache` cannot be updated when calling a model."
"They have to be set to True/False in the config object (i.e.: `config=XConfig.from_pretrained('name', output_attentions=True)`)."
)
final_booleans["output_attentions"] = config.output_attentions
final_booleans["output_hidden_states"] = config.output_hidden_states
if "return_dict" in kwargs:
if kwargs["return_dict"] is not None:
logger.warning(
"The parameter `return_dict` cannot be set in graph mode and will always be set to `True`."
)
final_booleans["return_dict"] = True
if "use_cache" in kwargs:
final_booleans["use_cache"] = config.use_cache
return final_booleans
def input_processing(func, config, input_ids, **kwargs):
"""
Process the input of each TensorFlow model including the booleans.
Args:
func (:obj:`callable`):
The callable function of the TensorFlow model.
config (:class:`~transformers.PretrainedConfig`):
The config of the running model.
**kwargs:
The inputs of the model.
Returns:
Two lists, one for the missing layers, and another one for the unexpected layers.
"""
signature = dict(inspect.signature(func).parameters) signature = dict(inspect.signature(func).parameters)
signature.pop("kwargs", None) signature.pop("kwargs", None)
parameter_names = list(signature.keys()) parameter_names = list(signature.keys())
...@@ -317,10 +391,15 @@ def input_processing(func, input_ids, **kwargs): ...@@ -317,10 +391,15 @@ def input_processing(func, input_ids, **kwargs):
output["past_key_values"] = input_ids.pop("decoder_cached_states") output["past_key_values"] = input_ids.pop("decoder_cached_states")
for k, v in dict(input_ids).items(): for k, v in dict(input_ids).items():
if not isinstance(v, allowed_types): if isinstance(v, allowed_types) or v is None:
raise ValueError(f"Data of type {type(v)} is not allowed only tf.Tensor is accepted for {k}.")
else:
output[k] = v output[k] = v
elif k not in parameter_names and "args" not in parameter_names:
logger.warn(
f"The parameter {k} does not belongs to the parameter list {parameter_names} and will be ignored."
)
continue
else:
raise ValueError(f"Data of type {type(v)} is not allowed only tf.Tensor is accepted for {k}.")
else: else:
if isinstance(input_ids, tf.Tensor) or input_ids is None: if isinstance(input_ids, tf.Tensor) or input_ids is None:
output[parameter_names[0]] = input_ids output[parameter_names[0]] = input_ids
...@@ -348,6 +427,19 @@ def input_processing(func, input_ids, **kwargs): ...@@ -348,6 +427,19 @@ def input_processing(func, input_ids, **kwargs):
if "kwargs" in output: if "kwargs" in output:
del output["kwargs"] del output["kwargs"]
boolean_dict = {
k: v
for k, v in output.items()
if k in ["return_dict", "output_attentions", "output_hidden_states", "use_cache"]
}
output.update(
booleans_processing(
config=config,
**boolean_dict,
)
)
return output return output
......
...@@ -484,9 +484,7 @@ class TFAlbertMainLayer(tf.keras.layers.Layer): ...@@ -484,9 +484,7 @@ class TFAlbertMainLayer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.num_hidden_layers = config.num_hidden_layers self.num_hidden_layers = config.num_hidden_layers
self.output_attentions = config.output_attentions self.config = config
self.output_hidden_states = config.output_hidden_states
self.return_dict = config.use_return_dict
self.embeddings = TFAlbertEmbeddings(config, name="embeddings") self.embeddings = TFAlbertEmbeddings(config, name="embeddings")
self.encoder = TFAlbertTransformer(config, name="encoder") self.encoder = TFAlbertTransformer(config, name="encoder")
...@@ -530,6 +528,7 @@ class TFAlbertMainLayer(tf.keras.layers.Layer): ...@@ -530,6 +528,7 @@ class TFAlbertMainLayer(tf.keras.layers.Layer):
): ):
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -543,14 +542,6 @@ class TFAlbertMainLayer(tf.keras.layers.Layer): ...@@ -543,14 +542,6 @@ class TFAlbertMainLayer(tf.keras.layers.Layer):
kwargs_call=kwargs, kwargs_call=kwargs,
) )
output_attentions = (
inputs["output_attentions"] if inputs["output_attentions"] is not None else self.output_attentions
)
output_hidden_states = (
inputs["output_hidden_states"] if inputs["output_hidden_states"] is not None else self.output_hidden_states
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.return_dict
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None: if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif inputs["input_ids"] is not None: elif inputs["input_ids"] is not None:
...@@ -603,16 +594,16 @@ class TFAlbertMainLayer(tf.keras.layers.Layer): ...@@ -603,16 +594,16 @@ class TFAlbertMainLayer(tf.keras.layers.Layer):
embedding_output, embedding_output,
extended_attention_mask, extended_attention_mask,
inputs["head_mask"], inputs["head_mask"],
output_attentions, inputs["output_attentions"],
output_hidden_states, inputs["output_hidden_states"],
return_dict, inputs["return_dict"],
training=inputs["training"], training=inputs["training"],
) )
sequence_output = encoder_outputs[0] sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output[:, 0]) pooled_output = self.pooler(sequence_output[:, 0])
if not return_dict: if not inputs["return_dict"]:
return ( return (
sequence_output, sequence_output,
pooled_output, pooled_output,
...@@ -778,6 +769,7 @@ class TFAlbertModel(TFAlbertPreTrainedModel): ...@@ -778,6 +769,7 @@ class TFAlbertModel(TFAlbertPreTrainedModel):
): ):
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -862,6 +854,7 @@ class TFAlbertForPreTraining(TFAlbertPreTrainedModel): ...@@ -862,6 +854,7 @@ class TFAlbertForPreTraining(TFAlbertPreTrainedModel):
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -875,7 +868,6 @@ class TFAlbertForPreTraining(TFAlbertPreTrainedModel): ...@@ -875,7 +868,6 @@ class TFAlbertForPreTraining(TFAlbertPreTrainedModel):
kwargs_call=kwargs, kwargs_call=kwargs,
) )
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.albert.return_dict
outputs = self.albert( outputs = self.albert(
inputs["input_ids"], inputs["input_ids"],
attention_mask=inputs["attention_mask"], attention_mask=inputs["attention_mask"],
...@@ -892,7 +884,7 @@ class TFAlbertForPreTraining(TFAlbertPreTrainedModel): ...@@ -892,7 +884,7 @@ class TFAlbertForPreTraining(TFAlbertPreTrainedModel):
prediction_scores = self.predictions(sequence_output) prediction_scores = self.predictions(sequence_output)
sop_scores = self.sop_classifier(pooled_output, training=inputs["training"]) sop_scores = self.sop_classifier(pooled_output, training=inputs["training"])
if not return_dict: if not inputs["return_dict"]:
return (prediction_scores, sop_scores) + outputs[2:] return (prediction_scores, sop_scores) + outputs[2:]
return TFAlbertForPreTrainingOutput( return TFAlbertForPreTrainingOutput(
...@@ -964,6 +956,7 @@ class TFAlbertForMaskedLM(TFAlbertPreTrainedModel, TFMaskedLanguageModelingLoss) ...@@ -964,6 +956,7 @@ class TFAlbertForMaskedLM(TFAlbertPreTrainedModel, TFMaskedLanguageModelingLoss)
""" """
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -977,8 +970,6 @@ class TFAlbertForMaskedLM(TFAlbertPreTrainedModel, TFMaskedLanguageModelingLoss) ...@@ -977,8 +970,6 @@ class TFAlbertForMaskedLM(TFAlbertPreTrainedModel, TFMaskedLanguageModelingLoss)
training=training, training=training,
kwargs_call=kwargs, kwargs_call=kwargs,
) )
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.albert.return_dict
outputs = self.albert( outputs = self.albert(
inputs["input_ids"], inputs["input_ids"],
attention_mask=inputs["attention_mask"], attention_mask=inputs["attention_mask"],
...@@ -995,8 +986,9 @@ class TFAlbertForMaskedLM(TFAlbertPreTrainedModel, TFMaskedLanguageModelingLoss) ...@@ -995,8 +986,9 @@ class TFAlbertForMaskedLM(TFAlbertPreTrainedModel, TFMaskedLanguageModelingLoss)
prediction_scores = self.predictions(sequence_output, training=inputs["training"]) prediction_scores = self.predictions(sequence_output, training=inputs["training"])
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], prediction_scores) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], prediction_scores)
if not return_dict: if not inputs["return_dict"]:
output = (prediction_scores,) + outputs[2:] output = (prediction_scores,) + outputs[2:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
return TFMaskedLMOutput( return TFMaskedLMOutput(
...@@ -1055,6 +1047,7 @@ class TFAlbertForSequenceClassification(TFAlbertPreTrainedModel, TFSequenceClass ...@@ -1055,6 +1047,7 @@ class TFAlbertForSequenceClassification(TFAlbertPreTrainedModel, TFSequenceClass
""" """
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -1068,7 +1061,6 @@ class TFAlbertForSequenceClassification(TFAlbertPreTrainedModel, TFSequenceClass ...@@ -1068,7 +1061,6 @@ class TFAlbertForSequenceClassification(TFAlbertPreTrainedModel, TFSequenceClass
training=training, training=training,
kwargs_call=kwargs, kwargs_call=kwargs,
) )
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.albert.return_dict
outputs = self.albert( outputs = self.albert(
inputs["input_ids"], inputs["input_ids"],
attention_mask=inputs["attention_mask"], attention_mask=inputs["attention_mask"],
...@@ -1086,8 +1078,9 @@ class TFAlbertForSequenceClassification(TFAlbertPreTrainedModel, TFSequenceClass ...@@ -1086,8 +1078,9 @@ class TFAlbertForSequenceClassification(TFAlbertPreTrainedModel, TFSequenceClass
logits = self.classifier(pooled_output) logits = self.classifier(pooled_output)
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
if not return_dict: if not inputs["return_dict"]:
output = (logits,) + outputs[2:] output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
return TFSequenceClassifierOutput( return TFSequenceClassifierOutput(
...@@ -1148,6 +1141,7 @@ class TFAlbertForTokenClassification(TFAlbertPreTrainedModel, TFTokenClassificat ...@@ -1148,6 +1141,7 @@ class TFAlbertForTokenClassification(TFAlbertPreTrainedModel, TFTokenClassificat
""" """
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -1161,7 +1155,6 @@ class TFAlbertForTokenClassification(TFAlbertPreTrainedModel, TFTokenClassificat ...@@ -1161,7 +1155,6 @@ class TFAlbertForTokenClassification(TFAlbertPreTrainedModel, TFTokenClassificat
training=training, training=training,
kwargs_call=kwargs, kwargs_call=kwargs,
) )
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.albert.return_dict
outputs = self.albert( outputs = self.albert(
inputs["input_ids"], inputs["input_ids"],
attention_mask=inputs["attention_mask"], attention_mask=inputs["attention_mask"],
...@@ -1179,8 +1172,9 @@ class TFAlbertForTokenClassification(TFAlbertPreTrainedModel, TFTokenClassificat ...@@ -1179,8 +1172,9 @@ class TFAlbertForTokenClassification(TFAlbertPreTrainedModel, TFTokenClassificat
logits = self.classifier(sequence_output) logits = self.classifier(sequence_output)
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
if not return_dict: if not inputs["return_dict"]:
output = (logits,) + outputs[2:] output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
return TFTokenClassifierOutput( return TFTokenClassifierOutput(
...@@ -1246,6 +1240,7 @@ class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel, TFQuestionAnsweringL ...@@ -1246,6 +1240,7 @@ class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel, TFQuestionAnsweringL
""" """
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -1260,7 +1255,6 @@ class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel, TFQuestionAnsweringL ...@@ -1260,7 +1255,6 @@ class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel, TFQuestionAnsweringL
training=training, training=training,
kwargs_call=kwargs, kwargs_call=kwargs,
) )
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.albert.return_dict
outputs = self.albert( outputs = self.albert(
inputs["input_ids"], inputs["input_ids"],
attention_mask=inputs["attention_mask"], attention_mask=inputs["attention_mask"],
...@@ -1285,8 +1279,9 @@ class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel, TFQuestionAnsweringL ...@@ -1285,8 +1279,9 @@ class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel, TFQuestionAnsweringL
labels["end_position"] = inputs["end_positions"] labels["end_position"] = inputs["end_positions"]
loss = self.compute_loss(labels, (start_logits, end_logits)) loss = self.compute_loss(labels, (start_logits, end_logits))
if not return_dict: if not inputs["return_dict"]:
output = (start_logits, end_logits) + outputs[2:] output = (start_logits, end_logits) + outputs[2:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
return TFQuestionAnsweringModelOutput( return TFQuestionAnsweringModelOutput(
...@@ -1355,6 +1350,7 @@ class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel, TFMultipleChoiceLoss): ...@@ -1355,6 +1350,7 @@ class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel, TFMultipleChoiceLoss):
""" """
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -1368,7 +1364,6 @@ class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel, TFMultipleChoiceLoss): ...@@ -1368,7 +1364,6 @@ class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel, TFMultipleChoiceLoss):
training=training, training=training,
kwargs_call=kwargs, kwargs_call=kwargs,
) )
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.albert.return_dict
if inputs["input_ids"] is not None: if inputs["input_ids"] is not None:
num_choices = shape_list(inputs["input_ids"])[1] num_choices = shape_list(inputs["input_ids"])[1]
...@@ -1400,7 +1395,7 @@ class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel, TFMultipleChoiceLoss): ...@@ -1400,7 +1395,7 @@ class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel, TFMultipleChoiceLoss):
flat_inputs_embeds, flat_inputs_embeds,
inputs["output_attentions"], inputs["output_attentions"],
inputs["output_hidden_states"], inputs["output_hidden_states"],
return_dict=return_dict, return_dict=inputs["return_dict"],
training=inputs["training"], training=inputs["training"],
) )
...@@ -1412,7 +1407,7 @@ class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel, TFMultipleChoiceLoss): ...@@ -1412,7 +1407,7 @@ class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel, TFMultipleChoiceLoss):
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], reshaped_logits) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], reshaped_logits)
if not return_dict: if not inputs["return_dict"]:
output = (reshaped_logits,) + outputs[2:] output = (reshaped_logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
......
...@@ -270,7 +270,7 @@ class TFEncoderLayer(tf.keras.layers.Layer): ...@@ -270,7 +270,7 @@ class TFEncoderLayer(tf.keras.layers.Layer):
if self.normalize_before: if self.normalize_before:
x = self.final_layer_norm(x) x = self.final_layer_norm(x)
x = self.activation_fn(self.fc1(x)) x = self.activation_fn(self.fc1(x))
x = tf.nn.dropout(x, rate=self.activation_dropout if training else 0) x = tf.nn.dropout(x, rate=self.self.activation_dropout if training else 0)
x = self.fc2(x) x = self.fc2(x)
x = tf.nn.dropout(x, rate=self.dropout if training else 0) x = tf.nn.dropout(x, rate=self.dropout if training else 0)
x = residual + x x = residual + x
...@@ -938,6 +938,7 @@ class TFBartModel(TFPretrainedBartModel): ...@@ -938,6 +938,7 @@ class TFBartModel(TFPretrainedBartModel):
): ):
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
...@@ -951,19 +952,17 @@ class TFBartModel(TFPretrainedBartModel): ...@@ -951,19 +952,17 @@ class TFBartModel(TFPretrainedBartModel):
training=training, training=training,
kwargs_call=kwargs, kwargs_call=kwargs,
) )
use_cache = inputs["use_cache"] if inputs["use_cache"] is not None else self.config.use_cache
if inputs["decoder_input_ids"] is None: # Classification if inputs["decoder_input_ids"] is None:
use_cache = False inputs["use_cache"] = False
output_attentions = (
inputs["output_attentions"] if inputs["output_attentions"] is not None else self.config.output_attentions inputs["output_hidden_states"] = (
)
output_hidden_states = (
inputs["output_hidden_states"] inputs["output_hidden_states"]
if inputs["output_hidden_states"] is not None if inputs["output_hidden_states"] is not None
else self.config.output_hidden_states else self.config.output_hidden_states
) )
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.config.return_dict
if not use_cache: if not inputs["use_cache"]:
inputs["decoder_input_ids"], decoder_padding_mask, causal_mask = self._prepare_bart_decoder_inputs( inputs["decoder_input_ids"], decoder_padding_mask, causal_mask = self._prepare_bart_decoder_inputs(
inputs["input_ids"], inputs["input_ids"],
decoder_input_ids=inputs["decoder_input_ids"], decoder_input_ids=inputs["decoder_input_ids"],
...@@ -972,25 +971,24 @@ class TFBartModel(TFPretrainedBartModel): ...@@ -972,25 +971,24 @@ class TFBartModel(TFPretrainedBartModel):
) )
else: else:
decoder_padding_mask, causal_mask = None, None decoder_padding_mask, causal_mask = None, None
if inputs["encoder_outputs"] is None: if inputs["encoder_outputs"] is None:
inputs["encoder_outputs"] = self.encoder( inputs["encoder_outputs"] = self.encoder(
input_ids=inputs["input_ids"], input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"], attention_mask=inputs["attention_mask"],
output_attentions=output_attentions, output_attentions=inputs["output_attentions"],
output_hidden_states=output_hidden_states, output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict, return_dict=inputs["return_dict"],
training=inputs["training"], training=inputs["training"],
) )
# If the user passed a tuple for encoder_outputs, we wrap it in a TFBaseModelOutput when return_dict=True # If the user passed a tuple for encoder_outputs, we wrap it in a TFBaseModelOutput when return_dict=True
elif return_dict and not isinstance(inputs["encoder_outputs"], TFBaseModelOutput): elif inputs["return_dict"] and not isinstance(inputs["encoder_outputs"], TFBaseModelOutput):
inputs["encoder_outputs"] = TFBaseModelOutput( inputs["encoder_outputs"] = TFBaseModelOutput(
last_hidden_state=inputs["encoder_outputs"][0], last_hidden_state=inputs["encoder_outputs"][0],
hidden_states=inputs["encoder_outputs"][1] if len(inputs["encoder_outputs"]) > 1 else None, hidden_states=inputs["encoder_outputs"][1] if len(inputs["encoder_outputs"]) > 1 else None,
attentions=inputs["encoder_outputs"][2] if len(inputs["encoder_outputs"]) > 2 else None, attentions=inputs["encoder_outputs"][2] if len(inputs["encoder_outputs"]) > 2 else None,
) )
# If the user passed a TFBaseModelOutput for encoder_outputs, we wrap it in a tuple when return_dict=False # If the user passed a TFBaseModelOutput for encoder_outputs, we wrap it in a tuple when return_dict=False
elif not return_dict and not isinstance(inputs["encoder_outputs"], tuple): elif not inputs["return_dict"] and not isinstance(inputs["encoder_outputs"], tuple):
inputs["encoder_outputs"] = inputs["encoder_outputs"].to_tuple() inputs["encoder_outputs"] = inputs["encoder_outputs"].to_tuple()
decoder_outputs = self.decoder( decoder_outputs = self.decoder(
...@@ -1000,14 +998,14 @@ class TFBartModel(TFPretrainedBartModel): ...@@ -1000,14 +998,14 @@ class TFBartModel(TFPretrainedBartModel):
decoder_padding_mask, decoder_padding_mask,
decoder_causal_mask=causal_mask, decoder_causal_mask=causal_mask,
decoder_cached_states=inputs["past_key_values"], decoder_cached_states=inputs["past_key_values"],
use_cache=use_cache, use_cache=inputs["use_cache"],
output_attentions=output_attentions, output_attentions=inputs["output_attentions"],
output_hidden_states=output_hidden_states, output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict, return_dict=inputs["return_dict"],
training=inputs["training"], training=inputs["training"],
) )
if not return_dict: if not inputs["return_dict"]:
return decoder_outputs + inputs["encoder_outputs"] return decoder_outputs + inputs["encoder_outputs"]
return TFSeq2SeqModelOutput( return TFSeq2SeqModelOutput(
...@@ -1090,6 +1088,7 @@ class TFBartForConditionalGeneration(TFPretrainedBartModel): ...@@ -1090,6 +1088,7 @@ class TFBartForConditionalGeneration(TFPretrainedBartModel):
""" """
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
...@@ -1104,10 +1103,9 @@ class TFBartForConditionalGeneration(TFPretrainedBartModel): ...@@ -1104,10 +1103,9 @@ class TFBartForConditionalGeneration(TFPretrainedBartModel):
training=training, training=training,
kwargs_call=kwargs, kwargs_call=kwargs,
) )
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.config.return_dict
use_cache = inputs["use_cache"] if inputs["use_cache"] is not None else self.config.use_cache
if inputs["labels"] is not None: if inputs["labels"] is not None:
use_cache = False inputs["use_cache"] = False
if inputs["decoder_input_ids"] is None: if inputs["decoder_input_ids"] is None:
inputs["decoder_input_ids"] = self._shift_right(inputs["labels"]) inputs["decoder_input_ids"] = self._shift_right(inputs["labels"])
...@@ -1118,19 +1116,18 @@ class TFBartForConditionalGeneration(TFPretrainedBartModel): ...@@ -1118,19 +1116,18 @@ class TFBartForConditionalGeneration(TFPretrainedBartModel):
encoder_outputs=inputs["encoder_outputs"], encoder_outputs=inputs["encoder_outputs"],
decoder_attention_mask=inputs["decoder_attention_mask"], decoder_attention_mask=inputs["decoder_attention_mask"],
past_key_values=inputs["past_key_values"], past_key_values=inputs["past_key_values"],
use_cache=use_cache, use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"], output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict, return_dict=inputs["return_dict"],
) )
lm_logits = self.model.shared(outputs[0], mode="linear") lm_logits = self.model.shared(outputs[0], mode="linear")
lm_logits = lm_logits + self.final_logits_bias lm_logits = lm_logits + self.final_logits_bias
masked_lm_loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], lm_logits) masked_lm_loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], lm_logits)
if not return_dict: if not inputs["return_dict"]:
output = (lm_logits,) + outputs[1:] output = (lm_logits,) + outputs[1:]
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
return TFSeq2SeqLMOutput( return TFSeq2SeqLMOutput(
loss=masked_lm_loss, loss=masked_lm_loss,
logits=lm_logits, logits=lm_logits,
......
...@@ -550,6 +550,7 @@ class TFBertMainLayer(tf.keras.layers.Layer): ...@@ -550,6 +550,7 @@ class TFBertMainLayer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.config = config
self.num_hidden_layers = config.num_hidden_layers self.num_hidden_layers = config.num_hidden_layers
self.initializer_range = config.initializer_range self.initializer_range = config.initializer_range
self.output_attentions = config.output_attentions self.output_attentions = config.output_attentions
...@@ -589,6 +590,7 @@ class TFBertMainLayer(tf.keras.layers.Layer): ...@@ -589,6 +590,7 @@ class TFBertMainLayer(tf.keras.layers.Layer):
): ):
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -601,13 +603,6 @@ class TFBertMainLayer(tf.keras.layers.Layer): ...@@ -601,13 +603,6 @@ class TFBertMainLayer(tf.keras.layers.Layer):
training=training, training=training,
kwargs_call=kwargs, kwargs_call=kwargs,
) )
output_attentions = (
inputs["output_attentions"] if inputs["output_attentions"] is not None else self.output_attentions
)
output_hidden_states = (
inputs["output_hidden_states"] if inputs["output_hidden_states"] is not None else self.output_hidden_states
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.return_dict
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None: if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
...@@ -661,16 +656,16 @@ class TFBertMainLayer(tf.keras.layers.Layer): ...@@ -661,16 +656,16 @@ class TFBertMainLayer(tf.keras.layers.Layer):
embedding_output, embedding_output,
extended_attention_mask, extended_attention_mask,
inputs["head_mask"], inputs["head_mask"],
output_attentions, inputs["output_attentions"],
output_hidden_states, inputs["output_hidden_states"],
return_dict, inputs["return_dict"],
training=inputs["training"], training=inputs["training"],
) )
sequence_output = encoder_outputs[0] sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output) pooled_output = self.pooler(sequence_output)
if not return_dict: if not inputs["return_dict"]:
return ( return (
sequence_output, sequence_output,
pooled_output, pooled_output,
...@@ -848,6 +843,7 @@ class TFBertModel(TFBertPreTrainedModel): ...@@ -848,6 +843,7 @@ class TFBertModel(TFBertPreTrainedModel):
): ):
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -929,6 +925,7 @@ class TFBertForPreTraining(TFBertPreTrainedModel, TFBertPreTrainingLoss): ...@@ -929,6 +925,7 @@ class TFBertForPreTraining(TFBertPreTrainedModel, TFBertPreTrainingLoss):
""" """
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -943,7 +940,6 @@ class TFBertForPreTraining(TFBertPreTrainedModel, TFBertPreTrainingLoss): ...@@ -943,7 +940,6 @@ class TFBertForPreTraining(TFBertPreTrainedModel, TFBertPreTrainingLoss):
training=training, training=training,
kwargs_call=kwargs, kwargs_call=kwargs,
) )
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.bert.return_dict
outputs = self.bert( outputs = self.bert(
inputs["input_ids"], inputs["input_ids"],
attention_mask=inputs["attention_mask"], attention_mask=inputs["attention_mask"],
...@@ -953,7 +949,7 @@ class TFBertForPreTraining(TFBertPreTrainedModel, TFBertPreTrainingLoss): ...@@ -953,7 +949,7 @@ class TFBertForPreTraining(TFBertPreTrainedModel, TFBertPreTrainingLoss):
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"], output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict, return_dict=inputs["return_dict"],
training=inputs["training"], training=inputs["training"],
) )
sequence_output, pooled_output = outputs[:2] sequence_output, pooled_output = outputs[:2]
...@@ -966,7 +962,7 @@ class TFBertForPreTraining(TFBertPreTrainedModel, TFBertPreTrainingLoss): ...@@ -966,7 +962,7 @@ class TFBertForPreTraining(TFBertPreTrainedModel, TFBertPreTrainingLoss):
d_labels["next_sentence_label"] = inputs["next_sentence_label"] d_labels["next_sentence_label"] = inputs["next_sentence_label"]
total_loss = self.compute_loss(labels=d_labels, logits=(prediction_scores, seq_relationship_score)) total_loss = self.compute_loss(labels=d_labels, logits=(prediction_scores, seq_relationship_score))
if not return_dict: if not inputs["return_dict"]:
return (prediction_scores, seq_relationship_score) + outputs[2:] return (prediction_scores, seq_relationship_score) + outputs[2:]
return TFBertForPreTrainingOutput( return TFBertForPreTrainingOutput(
...@@ -1029,6 +1025,7 @@ class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss): ...@@ -1029,6 +1025,7 @@ class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss):
""" """
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -1042,7 +1039,6 @@ class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss): ...@@ -1042,7 +1039,6 @@ class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss):
training=training, training=training,
kwargs_call=kwargs, kwargs_call=kwargs,
) )
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.bert.return_dict
outputs = self.bert( outputs = self.bert(
inputs["input_ids"], inputs["input_ids"],
attention_mask=inputs["attention_mask"], attention_mask=inputs["attention_mask"],
...@@ -1052,14 +1048,14 @@ class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss): ...@@ -1052,14 +1048,14 @@ class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss):
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"], output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict, return_dict=inputs["return_dict"],
training=inputs["training"], training=inputs["training"],
) )
sequence_output = outputs[0] sequence_output = outputs[0]
prediction_scores = self.mlm(sequence_output, training=inputs["training"]) prediction_scores = self.mlm(sequence_output, training=inputs["training"])
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], prediction_scores) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], prediction_scores)
if not return_dict: if not inputs["return_dict"]:
output = (prediction_scores,) + outputs[2:] output = (prediction_scores,) + outputs[2:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
...@@ -1116,6 +1112,7 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -1116,6 +1112,7 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss):
""" """
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -1129,7 +1126,6 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -1129,7 +1126,6 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss):
training=training, training=training,
kwargs_call=kwargs, kwargs_call=kwargs,
) )
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.bert.return_dict
outputs = self.bert( outputs = self.bert(
inputs["input_ids"], inputs["input_ids"],
attention_mask=inputs["attention_mask"], attention_mask=inputs["attention_mask"],
...@@ -1139,7 +1135,7 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -1139,7 +1135,7 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss):
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"], output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict, return_dict=inputs["return_dict"],
training=inputs["training"], training=inputs["training"],
) )
sequence_output = outputs[0] sequence_output = outputs[0]
...@@ -1152,7 +1148,7 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -1152,7 +1148,7 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss):
labels = inputs["labels"][:, 1:] labels = inputs["labels"][:, 1:]
loss = self.compute_loss(labels, logits) loss = self.compute_loss(labels, logits)
if not return_dict: if not inputs["return_dict"]:
output = (logits,) + outputs[2:] output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
...@@ -1212,6 +1208,7 @@ class TFBertForNextSentencePrediction(TFBertPreTrainedModel, TFNextSentencePredi ...@@ -1212,6 +1208,7 @@ class TFBertForNextSentencePrediction(TFBertPreTrainedModel, TFNextSentencePredi
""" """
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -1225,7 +1222,6 @@ class TFBertForNextSentencePrediction(TFBertPreTrainedModel, TFNextSentencePredi ...@@ -1225,7 +1222,6 @@ class TFBertForNextSentencePrediction(TFBertPreTrainedModel, TFNextSentencePredi
training=training, training=training,
kwargs_call=kwargs, kwargs_call=kwargs,
) )
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.bert.return_dict
outputs = self.bert( outputs = self.bert(
inputs["input_ids"], inputs["input_ids"],
attention_mask=inputs["attention_mask"], attention_mask=inputs["attention_mask"],
...@@ -1235,7 +1231,7 @@ class TFBertForNextSentencePrediction(TFBertPreTrainedModel, TFNextSentencePredi ...@@ -1235,7 +1231,7 @@ class TFBertForNextSentencePrediction(TFBertPreTrainedModel, TFNextSentencePredi
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"], output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict, return_dict=inputs["return_dict"],
training=inputs["training"], training=inputs["training"],
) )
pooled_output = outputs[1] pooled_output = outputs[1]
...@@ -1246,7 +1242,7 @@ class TFBertForNextSentencePrediction(TFBertPreTrainedModel, TFNextSentencePredi ...@@ -1246,7 +1242,7 @@ class TFBertForNextSentencePrediction(TFBertPreTrainedModel, TFNextSentencePredi
else self.compute_loss(labels=inputs["next_sentence_label"], logits=seq_relationship_scores) else self.compute_loss(labels=inputs["next_sentence_label"], logits=seq_relationship_scores)
) )
if not return_dict: if not inputs["return_dict"]:
output = (seq_relationship_scores,) + outputs[2:] output = (seq_relationship_scores,) + outputs[2:]
return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output
...@@ -1306,6 +1302,7 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassific ...@@ -1306,6 +1302,7 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassific
""" """
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -1319,7 +1316,6 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassific ...@@ -1319,7 +1316,6 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassific
training=training, training=training,
kwargs_call=kwargs, kwargs_call=kwargs,
) )
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.bert.return_dict
outputs = self.bert( outputs = self.bert(
inputs["input_ids"], inputs["input_ids"],
attention_mask=inputs["attention_mask"], attention_mask=inputs["attention_mask"],
...@@ -1329,7 +1325,7 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassific ...@@ -1329,7 +1325,7 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassific
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"], output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict, return_dict=inputs["return_dict"],
training=inputs["training"], training=inputs["training"],
) )
pooled_output = outputs[1] pooled_output = outputs[1]
...@@ -1337,7 +1333,7 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassific ...@@ -1337,7 +1333,7 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassific
logits = self.classifier(pooled_output) logits = self.classifier(pooled_output)
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
if not return_dict: if not inputs["return_dict"]:
output = (logits,) + outputs[2:] output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
...@@ -1406,6 +1402,7 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss): ...@@ -1406,6 +1402,7 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
""" """
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -1419,7 +1416,6 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss): ...@@ -1419,7 +1416,6 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
training=training, training=training,
kwargs_call=kwargs, kwargs_call=kwargs,
) )
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.bert.return_dict
if inputs["input_ids"] is not None: if inputs["input_ids"] is not None:
num_choices = shape_list(inputs["input_ids"])[1] num_choices = shape_list(inputs["input_ids"])[1]
...@@ -1452,7 +1448,7 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss): ...@@ -1452,7 +1448,7 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
flat_inputs_embeds, flat_inputs_embeds,
inputs["output_attentions"], inputs["output_attentions"],
inputs["output_hidden_states"], inputs["output_hidden_states"],
return_dict=return_dict, return_dict=inputs["return_dict"],
training=inputs["training"], training=inputs["training"],
) )
pooled_output = outputs[1] pooled_output = outputs[1]
...@@ -1461,7 +1457,7 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss): ...@@ -1461,7 +1457,7 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
reshaped_logits = tf.reshape(logits, (-1, num_choices)) reshaped_logits = tf.reshape(logits, (-1, num_choices))
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], reshaped_logits) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], reshaped_logits)
if not return_dict: if not inputs["return_dict"]:
output = (reshaped_logits,) + outputs[2:] output = (reshaped_logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
...@@ -1524,6 +1520,7 @@ class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationL ...@@ -1524,6 +1520,7 @@ class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationL
""" """
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -1537,7 +1534,6 @@ class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationL ...@@ -1537,7 +1534,6 @@ class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationL
training=training, training=training,
kwargs_call=kwargs, kwargs_call=kwargs,
) )
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.bert.return_dict
outputs = self.bert( outputs = self.bert(
inputs["input_ids"], inputs["input_ids"],
attention_mask=inputs["attention_mask"], attention_mask=inputs["attention_mask"],
...@@ -1547,7 +1543,7 @@ class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationL ...@@ -1547,7 +1543,7 @@ class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationL
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"], output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict, return_dict=inputs["return_dict"],
training=inputs["training"], training=inputs["training"],
) )
sequence_output = outputs[0] sequence_output = outputs[0]
...@@ -1555,7 +1551,7 @@ class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationL ...@@ -1555,7 +1551,7 @@ class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationL
logits = self.classifier(sequence_output) logits = self.classifier(sequence_output)
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
if not return_dict: if not inputs["return_dict"]:
output = (logits,) + outputs[2:] output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
...@@ -1623,6 +1619,7 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss) ...@@ -1623,6 +1619,7 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss)
""" """
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -1637,7 +1634,6 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss) ...@@ -1637,7 +1634,6 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss)
training=training, training=training,
kwargs_call=kwargs, kwargs_call=kwargs,
) )
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.bert.return_dict
outputs = self.bert( outputs = self.bert(
inputs["input_ids"], inputs["input_ids"],
attention_mask=inputs["attention_mask"], attention_mask=inputs["attention_mask"],
...@@ -1647,7 +1643,7 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss) ...@@ -1647,7 +1643,7 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss)
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"], output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict, return_dict=inputs["return_dict"],
training=inputs["training"], training=inputs["training"],
) )
sequence_output = outputs[0] sequence_output = outputs[0]
...@@ -1662,7 +1658,7 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss) ...@@ -1662,7 +1658,7 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss)
labels["end_position"] = inputs["end_positions"] labels["end_position"] = inputs["end_positions"]
loss = self.compute_loss(labels, (start_logits, end_logits)) loss = self.compute_loss(labels, (start_logits, end_logits))
if not return_dict: if not inputs["return_dict"]:
output = (start_logits, end_logits) + outputs[2:] output = (start_logits, end_logits) + outputs[2:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
......
...@@ -204,6 +204,8 @@ class TFCTRLMainLayer(tf.keras.layers.Layer): ...@@ -204,6 +204,8 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.config = config
self.output_hidden_states = config.output_hidden_states self.output_hidden_states = config.output_hidden_states
self.output_attentions = config.output_attentions self.output_attentions = config.output_attentions
self.use_cache = config.use_cache self.use_cache = config.use_cache
...@@ -267,6 +269,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer): ...@@ -267,6 +269,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
): ):
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
past=past, past=past,
attention_mask=attention_mask, attention_mask=attention_mask,
...@@ -281,14 +284,6 @@ class TFCTRLMainLayer(tf.keras.layers.Layer): ...@@ -281,14 +284,6 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
training=training, training=training,
kwargs_call=kwargs, kwargs_call=kwargs,
) )
output_attentions = (
inputs["output_attentions"] if inputs["output_attentions"] is not None else self.output_attentions
)
output_hidden_states = (
inputs["output_hidden_states"] if inputs["output_hidden_states"] is not None else self.output_hidden_states
)
use_cache = inputs["use_cache"] if inputs["use_cache"] is not None else self.use_cache
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.return_dict
# If using past key value states, only the last tokens # If using past key value states, only the last tokens
# should be given as an input # should be given as an input
...@@ -375,8 +370,8 @@ class TFCTRLMainLayer(tf.keras.layers.Layer): ...@@ -375,8 +370,8 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
output_shape = input_shape + [shape_list(hidden_states)[-1]] output_shape = input_shape + [shape_list(hidden_states)[-1]]
presents = () if inputs["use_cache"] else None presents = () if inputs["use_cache"] else None
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if inputs["output_hidden_states"] else None
all_attentions = () if output_attentions else None all_attentions = () if inputs["output_attentions"] else None
for i, (h, layer_past) in enumerate(zip(self.h, inputs["past"])): for i, (h, layer_past) in enumerate(zip(self.h, inputs["past"])):
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),) all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),)
...@@ -400,15 +395,15 @@ class TFCTRLMainLayer(tf.keras.layers.Layer): ...@@ -400,15 +395,15 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
hidden_states = self.layernorm(hidden_states) hidden_states = self.layernorm(hidden_states)
hidden_states = tf.reshape(hidden_states, output_shape) hidden_states = tf.reshape(hidden_states, output_shape)
if output_hidden_states: if inputs["output_hidden_states"]:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if output_attentions: if inputs["output_attentions"]:
# let the number of heads free (-1) so we can extract attention even after head pruning # let the number of heads free (-1) so we can extract attention even after head pruning
attention_output_shape = input_shape[:-1] + [-1] + shape_list(all_attentions[0])[-2:] attention_output_shape = input_shape[:-1] + [-1] + shape_list(all_attentions[0])[-2:]
all_attentions = tuple(tf.reshape(t, attention_output_shape) for t in all_attentions) all_attentions = tuple(tf.reshape(t, attention_output_shape) for t in all_attentions)
if not return_dict: if not inputs["return_dict"]:
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_attentions] if v is not None) return tuple(v for v in [hidden_states, presents, all_hidden_states, all_attentions] if v is not None)
return TFBaseModelOutputWithPast( return TFBaseModelOutputWithPast(
...@@ -566,6 +561,7 @@ class TFCTRLModel(TFCTRLPreTrainedModel): ...@@ -566,6 +561,7 @@ class TFCTRLModel(TFCTRLPreTrainedModel):
): ):
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
past=past, past=past,
attention_mask=attention_mask, attention_mask=attention_mask,
...@@ -671,6 +667,7 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -671,6 +667,7 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
""" """
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
past=past, past=past,
attention_mask=attention_mask, attention_mask=attention_mask,
...@@ -686,7 +683,6 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -686,7 +683,6 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
training=training, training=training,
kwargs_call=kwargs, kwargs_call=kwargs,
) )
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
transformer_outputs = self.transformer( transformer_outputs = self.transformer(
input_ids=inputs["input_ids"], input_ids=inputs["input_ids"],
past=inputs["past"], past=inputs["past"],
...@@ -698,7 +694,7 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -698,7 +694,7 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
use_cache=inputs["use_cache"], use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"], output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict, return_dict=inputs["return_dict"],
training=inputs["training"], training=inputs["training"],
) )
...@@ -713,7 +709,7 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -713,7 +709,7 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
labels = inputs["labels"][:, 1:] labels = inputs["labels"][:, 1:]
loss = self.compute_loss(labels, logits) loss = self.compute_loss(labels, logits)
if not return_dict: if not inputs["return_dict"]:
output = (logits,) + transformer_outputs[1:] output = (logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
......
...@@ -388,6 +388,8 @@ class TFDistilBertMainLayer(tf.keras.layers.Layer): ...@@ -388,6 +388,8 @@ class TFDistilBertMainLayer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.config = config
self.num_hidden_layers = config.num_hidden_layers self.num_hidden_layers = config.num_hidden_layers
self.output_attentions = config.output_attentions self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states self.output_hidden_states = config.output_hidden_states
...@@ -420,6 +422,7 @@ class TFDistilBertMainLayer(tf.keras.layers.Layer): ...@@ -420,6 +422,7 @@ class TFDistilBertMainLayer(tf.keras.layers.Layer):
): ):
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
head_mask=head_mask, head_mask=head_mask,
...@@ -430,13 +433,6 @@ class TFDistilBertMainLayer(tf.keras.layers.Layer): ...@@ -430,13 +433,6 @@ class TFDistilBertMainLayer(tf.keras.layers.Layer):
training=training, training=training,
kwargs_call=kwargs, kwargs_call=kwargs,
) )
output_attentions = (
inputs["output_attentions"] if inputs["output_attentions"] is not None else self.output_attentions
)
output_hidden_states = (
inputs["output_hidden_states"] if inputs["output_hidden_states"] is not None else self.output_hidden_states
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.return_dict
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None: if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
...@@ -469,9 +465,9 @@ class TFDistilBertMainLayer(tf.keras.layers.Layer): ...@@ -469,9 +465,9 @@ class TFDistilBertMainLayer(tf.keras.layers.Layer):
embedding_output, embedding_output,
inputs["attention_mask"], inputs["attention_mask"],
inputs["head_mask"], inputs["head_mask"],
output_attentions, inputs["output_attentions"],
output_hidden_states, inputs["output_hidden_states"],
return_dict, inputs["return_dict"],
training=inputs["training"], training=inputs["training"],
) )
...@@ -596,6 +592,7 @@ class TFDistilBertModel(TFDistilBertPreTrainedModel): ...@@ -596,6 +592,7 @@ class TFDistilBertModel(TFDistilBertPreTrainedModel):
): ):
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
head_mask=head_mask, head_mask=head_mask,
...@@ -686,6 +683,7 @@ class TFDistilBertForMaskedLM(TFDistilBertPreTrainedModel, TFMaskedLanguageModel ...@@ -686,6 +683,7 @@ class TFDistilBertForMaskedLM(TFDistilBertPreTrainedModel, TFMaskedLanguageModel
""" """
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
head_mask=head_mask, head_mask=head_mask,
...@@ -697,7 +695,6 @@ class TFDistilBertForMaskedLM(TFDistilBertPreTrainedModel, TFMaskedLanguageModel ...@@ -697,7 +695,6 @@ class TFDistilBertForMaskedLM(TFDistilBertPreTrainedModel, TFMaskedLanguageModel
training=training, training=training,
kwargs_call=kwargs, kwargs_call=kwargs,
) )
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.distilbert.return_dict
distilbert_output = self.distilbert( distilbert_output = self.distilbert(
input_ids=inputs["input_ids"], input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"], attention_mask=inputs["attention_mask"],
...@@ -705,10 +702,9 @@ class TFDistilBertForMaskedLM(TFDistilBertPreTrainedModel, TFMaskedLanguageModel ...@@ -705,10 +702,9 @@ class TFDistilBertForMaskedLM(TFDistilBertPreTrainedModel, TFMaskedLanguageModel
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"], output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict, return_dict=inputs["return_dict"],
training=inputs["training"], training=inputs["training"],
) )
hidden_states = distilbert_output[0] # (bs, seq_length, dim) hidden_states = distilbert_output[0] # (bs, seq_length, dim)
prediction_logits = self.vocab_transform(hidden_states) # (bs, seq_length, dim) prediction_logits = self.vocab_transform(hidden_states) # (bs, seq_length, dim)
prediction_logits = self.act(prediction_logits) # (bs, seq_length, dim) prediction_logits = self.act(prediction_logits) # (bs, seq_length, dim)
...@@ -717,7 +713,7 @@ class TFDistilBertForMaskedLM(TFDistilBertPreTrainedModel, TFMaskedLanguageModel ...@@ -717,7 +713,7 @@ class TFDistilBertForMaskedLM(TFDistilBertPreTrainedModel, TFMaskedLanguageModel
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], prediction_logits) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], prediction_logits)
if not return_dict: if not inputs["return_dict"]:
output = (prediction_logits,) + distilbert_output[1:] output = (prediction_logits,) + distilbert_output[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
...@@ -781,6 +777,7 @@ class TFDistilBertForSequenceClassification(TFDistilBertPreTrainedModel, TFSeque ...@@ -781,6 +777,7 @@ class TFDistilBertForSequenceClassification(TFDistilBertPreTrainedModel, TFSeque
""" """
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
head_mask=head_mask, head_mask=head_mask,
...@@ -792,7 +789,6 @@ class TFDistilBertForSequenceClassification(TFDistilBertPreTrainedModel, TFSeque ...@@ -792,7 +789,6 @@ class TFDistilBertForSequenceClassification(TFDistilBertPreTrainedModel, TFSeque
training=training, training=training,
kwargs_call=kwargs, kwargs_call=kwargs,
) )
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.distilbert.return_dict
distilbert_output = self.distilbert( distilbert_output = self.distilbert(
input_ids=inputs["input_ids"], input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"], attention_mask=inputs["attention_mask"],
...@@ -800,10 +796,9 @@ class TFDistilBertForSequenceClassification(TFDistilBertPreTrainedModel, TFSeque ...@@ -800,10 +796,9 @@ class TFDistilBertForSequenceClassification(TFDistilBertPreTrainedModel, TFSeque
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"], output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict, return_dict=inputs["return_dict"],
training=inputs["training"], training=inputs["training"],
) )
hidden_state = distilbert_output[0] # (bs, seq_len, dim) hidden_state = distilbert_output[0] # (bs, seq_len, dim)
pooled_output = hidden_state[:, 0] # (bs, dim) pooled_output = hidden_state[:, 0] # (bs, dim)
pooled_output = self.pre_classifier(pooled_output) # (bs, dim) pooled_output = self.pre_classifier(pooled_output) # (bs, dim)
...@@ -812,7 +807,7 @@ class TFDistilBertForSequenceClassification(TFDistilBertPreTrainedModel, TFSeque ...@@ -812,7 +807,7 @@ class TFDistilBertForSequenceClassification(TFDistilBertPreTrainedModel, TFSeque
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
if not return_dict: if not inputs["return_dict"]:
output = (logits,) + distilbert_output[1:] output = (logits,) + distilbert_output[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
...@@ -869,6 +864,7 @@ class TFDistilBertForTokenClassification(TFDistilBertPreTrainedModel, TFTokenCla ...@@ -869,6 +864,7 @@ class TFDistilBertForTokenClassification(TFDistilBertPreTrainedModel, TFTokenCla
""" """
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
head_mask=head_mask, head_mask=head_mask,
...@@ -880,7 +876,6 @@ class TFDistilBertForTokenClassification(TFDistilBertPreTrainedModel, TFTokenCla ...@@ -880,7 +876,6 @@ class TFDistilBertForTokenClassification(TFDistilBertPreTrainedModel, TFTokenCla
training=training, training=training,
kwargs_call=kwargs, kwargs_call=kwargs,
) )
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.distilbert.return_dict
outputs = self.distilbert( outputs = self.distilbert(
input_ids=inputs["input_ids"], input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"], attention_mask=inputs["attention_mask"],
...@@ -888,18 +883,15 @@ class TFDistilBertForTokenClassification(TFDistilBertPreTrainedModel, TFTokenCla ...@@ -888,18 +883,15 @@ class TFDistilBertForTokenClassification(TFDistilBertPreTrainedModel, TFTokenCla
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"], output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict, return_dict=inputs["return_dict"],
training=inputs["training"], training=inputs["training"],
) )
sequence_output = outputs[0] sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output, training=inputs["training"]) sequence_output = self.dropout(sequence_output, training=inputs["training"])
logits = self.classifier(sequence_output) logits = self.classifier(sequence_output)
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
if not return_dict: if not inputs["return_dict"]:
output = (logits,) + outputs[1:] output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
...@@ -974,6 +966,7 @@ class TFDistilBertForMultipleChoice(TFDistilBertPreTrainedModel, TFMultipleChoic ...@@ -974,6 +966,7 @@ class TFDistilBertForMultipleChoice(TFDistilBertPreTrainedModel, TFMultipleChoic
""" """
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
head_mask=head_mask, head_mask=head_mask,
...@@ -985,7 +978,6 @@ class TFDistilBertForMultipleChoice(TFDistilBertPreTrainedModel, TFMultipleChoic ...@@ -985,7 +978,6 @@ class TFDistilBertForMultipleChoice(TFDistilBertPreTrainedModel, TFMultipleChoic
training=training, training=training,
kwargs_call=kwargs, kwargs_call=kwargs,
) )
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.distilbert.return_dict
if inputs["input_ids"] is not None: if inputs["input_ids"] is not None:
num_choices = shape_list(inputs["input_ids"])[1] num_choices = shape_list(inputs["input_ids"])[1]
...@@ -1010,7 +1002,7 @@ class TFDistilBertForMultipleChoice(TFDistilBertPreTrainedModel, TFMultipleChoic ...@@ -1010,7 +1002,7 @@ class TFDistilBertForMultipleChoice(TFDistilBertPreTrainedModel, TFMultipleChoic
flat_inputs_embeds, flat_inputs_embeds,
inputs["output_attentions"], inputs["output_attentions"],
inputs["output_hidden_states"], inputs["output_hidden_states"],
return_dict=return_dict, return_dict=inputs["return_dict"],
training=inputs["training"], training=inputs["training"],
) )
hidden_state = distilbert_output[0] # (bs, seq_len, dim) hidden_state = distilbert_output[0] # (bs, seq_len, dim)
...@@ -1022,7 +1014,7 @@ class TFDistilBertForMultipleChoice(TFDistilBertPreTrainedModel, TFMultipleChoic ...@@ -1022,7 +1014,7 @@ class TFDistilBertForMultipleChoice(TFDistilBertPreTrainedModel, TFMultipleChoic
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], reshaped_logits) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], reshaped_logits)
if not return_dict: if not inputs["return_dict"]:
output = (reshaped_logits,) + distilbert_output[1:] output = (reshaped_logits,) + distilbert_output[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
...@@ -1085,6 +1077,7 @@ class TFDistilBertForQuestionAnswering(TFDistilBertPreTrainedModel, TFQuestionAn ...@@ -1085,6 +1077,7 @@ class TFDistilBertForQuestionAnswering(TFDistilBertPreTrainedModel, TFQuestionAn
""" """
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
head_mask=head_mask, head_mask=head_mask,
...@@ -1097,7 +1090,6 @@ class TFDistilBertForQuestionAnswering(TFDistilBertPreTrainedModel, TFQuestionAn ...@@ -1097,7 +1090,6 @@ class TFDistilBertForQuestionAnswering(TFDistilBertPreTrainedModel, TFQuestionAn
training=training, training=training,
kwargs_call=kwargs, kwargs_call=kwargs,
) )
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.distilbert.return_dict
distilbert_output = self.distilbert( distilbert_output = self.distilbert(
input_ids=inputs["input_ids"], input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"], attention_mask=inputs["attention_mask"],
...@@ -1105,10 +1097,9 @@ class TFDistilBertForQuestionAnswering(TFDistilBertPreTrainedModel, TFQuestionAn ...@@ -1105,10 +1097,9 @@ class TFDistilBertForQuestionAnswering(TFDistilBertPreTrainedModel, TFQuestionAn
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"], output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict, return_dict=inputs["return_dict"],
training=inputs["training"], training=inputs["training"],
) )
hidden_states = distilbert_output[0] # (bs, max_query_len, dim) hidden_states = distilbert_output[0] # (bs, max_query_len, dim)
hidden_states = self.dropout(hidden_states, training=inputs["training"]) # (bs, max_query_len, dim) hidden_states = self.dropout(hidden_states, training=inputs["training"]) # (bs, max_query_len, dim)
logits = self.qa_outputs(hidden_states) # (bs, max_query_len, 2) logits = self.qa_outputs(hidden_states) # (bs, max_query_len, 2)
...@@ -1122,7 +1113,7 @@ class TFDistilBertForQuestionAnswering(TFDistilBertPreTrainedModel, TFQuestionAn ...@@ -1122,7 +1113,7 @@ class TFDistilBertForQuestionAnswering(TFDistilBertPreTrainedModel, TFQuestionAn
labels["end_position"] = inputs["end_positions"] labels["end_position"] = inputs["end_positions"]
loss = self.compute_loss(labels, (start_logits, end_logits)) loss = self.compute_loss(labels, (start_logits, end_logits))
if not return_dict: if not inputs["return_dict"]:
output = (start_logits, end_logits) + distilbert_output[1:] output = (start_logits, end_logits) + distilbert_output[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
......
...@@ -176,6 +176,7 @@ class TFDPREncoder(TFPreTrainedModel): ...@@ -176,6 +176,7 @@ class TFDPREncoder(TFPreTrainedModel):
) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor, ...]]: ) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor, ...]]:
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -186,7 +187,6 @@ class TFDPREncoder(TFPreTrainedModel): ...@@ -186,7 +187,6 @@ class TFDPREncoder(TFPreTrainedModel):
training=training, training=training,
kwargs_call=kwargs, kwargs_call=kwargs,
) )
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.bert_model.return_dict
outputs = self.bert_model( outputs = self.bert_model(
input_ids=inputs["input_ids"], input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"], attention_mask=inputs["attention_mask"],
...@@ -194,7 +194,7 @@ class TFDPREncoder(TFPreTrainedModel): ...@@ -194,7 +194,7 @@ class TFDPREncoder(TFPreTrainedModel):
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"], output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict, return_dict=inputs["return_dict"],
training=inputs["training"], training=inputs["training"],
) )
...@@ -203,7 +203,7 @@ class TFDPREncoder(TFPreTrainedModel): ...@@ -203,7 +203,7 @@ class TFDPREncoder(TFPreTrainedModel):
if self.projection_dim > 0: if self.projection_dim > 0:
pooled_output = self.encode_proj(pooled_output) pooled_output = self.encode_proj(pooled_output)
if not return_dict: if not inputs["return_dict"]:
return (sequence_output, pooled_output) + outputs[2:] return (sequence_output, pooled_output) + outputs[2:]
return TFBaseModelOutputWithPooling( return TFBaseModelOutputWithPooling(
...@@ -237,7 +237,7 @@ class TFDPRSpanPredictor(TFPreTrainedModel): ...@@ -237,7 +237,7 @@ class TFDPRSpanPredictor(TFPreTrainedModel):
def call( def call(
self, self,
input_ids: tf.Tensor, input_ids: tf.Tensor = None,
attention_mask: Optional[tf.Tensor] = None, attention_mask: Optional[tf.Tensor] = None,
token_type_ids: Optional[tf.Tensor] = None, token_type_ids: Optional[tf.Tensor] = None,
inputs_embeds: Optional[tf.Tensor] = None, inputs_embeds: Optional[tf.Tensor] = None,
...@@ -253,6 +253,7 @@ class TFDPRSpanPredictor(TFPreTrainedModel): ...@@ -253,6 +253,7 @@ class TFDPRSpanPredictor(TFPreTrainedModel):
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -263,17 +264,13 @@ class TFDPRSpanPredictor(TFPreTrainedModel): ...@@ -263,17 +264,13 @@ class TFDPRSpanPredictor(TFPreTrainedModel):
training=training, training=training,
kwargs_call=kwargs, kwargs_call=kwargs,
) )
return_dict = (
inputs["return_dict"] if inputs["return_dict"] is not None else self.encoder.bert_model.return_dict
)
outputs = self.encoder( outputs = self.encoder(
input_ids=inputs["input_ids"], input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"], attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"], output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict, return_dict=inputs["return_dict"],
training=inputs["training"], training=inputs["training"],
) )
sequence_output = outputs[0] sequence_output = outputs[0]
...@@ -290,7 +287,7 @@ class TFDPRSpanPredictor(TFPreTrainedModel): ...@@ -290,7 +287,7 @@ class TFDPRSpanPredictor(TFPreTrainedModel):
end_logits = tf.reshape(end_logits, [n_passages, sequence_length]) end_logits = tf.reshape(end_logits, [n_passages, sequence_length])
relevance_logits = tf.reshape(relevance_logits, [n_passages]) relevance_logits = tf.reshape(relevance_logits, [n_passages])
if not return_dict: if not inputs["return_dict"]:
return (start_logits, end_logits, relevance_logits) + outputs[2:] return (start_logits, end_logits, relevance_logits) + outputs[2:]
return TFDPRReaderOutput( return TFDPRReaderOutput(
...@@ -501,6 +498,7 @@ class TFDPRContextEncoder(TFDPRPretrainedContextEncoder): ...@@ -501,6 +498,7 @@ class TFDPRContextEncoder(TFDPRPretrainedContextEncoder):
""" """
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -511,15 +509,6 @@ class TFDPRContextEncoder(TFDPRPretrainedContextEncoder): ...@@ -511,15 +509,6 @@ class TFDPRContextEncoder(TFDPRPretrainedContextEncoder):
training=training, training=training,
kwargs_call=kwargs, kwargs_call=kwargs,
) )
output_attentions = (
inputs["output_attentions"] if inputs["output_attentions"] is not None else self.config.output_attentions
)
output_hidden_states = (
inputs["output_hidden_states"]
if inputs["output_hidden_states"] is not None
else self.config.output_hidden_states
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.config.use_return_dict
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None: if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
...@@ -544,13 +533,13 @@ class TFDPRContextEncoder(TFDPRPretrainedContextEncoder): ...@@ -544,13 +533,13 @@ class TFDPRContextEncoder(TFDPRPretrainedContextEncoder):
attention_mask=inputs["attention_mask"], attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"], token_type_ids=inputs["token_type_ids"],
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
output_attentions=output_attentions, output_attentions=inputs["output_attentions"],
output_hidden_states=output_hidden_states, output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict, return_dict=inputs["return_dict"],
training=inputs["training"], training=inputs["training"],
) )
if not return_dict: if not inputs["return_dict"]:
return outputs[1:] return outputs[1:]
return TFDPRContextEncoderOutput( return TFDPRContextEncoderOutput(
pooler_output=outputs.pooler_output, hidden_states=outputs.hidden_states, attentions=outputs.attentions pooler_output=outputs.pooler_output, hidden_states=outputs.hidden_states, attentions=outputs.attentions
...@@ -597,6 +586,7 @@ class TFDPRQuestionEncoder(TFDPRPretrainedQuestionEncoder): ...@@ -597,6 +586,7 @@ class TFDPRQuestionEncoder(TFDPRPretrainedQuestionEncoder):
""" """
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -607,15 +597,6 @@ class TFDPRQuestionEncoder(TFDPRPretrainedQuestionEncoder): ...@@ -607,15 +597,6 @@ class TFDPRQuestionEncoder(TFDPRPretrainedQuestionEncoder):
training=training, training=training,
kwargs_call=kwargs, kwargs_call=kwargs,
) )
output_attentions = (
inputs["output_attentions"] if inputs["output_attentions"] is not None else self.config.output_attentions
)
output_hidden_states = (
inputs["output_hidden_states"]
if inputs["output_hidden_states"] is not None
else self.config.output_hidden_states
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.config.use_return_dict
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None: if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
...@@ -640,13 +621,13 @@ class TFDPRQuestionEncoder(TFDPRPretrainedQuestionEncoder): ...@@ -640,13 +621,13 @@ class TFDPRQuestionEncoder(TFDPRPretrainedQuestionEncoder):
attention_mask=inputs["attention_mask"], attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"], token_type_ids=inputs["token_type_ids"],
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
output_attentions=output_attentions, output_attentions=inputs["output_attentions"],
output_hidden_states=output_hidden_states, output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict, return_dict=inputs["return_dict"],
training=inputs["training"], training=inputs["training"],
) )
if not return_dict: if not inputs["return_dict"]:
return outputs[1:] return outputs[1:]
return TFDPRQuestionEncoderOutput( return TFDPRQuestionEncoderOutput(
pooler_output=outputs.pooler_output, hidden_states=outputs.hidden_states, attentions=outputs.attentions pooler_output=outputs.pooler_output, hidden_states=outputs.hidden_states, attentions=outputs.attentions
...@@ -702,6 +683,7 @@ class TFDPRReader(TFDPRPretrainedReader): ...@@ -702,6 +683,7 @@ class TFDPRReader(TFDPRPretrainedReader):
""" """
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -712,15 +694,6 @@ class TFDPRReader(TFDPRPretrainedReader): ...@@ -712,15 +694,6 @@ class TFDPRReader(TFDPRPretrainedReader):
training=training, training=training,
kwargs_call=kwargs, kwargs_call=kwargs,
) )
output_attentions = (
inputs["output_attentions"] if inputs["output_attentions"] is not None else self.config.output_attentions
)
output_hidden_states = (
inputs["output_hidden_states"]
if inputs["output_hidden_states"] is not None
else self.config.output_hidden_states
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.config.use_return_dict
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None: if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
...@@ -734,16 +707,16 @@ class TFDPRReader(TFDPRPretrainedReader): ...@@ -734,16 +707,16 @@ class TFDPRReader(TFDPRPretrainedReader):
if inputs["attention_mask"] is None: if inputs["attention_mask"] is None:
inputs["attention_mask"] = tf.ones(input_shape, dtype=tf.dtypes.int32) inputs["attention_mask"] = tf.ones(input_shape, dtype=tf.dtypes.int32)
if token_type_ids is None: if inputs["token_type_ids"] is None:
token_type_ids = tf.zeros(input_shape, dtype=tf.dtypes.int32) inputs["token_type_ids"] = tf.zeros(input_shape, dtype=tf.dtypes.int32)
return self.span_predictor( return self.span_predictor(
input_ids=inputs["input_ids"], input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"], attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"], token_type_ids=inputs["token_type_ids"],
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
output_attentions=output_attentions, output_attentions=inputs["output_attentions"],
output_hidden_states=output_hidden_states, output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict, return_dict=inputs["return_dict"],
training=inputs["training"], training=inputs["training"],
) )
...@@ -477,6 +477,7 @@ class TFElectraMainLayer(tf.keras.layers.Layer): ...@@ -477,6 +477,7 @@ class TFElectraMainLayer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.config = config
self.embeddings = TFElectraEmbeddings(config, name="embeddings") self.embeddings = TFElectraEmbeddings(config, name="embeddings")
if config.embedding_size != config.hidden_size: if config.embedding_size != config.hidden_size:
...@@ -547,6 +548,7 @@ class TFElectraMainLayer(tf.keras.layers.Layer): ...@@ -547,6 +548,7 @@ class TFElectraMainLayer(tf.keras.layers.Layer):
): ):
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -559,15 +561,6 @@ class TFElectraMainLayer(tf.keras.layers.Layer): ...@@ -559,15 +561,6 @@ class TFElectraMainLayer(tf.keras.layers.Layer):
training=training, training=training,
kwargs_call=kwargs, kwargs_call=kwargs,
) )
output_attentions = (
inputs["output_attentions"] if inputs["output_attentions"] is not None else self.config.output_attentions
)
output_hidden_states = (
inputs["output_hidden_states"]
if inputs["output_hidden_states"] is not None
else self.config.output_hidden_states
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.config.use_return_dict
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None: if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
...@@ -603,9 +596,9 @@ class TFElectraMainLayer(tf.keras.layers.Layer): ...@@ -603,9 +596,9 @@ class TFElectraMainLayer(tf.keras.layers.Layer):
hidden_states, hidden_states,
extended_attention_mask, extended_attention_mask,
inputs["head_mask"], inputs["head_mask"],
output_attentions, inputs["output_attentions"],
output_hidden_states, inputs["output_hidden_states"],
return_dict, inputs["return_dict"],
training=inputs["training"], training=inputs["training"],
) )
...@@ -759,6 +752,7 @@ class TFElectraModel(TFElectraPreTrainedModel): ...@@ -759,6 +752,7 @@ class TFElectraModel(TFElectraPreTrainedModel):
): ):
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -835,6 +829,7 @@ class TFElectraForPreTraining(TFElectraPreTrainedModel): ...@@ -835,6 +829,7 @@ class TFElectraForPreTraining(TFElectraPreTrainedModel):
""" """
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -847,9 +842,6 @@ class TFElectraForPreTraining(TFElectraPreTrainedModel): ...@@ -847,9 +842,6 @@ class TFElectraForPreTraining(TFElectraPreTrainedModel):
training=training, training=training,
kwargs_call=kwargs, kwargs_call=kwargs,
) )
return_dict = (
inputs["return_dict"] if inputs["return_dict"] is not None else self.electra.config.use_return_dict
)
discriminator_hidden_states = self.electra( discriminator_hidden_states = self.electra(
input_ids=inputs["input_ids"], input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"], attention_mask=inputs["attention_mask"],
...@@ -859,13 +851,13 @@ class TFElectraForPreTraining(TFElectraPreTrainedModel): ...@@ -859,13 +851,13 @@ class TFElectraForPreTraining(TFElectraPreTrainedModel):
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"], output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict, return_dict=inputs["return_dict"],
training=inputs["training"], training=inputs["training"],
) )
discriminator_sequence_output = discriminator_hidden_states[0] discriminator_sequence_output = discriminator_hidden_states[0]
logits = self.discriminator_predictions(discriminator_sequence_output) logits = self.discriminator_predictions(discriminator_sequence_output)
if not return_dict: if not inputs["return_dict"]:
return (logits,) + discriminator_hidden_states[1:] return (logits,) + discriminator_hidden_states[1:]
return TFElectraForPreTrainingOutput( return TFElectraForPreTrainingOutput(
...@@ -951,6 +943,7 @@ class TFElectraForMaskedLM(TFElectraPreTrainedModel, TFMaskedLanguageModelingLos ...@@ -951,6 +943,7 @@ class TFElectraForMaskedLM(TFElectraPreTrainedModel, TFMaskedLanguageModelingLos
""" """
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -964,9 +957,6 @@ class TFElectraForMaskedLM(TFElectraPreTrainedModel, TFMaskedLanguageModelingLos ...@@ -964,9 +957,6 @@ class TFElectraForMaskedLM(TFElectraPreTrainedModel, TFMaskedLanguageModelingLos
training=training, training=training,
kwargs_call=kwargs, kwargs_call=kwargs,
) )
return_dict = (
inputs["return_dict"] if inputs["return_dict"] is not None else self.electra.config.use_return_dict
)
generator_hidden_states = self.electra( generator_hidden_states = self.electra(
input_ids=inputs["input_ids"], input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"], attention_mask=inputs["attention_mask"],
...@@ -976,7 +966,7 @@ class TFElectraForMaskedLM(TFElectraPreTrainedModel, TFMaskedLanguageModelingLos ...@@ -976,7 +966,7 @@ class TFElectraForMaskedLM(TFElectraPreTrainedModel, TFMaskedLanguageModelingLos
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"], output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict, return_dict=inputs["return_dict"],
training=inputs["training"], training=inputs["training"],
) )
generator_sequence_output = generator_hidden_states[0] generator_sequence_output = generator_hidden_states[0]
...@@ -984,7 +974,7 @@ class TFElectraForMaskedLM(TFElectraPreTrainedModel, TFMaskedLanguageModelingLos ...@@ -984,7 +974,7 @@ class TFElectraForMaskedLM(TFElectraPreTrainedModel, TFMaskedLanguageModelingLos
prediction_scores = self.generator_lm_head(prediction_scores, training=inputs["training"]) prediction_scores = self.generator_lm_head(prediction_scores, training=inputs["training"])
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], prediction_scores) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], prediction_scores)
if not return_dict: if not inputs["return_dict"]:
output = (prediction_scores,) + generator_hidden_states[1:] output = (prediction_scores,) + generator_hidden_states[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
...@@ -1066,6 +1056,7 @@ class TFElectraForSequenceClassification(TFElectraPreTrainedModel, TFSequenceCla ...@@ -1066,6 +1056,7 @@ class TFElectraForSequenceClassification(TFElectraPreTrainedModel, TFSequenceCla
""" """
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -1079,9 +1070,6 @@ class TFElectraForSequenceClassification(TFElectraPreTrainedModel, TFSequenceCla ...@@ -1079,9 +1070,6 @@ class TFElectraForSequenceClassification(TFElectraPreTrainedModel, TFSequenceCla
training=training, training=training,
kwargs_call=kwargs, kwargs_call=kwargs,
) )
return_dict = (
inputs["return_dict"] if inputs["return_dict"] is not None else self.electra.config.use_return_dict
)
outputs = self.electra( outputs = self.electra(
input_ids=inputs["input_ids"], input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"], attention_mask=inputs["attention_mask"],
...@@ -1091,13 +1079,13 @@ class TFElectraForSequenceClassification(TFElectraPreTrainedModel, TFSequenceCla ...@@ -1091,13 +1079,13 @@ class TFElectraForSequenceClassification(TFElectraPreTrainedModel, TFSequenceCla
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"], output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict, return_dict=inputs["return_dict"],
training=inputs["training"], training=inputs["training"],
) )
logits = self.classifier(outputs[0]) logits = self.classifier(outputs[0])
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
if not return_dict: if not inputs["return_dict"]:
output = (logits,) + outputs[1:] output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
...@@ -1169,6 +1157,7 @@ class TFElectraForMultipleChoice(TFElectraPreTrainedModel, TFMultipleChoiceLoss) ...@@ -1169,6 +1157,7 @@ class TFElectraForMultipleChoice(TFElectraPreTrainedModel, TFMultipleChoiceLoss)
""" """
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -1182,9 +1171,6 @@ class TFElectraForMultipleChoice(TFElectraPreTrainedModel, TFMultipleChoiceLoss) ...@@ -1182,9 +1171,6 @@ class TFElectraForMultipleChoice(TFElectraPreTrainedModel, TFMultipleChoiceLoss)
training=training, training=training,
kwargs_call=kwargs, kwargs_call=kwargs,
) )
return_dict = (
inputs["return_dict"] if inputs["return_dict"] is not None else self.electra.config.use_return_dict
)
if inputs["input_ids"] is not None: if inputs["input_ids"] is not None:
num_choices = shape_list(inputs["input_ids"])[1] num_choices = shape_list(inputs["input_ids"])[1]
...@@ -1217,7 +1203,7 @@ class TFElectraForMultipleChoice(TFElectraPreTrainedModel, TFMultipleChoiceLoss) ...@@ -1217,7 +1203,7 @@ class TFElectraForMultipleChoice(TFElectraPreTrainedModel, TFMultipleChoiceLoss)
flat_inputs_embeds, flat_inputs_embeds,
inputs["output_attentions"], inputs["output_attentions"],
inputs["output_hidden_states"], inputs["output_hidden_states"],
return_dict=return_dict, return_dict=inputs["return_dict"],
training=inputs["training"], training=inputs["training"],
) )
logits = self.sequence_summary(outputs[0]) logits = self.sequence_summary(outputs[0])
...@@ -1225,7 +1211,7 @@ class TFElectraForMultipleChoice(TFElectraPreTrainedModel, TFMultipleChoiceLoss) ...@@ -1225,7 +1211,7 @@ class TFElectraForMultipleChoice(TFElectraPreTrainedModel, TFMultipleChoiceLoss)
reshaped_logits = tf.reshape(logits, (-1, num_choices)) reshaped_logits = tf.reshape(logits, (-1, num_choices))
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], reshaped_logits) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], reshaped_logits)
if not return_dict: if not inputs["return_dict"]:
output = (reshaped_logits,) + outputs[1:] output = (reshaped_logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
...@@ -1285,6 +1271,7 @@ class TFElectraForTokenClassification(TFElectraPreTrainedModel, TFTokenClassific ...@@ -1285,6 +1271,7 @@ class TFElectraForTokenClassification(TFElectraPreTrainedModel, TFTokenClassific
""" """
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -1298,9 +1285,6 @@ class TFElectraForTokenClassification(TFElectraPreTrainedModel, TFTokenClassific ...@@ -1298,9 +1285,6 @@ class TFElectraForTokenClassification(TFElectraPreTrainedModel, TFTokenClassific
training=training, training=training,
kwargs_call=kwargs, kwargs_call=kwargs,
) )
return_dict = (
inputs["return_dict"] if inputs["return_dict"] is not None else self.electra.config.use_return_dict
)
discriminator_hidden_states = self.electra( discriminator_hidden_states = self.electra(
input_ids=inputs["input_ids"], input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"], attention_mask=inputs["attention_mask"],
...@@ -1310,7 +1294,7 @@ class TFElectraForTokenClassification(TFElectraPreTrainedModel, TFTokenClassific ...@@ -1310,7 +1294,7 @@ class TFElectraForTokenClassification(TFElectraPreTrainedModel, TFTokenClassific
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"], output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict, return_dict=inputs["return_dict"],
training=inputs["training"], training=inputs["training"],
) )
discriminator_sequence_output = discriminator_hidden_states[0] discriminator_sequence_output = discriminator_hidden_states[0]
...@@ -1318,7 +1302,7 @@ class TFElectraForTokenClassification(TFElectraPreTrainedModel, TFTokenClassific ...@@ -1318,7 +1302,7 @@ class TFElectraForTokenClassification(TFElectraPreTrainedModel, TFTokenClassific
logits = self.classifier(discriminator_sequence_output) logits = self.classifier(discriminator_sequence_output)
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
if not return_dict: if not inputs["return_dict"]:
output = (logits,) + discriminator_hidden_states[1:] output = (logits,) + discriminator_hidden_states[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
...@@ -1383,6 +1367,7 @@ class TFElectraForQuestionAnswering(TFElectraPreTrainedModel, TFQuestionAnswerin ...@@ -1383,6 +1367,7 @@ class TFElectraForQuestionAnswering(TFElectraPreTrainedModel, TFQuestionAnswerin
""" """
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -1397,9 +1382,6 @@ class TFElectraForQuestionAnswering(TFElectraPreTrainedModel, TFQuestionAnswerin ...@@ -1397,9 +1382,6 @@ class TFElectraForQuestionAnswering(TFElectraPreTrainedModel, TFQuestionAnswerin
training=training, training=training,
kwargs_call=kwargs, kwargs_call=kwargs,
) )
return_dict = (
inputs["return_dict"] if inputs["return_dict"] is not None else self.electra.config.use_return_dict
)
discriminator_hidden_states = self.electra( discriminator_hidden_states = self.electra(
input_ids=inputs["input_ids"], input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"], attention_mask=inputs["attention_mask"],
...@@ -1409,7 +1391,7 @@ class TFElectraForQuestionAnswering(TFElectraPreTrainedModel, TFQuestionAnswerin ...@@ -1409,7 +1391,7 @@ class TFElectraForQuestionAnswering(TFElectraPreTrainedModel, TFQuestionAnswerin
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"], output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict, return_dict=inputs["return_dict"],
training=inputs["training"], training=inputs["training"],
) )
discriminator_sequence_output = discriminator_hidden_states[0] discriminator_sequence_output = discriminator_hidden_states[0]
...@@ -1424,7 +1406,7 @@ class TFElectraForQuestionAnswering(TFElectraPreTrainedModel, TFQuestionAnswerin ...@@ -1424,7 +1406,7 @@ class TFElectraForQuestionAnswering(TFElectraPreTrainedModel, TFQuestionAnswerin
labels["end_position"] = inputs["end_positions"] labels["end_position"] = inputs["end_positions"]
loss = self.compute_loss(labels, (start_logits, end_logits)) loss = self.compute_loss(labels, (start_logits, end_logits))
if not return_dict: if not inputs["return_dict"]:
output = ( output = (
start_logits, start_logits,
end_logits, end_logits,
......
...@@ -253,6 +253,7 @@ class TFFlaubertModel(TFFlaubertPreTrainedModel): ...@@ -253,6 +253,7 @@ class TFFlaubertModel(TFFlaubertPreTrainedModel):
): ):
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
langs=langs, langs=langs,
...@@ -407,6 +408,7 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer): ...@@ -407,6 +408,7 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.config = config
self.n_heads = config.n_heads self.n_heads = config.n_heads
self.n_langs = config.n_langs self.n_langs = config.n_langs
self.dim = config.emb_dim self.dim = config.emb_dim
...@@ -488,6 +490,7 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer): ...@@ -488,6 +490,7 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
# removed: src_enc=None, src_len=None # removed: src_enc=None, src_len=None
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
langs=langs, langs=langs,
...@@ -503,13 +506,6 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer): ...@@ -503,13 +506,6 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
training=training, training=training,
kwargs_call=kwargs, kwargs_call=kwargs,
) )
output_attentions = (
inputs["output_attentions"] if inputs["output_attentions"] is not None else self.output_attentions
)
output_hidden_states = (
inputs["output_hidden_states"] if inputs["output_hidden_states"] is not None else self.output_hidden_states
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.return_dict
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None: if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
...@@ -611,7 +607,7 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer): ...@@ -611,7 +607,7 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
if inputs["training"] and tf.less(dropout_probability, self.layerdrop): if inputs["training"] and tf.less(dropout_probability, self.layerdrop):
continue continue
if output_hidden_states: if inputs["output_hidden_states"]:
hidden_states = hidden_states + (tensor,) hidden_states = hidden_states + (tensor,)
# self attention # self attention
...@@ -622,12 +618,12 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer): ...@@ -622,12 +618,12 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
None, None,
inputs["cache"], inputs["cache"],
inputs["head_mask"][i], inputs["head_mask"][i],
output_attentions, inputs["output_attentions"],
training=inputs["training"], training=inputs["training"],
) )
attn = attn_outputs[0] attn = attn_outputs[0]
if output_attentions: if inputs["output_attentions"]:
attentions = attentions + (attn_outputs[1],) attentions = attentions + (attn_outputs[1],)
attn = self.dropout(attn, training=inputs["training"]) attn = self.dropout(attn, training=inputs["training"])
...@@ -641,7 +637,7 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer): ...@@ -641,7 +637,7 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
None, None,
inputs["cache"], inputs["cache"],
inputs["head_mask"][i], inputs["head_mask"][i],
output_attentions, inputs["output_attentions"],
training=inputs["training"], training=inputs["training"],
) )
attn = attn_outputs[0] attn = attn_outputs[0]
...@@ -670,7 +666,7 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer): ...@@ -670,7 +666,7 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
tensor = tensor * mask[..., tf.newaxis] tensor = tensor * mask[..., tf.newaxis]
# Add last hidden state # Add last hidden state
if output_hidden_states: if inputs["output_hidden_states"]:
hidden_states = hidden_states + (tensor,) hidden_states = hidden_states + (tensor,)
# update cache length # update cache length
...@@ -681,10 +677,10 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer): ...@@ -681,10 +677,10 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
# tensor = tensor.transpose(0, 1) # tensor = tensor.transpose(0, 1)
# Set to None here if the output booleans are at False # Set to None here if the output booleans are at False
hidden_states = hidden_states if output_hidden_states else None hidden_states = hidden_states if inputs["output_hidden_states"] else None
attentions = attentions if output_attentions else None attentions = attentions if inputs["output_attentions"] else None
if not return_dict: if not inputs["return_dict"]:
return tuple(v for v in [tensor, hidden_states, attentions] if v is not None) return tuple(v for v in [tensor, hidden_states, attentions] if v is not None)
return TFBaseModelOutput(last_hidden_state=tensor, hidden_states=hidden_states, attentions=attentions) return TFBaseModelOutput(last_hidden_state=tensor, hidden_states=hidden_states, attentions=attentions)
...@@ -810,6 +806,7 @@ class TFFlaubertWithLMHeadModel(TFFlaubertPreTrainedModel): ...@@ -810,6 +806,7 @@ class TFFlaubertWithLMHeadModel(TFFlaubertPreTrainedModel):
): ):
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
langs=langs, langs=langs,
...@@ -825,7 +822,6 @@ class TFFlaubertWithLMHeadModel(TFFlaubertPreTrainedModel): ...@@ -825,7 +822,6 @@ class TFFlaubertWithLMHeadModel(TFFlaubertPreTrainedModel):
training=training, training=training,
kwargs_call=kwargs, kwargs_call=kwargs,
) )
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
transformer_outputs = self.transformer( transformer_outputs = self.transformer(
input_ids=inputs["input_ids"], input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"], attention_mask=inputs["attention_mask"],
...@@ -838,13 +834,13 @@ class TFFlaubertWithLMHeadModel(TFFlaubertPreTrainedModel): ...@@ -838,13 +834,13 @@ class TFFlaubertWithLMHeadModel(TFFlaubertPreTrainedModel):
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"], output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict, return_dict=inputs["return_dict"],
training=inputs["training"], training=inputs["training"],
) )
output = transformer_outputs[0] output = transformer_outputs[0]
outputs = self.pred_layer(output) outputs = self.pred_layer(output)
if not return_dict: if not inputs["return_dict"]:
return (outputs,) + transformer_outputs[1:] return (outputs,) + transformer_outputs[1:]
return TFFlaubertWithLMHeadModelOutput( return TFFlaubertWithLMHeadModelOutput(
......
...@@ -764,6 +764,8 @@ class TFFunnelBaseLayer(tf.keras.layers.Layer): ...@@ -764,6 +764,8 @@ class TFFunnelBaseLayer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.config = config
self.output_attentions = config.output_attentions self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states self.output_hidden_states = config.output_hidden_states
self.return_dict = config.use_return_dict self.return_dict = config.use_return_dict
...@@ -795,6 +797,7 @@ class TFFunnelBaseLayer(tf.keras.layers.Layer): ...@@ -795,6 +797,7 @@ class TFFunnelBaseLayer(tf.keras.layers.Layer):
): ):
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -805,13 +808,6 @@ class TFFunnelBaseLayer(tf.keras.layers.Layer): ...@@ -805,13 +808,6 @@ class TFFunnelBaseLayer(tf.keras.layers.Layer):
training=training, training=training,
kwargs_call=kwargs, kwargs_call=kwargs,
) )
output_attentions = (
inputs["output_attentions"] if inputs["output_attentions"] is not None else self.output_attentions
)
output_hidden_states = (
inputs["output_hidden_states"] if inputs["output_hidden_states"] is not None else self.output_hidden_states
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.return_dict
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None: if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
...@@ -835,9 +831,9 @@ class TFFunnelBaseLayer(tf.keras.layers.Layer): ...@@ -835,9 +831,9 @@ class TFFunnelBaseLayer(tf.keras.layers.Layer):
inputs["inputs_embeds"], inputs["inputs_embeds"],
attention_mask=inputs["attention_mask"], attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"], token_type_ids=inputs["token_type_ids"],
output_attentions=output_attentions, output_attentions=inputs["output_attentions"],
output_hidden_states=output_hidden_states, output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict, return_dict=inputs["return_dict"],
training=inputs["training"], training=inputs["training"],
) )
...@@ -852,6 +848,8 @@ class TFFunnelMainLayer(tf.keras.layers.Layer): ...@@ -852,6 +848,8 @@ class TFFunnelMainLayer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.config = config
self.block_sizes = config.block_sizes self.block_sizes = config.block_sizes
self.output_attentions = config.output_attentions self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states self.output_hidden_states = config.output_hidden_states
...@@ -885,6 +883,7 @@ class TFFunnelMainLayer(tf.keras.layers.Layer): ...@@ -885,6 +883,7 @@ class TFFunnelMainLayer(tf.keras.layers.Layer):
): ):
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -895,13 +894,6 @@ class TFFunnelMainLayer(tf.keras.layers.Layer): ...@@ -895,13 +894,6 @@ class TFFunnelMainLayer(tf.keras.layers.Layer):
training=training, training=training,
kwargs_call=kwargs, kwargs_call=kwargs,
) )
output_attentions = (
inputs["output_attentions"] if inputs["output_attentions"] is not None else self.output_attentions
)
output_hidden_states = (
inputs["output_hidden_states"] if inputs["output_hidden_states"] is not None else self.output_hidden_states
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.return_dict
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None: if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
...@@ -925,9 +917,9 @@ class TFFunnelMainLayer(tf.keras.layers.Layer): ...@@ -925,9 +917,9 @@ class TFFunnelMainLayer(tf.keras.layers.Layer):
inputs["inputs_embeds"], inputs["inputs_embeds"],
attention_mask=inputs["attention_mask"], attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"], token_type_ids=inputs["token_type_ids"],
output_attentions=output_attentions, output_attentions=inputs["output_attentions"],
output_hidden_states=True, output_hidden_states=True,
return_dict=return_dict, return_dict=inputs["return_dict"],
training=inputs["training"], training=inputs["training"],
) )
...@@ -936,18 +928,19 @@ class TFFunnelMainLayer(tf.keras.layers.Layer): ...@@ -936,18 +928,19 @@ class TFFunnelMainLayer(tf.keras.layers.Layer):
first_block_hidden=encoder_outputs[1][self.block_sizes[0]], first_block_hidden=encoder_outputs[1][self.block_sizes[0]],
attention_mask=inputs["attention_mask"], attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"], token_type_ids=inputs["token_type_ids"],
output_attentions=output_attentions, output_attentions=inputs["output_attentions"],
output_hidden_states=output_hidden_states, output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict, return_dict=inputs["return_dict"],
training=inputs["training"],
) )
if not return_dict: if not inputs["return_dict"]:
idx = 0 idx = 0
outputs = (decoder_outputs[0],) outputs = (decoder_outputs[0],)
if output_hidden_states: if inputs["output_hidden_states"]:
idx += 1 idx += 1
outputs = outputs + (encoder_outputs[1] + decoder_outputs[idx],) outputs = outputs + (encoder_outputs[1] + decoder_outputs[idx],)
if output_attentions: if inputs["output_attentions"]:
idx += 1 idx += 1
outputs = outputs + (encoder_outputs[2] + decoder_outputs[idx],) outputs = outputs + (encoder_outputs[2] + decoder_outputs[idx],)
return outputs return outputs
...@@ -955,9 +948,11 @@ class TFFunnelMainLayer(tf.keras.layers.Layer): ...@@ -955,9 +948,11 @@ class TFFunnelMainLayer(tf.keras.layers.Layer):
return TFBaseModelOutput( return TFBaseModelOutput(
last_hidden_state=decoder_outputs[0], last_hidden_state=decoder_outputs[0],
hidden_states=(encoder_outputs.hidden_states + decoder_outputs.hidden_states) hidden_states=(encoder_outputs.hidden_states + decoder_outputs.hidden_states)
if output_hidden_states if inputs["output_hidden_states"]
else None,
attentions=(encoder_outputs.attentions + decoder_outputs.attentions)
if inputs["output_attentions"]
else None, else None,
attentions=(encoder_outputs.attentions + decoder_outputs.attentions) if output_attentions else None,
) )
...@@ -1162,6 +1157,7 @@ class TFFunnelBaseModel(TFFunnelPreTrainedModel): ...@@ -1162,6 +1157,7 @@ class TFFunnelBaseModel(TFFunnelPreTrainedModel):
): ):
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -1172,7 +1168,6 @@ class TFFunnelBaseModel(TFFunnelPreTrainedModel): ...@@ -1172,7 +1168,6 @@ class TFFunnelBaseModel(TFFunnelPreTrainedModel):
training=training, training=training,
kwargs_call=kwargs, kwargs_call=kwargs,
) )
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.funnel.return_dict
return self.funnel( return self.funnel(
input_ids=inputs["input_ids"], input_ids=inputs["input_ids"],
...@@ -1181,7 +1176,7 @@ class TFFunnelBaseModel(TFFunnelPreTrainedModel): ...@@ -1181,7 +1176,7 @@ class TFFunnelBaseModel(TFFunnelPreTrainedModel):
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"], output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict, return_dict=inputs["return_dict"],
training=inputs["training"], training=inputs["training"],
) )
...@@ -1216,6 +1211,7 @@ class TFFunnelModel(TFFunnelPreTrainedModel): ...@@ -1216,6 +1211,7 @@ class TFFunnelModel(TFFunnelPreTrainedModel):
): ):
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -1282,6 +1278,7 @@ class TFFunnelForPreTraining(TFFunnelPreTrainedModel): ...@@ -1282,6 +1278,7 @@ class TFFunnelForPreTraining(TFFunnelPreTrainedModel):
""" """
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -1292,7 +1289,6 @@ class TFFunnelForPreTraining(TFFunnelPreTrainedModel): ...@@ -1292,7 +1289,6 @@ class TFFunnelForPreTraining(TFFunnelPreTrainedModel):
training=training, training=training,
kwargs_call=kwargs, kwargs_call=kwargs,
) )
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.funnel.return_dict
discriminator_hidden_states = self.funnel( discriminator_hidden_states = self.funnel(
inputs["input_ids"], inputs["input_ids"],
inputs["attention_mask"], inputs["attention_mask"],
...@@ -1300,13 +1296,13 @@ class TFFunnelForPreTraining(TFFunnelPreTrainedModel): ...@@ -1300,13 +1296,13 @@ class TFFunnelForPreTraining(TFFunnelPreTrainedModel):
inputs["inputs_embeds"], inputs["inputs_embeds"],
inputs["output_attentions"], inputs["output_attentions"],
inputs["output_hidden_states"], inputs["output_hidden_states"],
return_dict=return_dict, return_dict=inputs["return_dict"],
training=inputs["training"], training=inputs["training"],
) )
discriminator_sequence_output = discriminator_hidden_states[0] discriminator_sequence_output = discriminator_hidden_states[0]
logits = self.discriminator_predictions(discriminator_sequence_output) logits = self.discriminator_predictions(discriminator_sequence_output)
if not return_dict: if not inputs["return_dict"]:
return (logits,) + discriminator_hidden_states[1:] return (logits,) + discriminator_hidden_states[1:]
return TFFunnelForPreTrainingOutput( return TFFunnelForPreTrainingOutput(
...@@ -1352,6 +1348,7 @@ class TFFunnelForMaskedLM(TFFunnelPreTrainedModel, TFMaskedLanguageModelingLoss) ...@@ -1352,6 +1348,7 @@ class TFFunnelForMaskedLM(TFFunnelPreTrainedModel, TFMaskedLanguageModelingLoss)
""" """
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -1363,7 +1360,6 @@ class TFFunnelForMaskedLM(TFFunnelPreTrainedModel, TFMaskedLanguageModelingLoss) ...@@ -1363,7 +1360,6 @@ class TFFunnelForMaskedLM(TFFunnelPreTrainedModel, TFMaskedLanguageModelingLoss)
training=training, training=training,
kwargs_call=kwargs, kwargs_call=kwargs,
) )
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.funnel.return_dict
outputs = self.funnel( outputs = self.funnel(
inputs["input_ids"], inputs["input_ids"],
inputs["attention_mask"], inputs["attention_mask"],
...@@ -1379,7 +1375,7 @@ class TFFunnelForMaskedLM(TFFunnelPreTrainedModel, TFMaskedLanguageModelingLoss) ...@@ -1379,7 +1375,7 @@ class TFFunnelForMaskedLM(TFFunnelPreTrainedModel, TFMaskedLanguageModelingLoss)
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], prediction_scores) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], prediction_scores)
if not return_dict: if not inputs["return_dict"]:
output = (prediction_scores,) + outputs[1:] output = (prediction_scores,) + outputs[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
...@@ -1434,6 +1430,7 @@ class TFFunnelForSequenceClassification(TFFunnelPreTrainedModel, TFSequenceClass ...@@ -1434,6 +1430,7 @@ class TFFunnelForSequenceClassification(TFFunnelPreTrainedModel, TFSequenceClass
""" """
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -1445,7 +1442,6 @@ class TFFunnelForSequenceClassification(TFFunnelPreTrainedModel, TFSequenceClass ...@@ -1445,7 +1442,6 @@ class TFFunnelForSequenceClassification(TFFunnelPreTrainedModel, TFSequenceClass
training=training, training=training,
kwargs_call=kwargs, kwargs_call=kwargs,
) )
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.funnel.return_dict
outputs = self.funnel( outputs = self.funnel(
inputs["input_ids"], inputs["input_ids"],
inputs["attention_mask"], inputs["attention_mask"],
...@@ -1453,7 +1449,7 @@ class TFFunnelForSequenceClassification(TFFunnelPreTrainedModel, TFSequenceClass ...@@ -1453,7 +1449,7 @@ class TFFunnelForSequenceClassification(TFFunnelPreTrainedModel, TFSequenceClass
inputs["inputs_embeds"], inputs["inputs_embeds"],
inputs["output_attentions"], inputs["output_attentions"],
inputs["output_hidden_states"], inputs["output_hidden_states"],
return_dict=return_dict, return_dict=inputs["return_dict"],
training=inputs["training"], training=inputs["training"],
) )
...@@ -1463,7 +1459,7 @@ class TFFunnelForSequenceClassification(TFFunnelPreTrainedModel, TFSequenceClass ...@@ -1463,7 +1459,7 @@ class TFFunnelForSequenceClassification(TFFunnelPreTrainedModel, TFSequenceClass
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
if not return_dict: if not inputs["return_dict"]:
output = (logits,) + outputs[1:] output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
...@@ -1527,6 +1523,7 @@ class TFFunnelForMultipleChoice(TFFunnelPreTrainedModel, TFMultipleChoiceLoss): ...@@ -1527,6 +1523,7 @@ class TFFunnelForMultipleChoice(TFFunnelPreTrainedModel, TFMultipleChoiceLoss):
""" """
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -1538,7 +1535,6 @@ class TFFunnelForMultipleChoice(TFFunnelPreTrainedModel, TFMultipleChoiceLoss): ...@@ -1538,7 +1535,6 @@ class TFFunnelForMultipleChoice(TFFunnelPreTrainedModel, TFMultipleChoiceLoss):
training=training, training=training,
kwargs_call=kwargs, kwargs_call=kwargs,
) )
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.funnel.return_dict
if inputs["input_ids"] is not None: if inputs["input_ids"] is not None:
num_choices = shape_list(inputs["input_ids"])[1] num_choices = shape_list(inputs["input_ids"])[1]
...@@ -1567,7 +1563,7 @@ class TFFunnelForMultipleChoice(TFFunnelPreTrainedModel, TFMultipleChoiceLoss): ...@@ -1567,7 +1563,7 @@ class TFFunnelForMultipleChoice(TFFunnelPreTrainedModel, TFMultipleChoiceLoss):
inputs_embeds=flat_inputs_embeds, inputs_embeds=flat_inputs_embeds,
output_attentions=inputs["output_attentions"], output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict, return_dict=inputs["return_dict"],
training=inputs["training"], training=inputs["training"],
) )
...@@ -1578,7 +1574,7 @@ class TFFunnelForMultipleChoice(TFFunnelPreTrainedModel, TFMultipleChoiceLoss): ...@@ -1578,7 +1574,7 @@ class TFFunnelForMultipleChoice(TFFunnelPreTrainedModel, TFMultipleChoiceLoss):
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], reshaped_logits) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], reshaped_logits)
if not return_dict: if not inputs["return_dict"]:
output = (reshaped_logits,) + outputs[1:] output = (reshaped_logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
...@@ -1635,6 +1631,7 @@ class TFFunnelForTokenClassification(TFFunnelPreTrainedModel, TFTokenClassificat ...@@ -1635,6 +1631,7 @@ class TFFunnelForTokenClassification(TFFunnelPreTrainedModel, TFTokenClassificat
""" """
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -1646,7 +1643,6 @@ class TFFunnelForTokenClassification(TFFunnelPreTrainedModel, TFTokenClassificat ...@@ -1646,7 +1643,6 @@ class TFFunnelForTokenClassification(TFFunnelPreTrainedModel, TFTokenClassificat
training=training, training=training,
kwargs_call=kwargs, kwargs_call=kwargs,
) )
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.funnel.return_dict
outputs = self.funnel( outputs = self.funnel(
inputs["input_ids"], inputs["input_ids"],
inputs["attention_mask"], inputs["attention_mask"],
...@@ -1654,10 +1650,9 @@ class TFFunnelForTokenClassification(TFFunnelPreTrainedModel, TFTokenClassificat ...@@ -1654,10 +1650,9 @@ class TFFunnelForTokenClassification(TFFunnelPreTrainedModel, TFTokenClassificat
inputs["inputs_embeds"], inputs["inputs_embeds"],
inputs["output_attentions"], inputs["output_attentions"],
inputs["output_hidden_states"], inputs["output_hidden_states"],
return_dict=return_dict, return_dict=inputs["return_dict"],
training=inputs["training"], training=inputs["training"],
) )
sequence_output = outputs[0] sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output, training=inputs["training"]) sequence_output = self.dropout(sequence_output, training=inputs["training"])
...@@ -1665,7 +1660,7 @@ class TFFunnelForTokenClassification(TFFunnelPreTrainedModel, TFTokenClassificat ...@@ -1665,7 +1660,7 @@ class TFFunnelForTokenClassification(TFFunnelPreTrainedModel, TFTokenClassificat
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
if not return_dict: if not inputs["return_dict"]:
output = (logits,) + outputs[1:] output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
...@@ -1727,6 +1722,7 @@ class TFFunnelForQuestionAnswering(TFFunnelPreTrainedModel, TFQuestionAnsweringL ...@@ -1727,6 +1722,7 @@ class TFFunnelForQuestionAnswering(TFFunnelPreTrainedModel, TFQuestionAnsweringL
""" """
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -1747,10 +1743,9 @@ class TFFunnelForQuestionAnswering(TFFunnelPreTrainedModel, TFQuestionAnsweringL ...@@ -1747,10 +1743,9 @@ class TFFunnelForQuestionAnswering(TFFunnelPreTrainedModel, TFQuestionAnsweringL
inputs["inputs_embeds"], inputs["inputs_embeds"],
inputs["output_attentions"], inputs["output_attentions"],
inputs["output_hidden_states"], inputs["output_hidden_states"],
return_dict=return_dict, return_dict=inputs["return_dict"],
training=inputs["training"], training=inputs["training"],
) )
sequence_output = outputs[0] sequence_output = outputs[0]
logits = self.qa_outputs(sequence_output) logits = self.qa_outputs(sequence_output)
...@@ -1763,7 +1758,7 @@ class TFFunnelForQuestionAnswering(TFFunnelPreTrainedModel, TFQuestionAnsweringL ...@@ -1763,7 +1758,7 @@ class TFFunnelForQuestionAnswering(TFFunnelPreTrainedModel, TFQuestionAnsweringL
labels = {"start_position": inputs["start_positions"], "end_position": inputs["end_positions"]} labels = {"start_position": inputs["start_positions"], "end_position": inputs["end_positions"]}
loss = self.compute_loss(labels, (start_logits, end_logits)) loss = self.compute_loss(labels, (start_logits, end_logits))
if not return_dict: if not inputs["return_dict"]:
output = (start_logits, end_logits) + outputs[1:] output = (start_logits, end_logits) + outputs[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
......
...@@ -209,6 +209,8 @@ class TFGPT2MainLayer(tf.keras.layers.Layer): ...@@ -209,6 +209,8 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super().__init__(*inputs, **kwargs) super().__init__(*inputs, **kwargs)
self.config = config
self.output_attentions = config.output_attentions self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states self.output_hidden_states = config.output_hidden_states
self.use_cache = config.use_cache self.use_cache = config.use_cache
...@@ -262,6 +264,7 @@ class TFGPT2MainLayer(tf.keras.layers.Layer): ...@@ -262,6 +264,7 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
): ):
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
past=past, past=past,
attention_mask=attention_mask, attention_mask=attention_mask,
...@@ -276,14 +279,6 @@ class TFGPT2MainLayer(tf.keras.layers.Layer): ...@@ -276,14 +279,6 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
training=training, training=training,
kwargs_call=kwargs, kwargs_call=kwargs,
) )
output_attentions = (
inputs["output_attentions"] if inputs["output_attentions"] is not None else self.output_attentions
)
output_hidden_states = (
inputs["output_hidden_states"] if inputs["output_hidden_states"] is not None else self.output_hidden_states
)
use_cache = inputs["use_cache"] if inputs["use_cache"] is not None else self.use_cache
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.return_dict
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None: if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
...@@ -358,11 +353,11 @@ class TFGPT2MainLayer(tf.keras.layers.Layer): ...@@ -358,11 +353,11 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
output_shape = input_shape + [shape_list(hidden_states)[-1]] output_shape = input_shape + [shape_list(hidden_states)[-1]]
presents = () if use_cache else None presents = () if inputs["use_cache"] else None
all_attentions = () if output_attentions else None all_attentions = () if inputs["output_attentions"] else None
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if inputs["output_hidden_states"] else None
for i, (block, layer_past) in enumerate(zip(self.h, inputs["past"])): for i, (block, layer_past) in enumerate(zip(self.h, inputs["past"])):
if output_hidden_states: if inputs["output_hidden_states"]:
all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),) all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),)
outputs = block( outputs = block(
...@@ -370,31 +365,31 @@ class TFGPT2MainLayer(tf.keras.layers.Layer): ...@@ -370,31 +365,31 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
layer_past, layer_past,
inputs["attention_mask"], inputs["attention_mask"],
inputs["head_mask"][i], inputs["head_mask"][i],
use_cache, inputs["use_cache"],
output_attentions, inputs["output_attentions"],
training=inputs["training"], training=inputs["training"],
) )
hidden_states, present = outputs[:2] hidden_states, present = outputs[:2]
if use_cache: if inputs["use_cache"]:
presents = presents + (present,) presents = presents + (present,)
if output_attentions: if inputs["output_attentions"]:
all_attentions = all_attentions + (outputs[2],) all_attentions = all_attentions + (outputs[2],)
hidden_states = self.ln_f(hidden_states) hidden_states = self.ln_f(hidden_states)
hidden_states = tf.reshape(hidden_states, output_shape) hidden_states = tf.reshape(hidden_states, output_shape)
# Add last hidden state # Add last hidden state
if output_hidden_states: if inputs["output_hidden_states"]:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if output_attentions: if inputs["output_attentions"]:
# let the number of heads free (-1) so we can extract attention even after head pruning # let the number of heads free (-1) so we can extract attention even after head pruning
attention_output_shape = input_shape[:-1] + [-1] + shape_list(all_attentions[0])[-2:] attention_output_shape = input_shape[:-1] + [-1] + shape_list(all_attentions[0])[-2:]
all_attentions = tuple(tf.reshape(t, attention_output_shape) for t in all_attentions) all_attentions = tuple(tf.reshape(t, attention_output_shape) for t in all_attentions)
if not return_dict: if not inputs["return_dict"]:
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_attentions] if v is not None) return tuple(v for v in [hidden_states, presents, all_hidden_states, all_attentions] if v is not None)
return TFBaseModelOutputWithPast( return TFBaseModelOutputWithPast(
...@@ -583,6 +578,7 @@ class TFGPT2Model(TFGPT2PreTrainedModel): ...@@ -583,6 +578,7 @@ class TFGPT2Model(TFGPT2PreTrainedModel):
): ):
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
past=past, past=past,
attention_mask=attention_mask, attention_mask=attention_mask,
...@@ -668,6 +664,7 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -668,6 +664,7 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss):
""" """
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
past=past, past=past,
attention_mask=attention_mask, attention_mask=attention_mask,
...@@ -683,7 +680,6 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -683,7 +680,6 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss):
training=training, training=training,
kwargs_call=kwargs, kwargs_call=kwargs,
) )
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
transformer_outputs = self.transformer( transformer_outputs = self.transformer(
input_ids=inputs["input_ids"], input_ids=inputs["input_ids"],
past=inputs["past"], past=inputs["past"],
...@@ -695,7 +691,7 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -695,7 +691,7 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss):
use_cache=inputs["use_cache"], use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"], output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict, return_dict=inputs["return_dict"],
training=inputs["training"], training=inputs["training"],
) )
hidden_states = transformer_outputs[0] hidden_states = transformer_outputs[0]
...@@ -708,7 +704,7 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -708,7 +704,7 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss):
labels = inputs["labels"][:, 1:] labels = inputs["labels"][:, 1:]
loss = self.compute_loss(labels, logits) loss = self.compute_loss(labels, logits)
if not return_dict: if not inputs["return_dict"]:
output = (logits,) + transformer_outputs[1:] output = (logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
...@@ -794,6 +790,7 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel): ...@@ -794,6 +790,7 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
""" """
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
past=past, past=past,
attention_mask=attention_mask, attention_mask=attention_mask,
...@@ -809,7 +806,6 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel): ...@@ -809,7 +806,6 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
training=training, training=training,
kwargs_call=kwargs, kwargs_call=kwargs,
) )
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
if inputs["input_ids"] is not None: if inputs["input_ids"] is not None:
input_shapes = shape_list(inputs["input_ids"]) input_shapes = shape_list(inputs["input_ids"])
...@@ -838,7 +834,7 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel): ...@@ -838,7 +834,7 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
inputs["use_cache"], inputs["use_cache"],
inputs["output_attentions"], inputs["output_attentions"],
inputs["output_hidden_states"], inputs["output_hidden_states"],
return_dict=return_dict, return_dict=inputs["return_dict"],
training=inputs["training"], training=inputs["training"],
) )
hidden_states = transformer_outputs[0] hidden_states = transformer_outputs[0]
...@@ -847,7 +843,7 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel): ...@@ -847,7 +843,7 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
mc_logits = self.multiple_choice_head(hidden_states, inputs["mc_token_ids"], training=inputs["training"]) mc_logits = self.multiple_choice_head(hidden_states, inputs["mc_token_ids"], training=inputs["training"])
mc_logits = tf.squeeze(mc_logits, axis=-1) mc_logits = tf.squeeze(mc_logits, axis=-1)
if not return_dict: if not inputs["return_dict"]:
return (lm_logits, mc_logits) + transformer_outputs[1:] return (lm_logits, mc_logits) + transformer_outputs[1:]
return TFGPT2DoubleHeadsModelOutput( return TFGPT2DoubleHeadsModelOutput(
......
...@@ -1579,6 +1579,7 @@ class TFLongformerMainLayer(tf.keras.layers.Layer): ...@@ -1579,6 +1579,7 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
f"Expected {config.num_hidden_layers}, given {len(config.attention_window)}" f"Expected {config.num_hidden_layers}, given {len(config.attention_window)}"
) )
self.config = config
self.num_hidden_layers = config.num_hidden_layers self.num_hidden_layers = config.num_hidden_layers
self.initializer_range = config.initializer_range self.initializer_range = config.initializer_range
self.output_attentions = config.output_attentions self.output_attentions = config.output_attentions
...@@ -1620,6 +1621,7 @@ class TFLongformerMainLayer(tf.keras.layers.Layer): ...@@ -1620,6 +1621,7 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
): ):
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
global_attention_mask=global_attention_mask, global_attention_mask=global_attention_mask,
...@@ -1632,13 +1634,6 @@ class TFLongformerMainLayer(tf.keras.layers.Layer): ...@@ -1632,13 +1634,6 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
training=training, training=training,
kwargs_call=kwargs, kwargs_call=kwargs,
) )
output_attentions = (
inputs["output_attentions"] if inputs["output_attentions"] is not None else self.output_attentions
)
output_hidden_states = (
inputs["output_hidden_states"] if inputs["output_hidden_states"] is not None else self.output_hidden_states
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.return_dict
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None: if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
...@@ -1709,9 +1704,9 @@ class TFLongformerMainLayer(tf.keras.layers.Layer): ...@@ -1709,9 +1704,9 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
is_index_masked=is_index_masked, is_index_masked=is_index_masked,
is_index_global_attn=is_index_global_attn, is_index_global_attn=is_index_global_attn,
is_global_attn=is_global_attn, is_global_attn=is_global_attn,
output_attentions=output_attentions, output_attentions=inputs["output_attentions"],
output_hidden_states=output_hidden_states, output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict, return_dict=inputs["return_dict"],
training=inputs["training"], training=inputs["training"],
) )
sequence_output = encoder_outputs[0] sequence_output = encoder_outputs[0]
...@@ -1722,7 +1717,7 @@ class TFLongformerMainLayer(tf.keras.layers.Layer): ...@@ -1722,7 +1717,7 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
# unpad `sequence_output` because the calling function is expecting a length == input_ids.size(1) # unpad `sequence_output` because the calling function is expecting a length == input_ids.size(1)
sequence_output = sequence_output[:, :-padding_len] sequence_output = sequence_output[:, :-padding_len]
if not return_dict: if not inputs["return_dict"]:
return ( return (
sequence_output, sequence_output,
pooled_output, pooled_output,
...@@ -1968,6 +1963,7 @@ class TFLongformerModel(TFLongformerPreTrainedModel): ...@@ -1968,6 +1963,7 @@ class TFLongformerModel(TFLongformerPreTrainedModel):
): ):
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
global_attention_mask=global_attention_mask, global_attention_mask=global_attention_mask,
...@@ -2043,6 +2039,7 @@ class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModel ...@@ -2043,6 +2039,7 @@ class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModel
""" """
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
global_attention_mask=global_attention_mask, global_attention_mask=global_attention_mask,
...@@ -2056,7 +2053,6 @@ class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModel ...@@ -2056,7 +2053,6 @@ class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModel
training=training, training=training,
kwargs_call=kwargs, kwargs_call=kwargs,
) )
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.longformer.return_dict
outputs = self.longformer( outputs = self.longformer(
input_ids=inputs["input_ids"], input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"], attention_mask=inputs["attention_mask"],
...@@ -2066,14 +2062,14 @@ class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModel ...@@ -2066,14 +2062,14 @@ class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModel
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"], output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict, return_dict=inputs["return_dict"],
training=inputs["training"], training=inputs["training"],
) )
sequence_output = outputs[0] sequence_output = outputs[0]
prediction_scores = self.lm_head(sequence_output, training=training) prediction_scores = self.lm_head(sequence_output, training=inputs["training"])
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], prediction_scores) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], prediction_scores)
if not return_dict: if not inputs["return_dict"]:
output = (prediction_scores,) + outputs[2:] output = (prediction_scores,) + outputs[2:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
...@@ -2144,6 +2140,7 @@ class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAn ...@@ -2144,6 +2140,7 @@ class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAn
""" """
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
global_attention_mask=global_attention_mask, global_attention_mask=global_attention_mask,
...@@ -2158,7 +2155,6 @@ class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAn ...@@ -2158,7 +2155,6 @@ class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAn
training=training, training=training,
kwargs_call=kwargs, kwargs_call=kwargs,
) )
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.longformer.return_dict
# set global attention on question tokens # set global attention on question tokens
if inputs["global_attention_mask"] is None and inputs["input_ids"] is not None: if inputs["global_attention_mask"] is None and inputs["input_ids"] is not None:
...@@ -2189,7 +2185,7 @@ class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAn ...@@ -2189,7 +2185,7 @@ class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAn
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"], output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict, return_dict=inputs["return_dict"],
training=inputs["training"], training=inputs["training"],
) )
sequence_output = outputs[0] sequence_output = outputs[0]
...@@ -2204,7 +2200,7 @@ class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAn ...@@ -2204,7 +2200,7 @@ class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAn
labels["end_position"] = inputs["end_positions"] labels["end_position"] = inputs["end_positions"]
loss = self.compute_loss(labels, (start_logits, end_logits)) loss = self.compute_loss(labels, (start_logits, end_logits))
if not return_dict: if not inputs["return_dict"]:
output = (start_logits, end_logits) + outputs[2:] output = (start_logits, end_logits) + outputs[2:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
...@@ -2287,6 +2283,7 @@ class TFLongformerForSequenceClassification(TFLongformerPreTrainedModel, TFSeque ...@@ -2287,6 +2283,7 @@ class TFLongformerForSequenceClassification(TFLongformerPreTrainedModel, TFSeque
): ):
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
global_attention_mask=global_attention_mask, global_attention_mask=global_attention_mask,
...@@ -2300,7 +2297,6 @@ class TFLongformerForSequenceClassification(TFLongformerPreTrainedModel, TFSeque ...@@ -2300,7 +2297,6 @@ class TFLongformerForSequenceClassification(TFLongformerPreTrainedModel, TFSeque
training=training, training=training,
kwargs_call=kwargs, kwargs_call=kwargs,
) )
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.longformer.return_dict
if inputs["global_attention_mask"] is None and inputs["input_ids"] is not None: if inputs["global_attention_mask"] is None and inputs["input_ids"] is not None:
logger.info("Initializing global attention on CLS token...") logger.info("Initializing global attention on CLS token...")
...@@ -2321,7 +2317,7 @@ class TFLongformerForSequenceClassification(TFLongformerPreTrainedModel, TFSeque ...@@ -2321,7 +2317,7 @@ class TFLongformerForSequenceClassification(TFLongformerPreTrainedModel, TFSeque
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"], output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict, return_dict=inputs["return_dict"],
training=inputs["training"], training=inputs["training"],
) )
sequence_output = outputs[0] sequence_output = outputs[0]
...@@ -2329,7 +2325,7 @@ class TFLongformerForSequenceClassification(TFLongformerPreTrainedModel, TFSeque ...@@ -2329,7 +2325,7 @@ class TFLongformerForSequenceClassification(TFLongformerPreTrainedModel, TFSeque
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
if not return_dict: if not inputs["return_dict"]:
output = (logits,) + outputs[2:] output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
...@@ -2398,6 +2394,7 @@ class TFLongformerForMultipleChoice(TFLongformerPreTrainedModel, TFMultipleChoic ...@@ -2398,6 +2394,7 @@ class TFLongformerForMultipleChoice(TFLongformerPreTrainedModel, TFMultipleChoic
""" """
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
global_attention_mask=global_attention_mask, global_attention_mask=global_attention_mask,
...@@ -2411,7 +2408,6 @@ class TFLongformerForMultipleChoice(TFLongformerPreTrainedModel, TFMultipleChoic ...@@ -2411,7 +2408,6 @@ class TFLongformerForMultipleChoice(TFLongformerPreTrainedModel, TFMultipleChoic
training=training, training=training,
kwargs_call=kwargs, kwargs_call=kwargs,
) )
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.longformer.return_dict
if inputs["input_ids"] is not None: if inputs["input_ids"] is not None:
num_choices = shape_list(inputs["input_ids"])[1] num_choices = shape_list(inputs["input_ids"])[1]
...@@ -2450,7 +2446,7 @@ class TFLongformerForMultipleChoice(TFLongformerPreTrainedModel, TFMultipleChoic ...@@ -2450,7 +2446,7 @@ class TFLongformerForMultipleChoice(TFLongformerPreTrainedModel, TFMultipleChoic
inputs_embeds=flat_inputs_embeds, inputs_embeds=flat_inputs_embeds,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=inputs["return_dict"],
training=inputs["training"], training=inputs["training"],
) )
pooled_output = outputs[1] pooled_output = outputs[1]
...@@ -2461,7 +2457,7 @@ class TFLongformerForMultipleChoice(TFLongformerPreTrainedModel, TFMultipleChoic ...@@ -2461,7 +2457,7 @@ class TFLongformerForMultipleChoice(TFLongformerPreTrainedModel, TFMultipleChoic
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], reshaped_logits) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], reshaped_logits)
if not return_dict: if not inputs["return_dict"]:
output = (reshaped_logits,) + outputs[2:] output = (reshaped_logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
...@@ -2524,6 +2520,7 @@ class TFLongformerForTokenClassification(TFLongformerPreTrainedModel, TFTokenCla ...@@ -2524,6 +2520,7 @@ class TFLongformerForTokenClassification(TFLongformerPreTrainedModel, TFTokenCla
""" """
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
global_attention_mask=global_attention_mask, global_attention_mask=global_attention_mask,
...@@ -2537,7 +2534,6 @@ class TFLongformerForTokenClassification(TFLongformerPreTrainedModel, TFTokenCla ...@@ -2537,7 +2534,6 @@ class TFLongformerForTokenClassification(TFLongformerPreTrainedModel, TFTokenCla
training=training, training=training,
kwargs_call=kwargs, kwargs_call=kwargs,
) )
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.longformer.return_dict
outputs = self.longformer( outputs = self.longformer(
input_ids=inputs["input_ids"], input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"], attention_mask=inputs["attention_mask"],
...@@ -2547,7 +2543,7 @@ class TFLongformerForTokenClassification(TFLongformerPreTrainedModel, TFTokenCla ...@@ -2547,7 +2543,7 @@ class TFLongformerForTokenClassification(TFLongformerPreTrainedModel, TFTokenCla
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"], output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict, return_dict=inputs["return_dict"],
training=inputs["training"], training=inputs["training"],
) )
sequence_output = outputs[0] sequence_output = outputs[0]
...@@ -2555,7 +2551,7 @@ class TFLongformerForTokenClassification(TFLongformerPreTrainedModel, TFTokenCla ...@@ -2555,7 +2551,7 @@ class TFLongformerForTokenClassification(TFLongformerPreTrainedModel, TFTokenCla
logits = self.classifier(sequence_output) logits = self.classifier(sequence_output)
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
if not return_dict: if not inputs["return_dict"]:
output = (logits,) + outputs[2:] output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
......
...@@ -128,7 +128,7 @@ class TFLxmertForPreTrainingOutput(ModelOutput): ...@@ -128,7 +128,7 @@ class TFLxmertForPreTrainingOutput(ModelOutput):
""" """
loss: [tf.Tensor] = None loss: Optional[tf.Tensor] = None
prediction_logits: Optional[tf.Tensor] = None prediction_logits: Optional[tf.Tensor] = None
cross_relationship_score: Optional[tf.Tensor] = None cross_relationship_score: Optional[tf.Tensor] = None
question_answering_score: Optional[tf.Tensor] = None question_answering_score: Optional[tf.Tensor] = None
...@@ -687,6 +687,8 @@ class TFLxmertMainLayer(tf.keras.layers.Layer): ...@@ -687,6 +687,8 @@ class TFLxmertMainLayer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.config = config
self.num_l_layers = config.l_layers self.num_l_layers = config.l_layers
self.num_x_layers = config.x_layers self.num_x_layers = config.x_layers
self.num_r_layers = config.r_layers self.num_r_layers = config.r_layers
...@@ -729,6 +731,7 @@ class TFLxmertMainLayer(tf.keras.layers.Layer): ...@@ -729,6 +731,7 @@ class TFLxmertMainLayer(tf.keras.layers.Layer):
): ):
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
visual_feats=visual_feats, visual_feats=visual_feats,
visual_pos=visual_pos, visual_pos=visual_pos,
...@@ -742,13 +745,6 @@ class TFLxmertMainLayer(tf.keras.layers.Layer): ...@@ -742,13 +745,6 @@ class TFLxmertMainLayer(tf.keras.layers.Layer):
training=training, training=training,
kwargs_call=kwargs, kwargs_call=kwargs,
) )
output_attentions = (
inputs["output_attentions"] if inputs["output_attentions"] is not None else self.output_attentions
)
output_hidden_states = (
inputs["output_hidden_states"] if inputs["output_hidden_states"] is not None else self.output_hidden_states
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.return_dict
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None: if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
...@@ -804,7 +800,7 @@ class TFLxmertMainLayer(tf.keras.layers.Layer): ...@@ -804,7 +800,7 @@ class TFLxmertMainLayer(tf.keras.layers.Layer):
inputs["visual_feats"], inputs["visual_feats"],
inputs["visual_pos"], inputs["visual_pos"],
extended_visual_attention_mask, extended_visual_attention_mask,
output_attentions=output_attentions, output_attentions=inputs["output_attentions"],
training=inputs["training"], training=inputs["training"],
) )
visual_encoder_outputs, lang_encoder_outputs = encoder_outputs[:2] visual_encoder_outputs, lang_encoder_outputs = encoder_outputs[:2]
...@@ -812,7 +808,7 @@ class TFLxmertMainLayer(tf.keras.layers.Layer): ...@@ -812,7 +808,7 @@ class TFLxmertMainLayer(tf.keras.layers.Layer):
language_hidden_states = lang_encoder_outputs[0] language_hidden_states = lang_encoder_outputs[0]
all_attentions = () all_attentions = ()
if output_attentions: if inputs["output_attentions"]:
language_attentions = lang_encoder_outputs[1] language_attentions = lang_encoder_outputs[1]
vision_attentions = visual_encoder_outputs[1] vision_attentions = visual_encoder_outputs[1]
cross_encoder_attentions = encoder_outputs[2] cross_encoder_attentions = encoder_outputs[2]
...@@ -822,24 +818,24 @@ class TFLxmertMainLayer(tf.keras.layers.Layer): ...@@ -822,24 +818,24 @@ class TFLxmertMainLayer(tf.keras.layers.Layer):
cross_encoder_attentions, cross_encoder_attentions,
) )
hidden_states = (language_hidden_states, vision_hidden_states) if output_hidden_states else () hidden_states = (language_hidden_states, vision_hidden_states) if inputs["output_hidden_states"] else ()
visual_output = vision_hidden_states[-1] visual_output = vision_hidden_states[-1]
lang_output = language_hidden_states[-1] lang_output = language_hidden_states[-1]
pooled_output = self.pooler(lang_output) pooled_output = self.pooler(lang_output)
if not return_dict: if not inputs["return_dict"]:
return (lang_output, visual_output, pooled_output) + hidden_states + all_attentions return (lang_output, visual_output, pooled_output) + hidden_states + all_attentions
return TFLxmertModelOutput( return TFLxmertModelOutput(
pooled_output=pooled_output, pooled_output=pooled_output,
language_output=lang_output, language_output=lang_output,
vision_output=visual_output, vision_output=visual_output,
language_hidden_states=language_hidden_states if output_hidden_states else None, language_hidden_states=language_hidden_states if inputs["output_hidden_states"] else None,
vision_hidden_states=vision_hidden_states if output_hidden_states else None, vision_hidden_states=vision_hidden_states if inputs["output_hidden_states"] else None,
language_attentions=language_attentions if output_attentions else None, language_attentions=language_attentions if inputs["output_attentions"] else None,
vision_attentions=vision_attentions if output_attentions else None, vision_attentions=vision_attentions if inputs["output_attentions"] else None,
cross_encoder_attentions=cross_encoder_attentions if output_attentions else None, cross_encoder_attentions=cross_encoder_attentions if inputs["output_attentions"] else None,
) )
...@@ -989,6 +985,7 @@ class TFLxmertModel(TFLxmertPreTrainedModel): ...@@ -989,6 +985,7 @@ class TFLxmertModel(TFLxmertPreTrainedModel):
): ):
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
visual_feats=visual_feats, visual_feats=visual_feats,
visual_pos=visual_pos, visual_pos=visual_pos,
...@@ -1304,6 +1301,7 @@ class TFLxmertForPreTraining(TFLxmertPreTrainedModel): ...@@ -1304,6 +1301,7 @@ class TFLxmertForPreTraining(TFLxmertPreTrainedModel):
""" """
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
visual_feats=visual_feats, visual_feats=visual_feats,
visual_pos=visual_pos, visual_pos=visual_pos,
...@@ -1321,7 +1319,6 @@ class TFLxmertForPreTraining(TFLxmertPreTrainedModel): ...@@ -1321,7 +1319,6 @@ class TFLxmertForPreTraining(TFLxmertPreTrainedModel):
training=training, training=training,
kwargs_call=kwargs, kwargs_call=kwargs,
) )
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.lxmert.return_dict
lxmert_output = self.lxmert( lxmert_output = self.lxmert(
input_ids=inputs["input_ids"], input_ids=inputs["input_ids"],
visual_feats=inputs["visual_feats"], visual_feats=inputs["visual_feats"],
...@@ -1407,7 +1404,7 @@ class TFLxmertForPreTraining(TFLxmertPreTrainedModel): ...@@ -1407,7 +1404,7 @@ class TFLxmertForPreTraining(TFLxmertPreTrainedModel):
losses += (answer_loss,) losses += (answer_loss,)
# return total_loss, tf.stack(losses)[tf.new_axis, ...], answer_score.detach() # return total_loss, tf.stack(losses)[tf.new_axis, ...], answer_score.detach()
if not return_dict: if not inputs["return_dict"]:
output = ( output = (
lang_prediction_scores, lang_prediction_scores,
cross_relationship_score, cross_relationship_score,
......
...@@ -688,6 +688,8 @@ class TFMobileBertMainLayer(tf.keras.layers.Layer): ...@@ -688,6 +688,8 @@ class TFMobileBertMainLayer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.config = config
self.num_hidden_layers = config.num_hidden_layers self.num_hidden_layers = config.num_hidden_layers
self.output_attentions = config.output_attentions self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states self.output_hidden_states = config.output_hidden_states
...@@ -726,6 +728,7 @@ class TFMobileBertMainLayer(tf.keras.layers.Layer): ...@@ -726,6 +728,7 @@ class TFMobileBertMainLayer(tf.keras.layers.Layer):
): ):
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -738,13 +741,6 @@ class TFMobileBertMainLayer(tf.keras.layers.Layer): ...@@ -738,13 +741,6 @@ class TFMobileBertMainLayer(tf.keras.layers.Layer):
training=training, training=training,
kwargs_call=kwargs, kwargs_call=kwargs,
) )
output_attentions = (
inputs["output_attentions"] if inputs["output_attentions"] is not None else self.output_attentions
)
output_hidden_states = (
inputs["output_hidden_states"] if inputs["output_hidden_states"] is not None else self.output_hidden_states
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.return_dict
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None: if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
...@@ -798,16 +794,16 @@ class TFMobileBertMainLayer(tf.keras.layers.Layer): ...@@ -798,16 +794,16 @@ class TFMobileBertMainLayer(tf.keras.layers.Layer):
embedding_output, embedding_output,
extended_attention_mask, extended_attention_mask,
inputs["head_mask"], inputs["head_mask"],
output_attentions, inputs["output_attentions"],
output_hidden_states, inputs["output_hidden_states"],
return_dict, inputs["return_dict"],
training=inputs["training"], training=inputs["training"],
) )
sequence_output = encoder_outputs[0] sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output) pooled_output = self.pooler(sequence_output)
if not return_dict: if not inputs["return_dict"]:
return ( return (
sequence_output, sequence_output,
pooled_output, pooled_output,
...@@ -984,6 +980,7 @@ class TFMobileBertModel(TFMobileBertPreTrainedModel): ...@@ -984,6 +980,7 @@ class TFMobileBertModel(TFMobileBertPreTrainedModel):
): ):
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -1062,6 +1059,7 @@ class TFMobileBertForPreTraining(TFMobileBertPreTrainedModel): ...@@ -1062,6 +1059,7 @@ class TFMobileBertForPreTraining(TFMobileBertPreTrainedModel):
""" """
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -1074,7 +1072,6 @@ class TFMobileBertForPreTraining(TFMobileBertPreTrainedModel): ...@@ -1074,7 +1072,6 @@ class TFMobileBertForPreTraining(TFMobileBertPreTrainedModel):
training=training, training=training,
kwargs_call=kwargs, kwargs_call=kwargs,
) )
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.mobilebert.return_dict
outputs = self.mobilebert( outputs = self.mobilebert(
inputs["input_ids"], inputs["input_ids"],
attention_mask=inputs["attention_mask"], attention_mask=inputs["attention_mask"],
...@@ -1084,7 +1081,7 @@ class TFMobileBertForPreTraining(TFMobileBertPreTrainedModel): ...@@ -1084,7 +1081,7 @@ class TFMobileBertForPreTraining(TFMobileBertPreTrainedModel):
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"], output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict, return_dict=inputs["return_dict"],
training=inputs["training"], training=inputs["training"],
) )
...@@ -1092,7 +1089,7 @@ class TFMobileBertForPreTraining(TFMobileBertPreTrainedModel): ...@@ -1092,7 +1089,7 @@ class TFMobileBertForPreTraining(TFMobileBertPreTrainedModel):
prediction_scores = self.predictions(sequence_output) prediction_scores = self.predictions(sequence_output)
seq_relationship_score = self.seq_relationship(pooled_output) seq_relationship_score = self.seq_relationship(pooled_output)
if not return_dict: if not inputs["return_dict"]:
return (prediction_scores, seq_relationship_score) + outputs[2:] return (prediction_scores, seq_relationship_score) + outputs[2:]
return TFMobileBertForPreTrainingOutput( return TFMobileBertForPreTrainingOutput(
...@@ -1147,6 +1144,7 @@ class TFMobileBertForMaskedLM(TFMobileBertPreTrainedModel, TFMaskedLanguageModel ...@@ -1147,6 +1144,7 @@ class TFMobileBertForMaskedLM(TFMobileBertPreTrainedModel, TFMaskedLanguageModel
""" """
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -1160,7 +1158,6 @@ class TFMobileBertForMaskedLM(TFMobileBertPreTrainedModel, TFMaskedLanguageModel ...@@ -1160,7 +1158,6 @@ class TFMobileBertForMaskedLM(TFMobileBertPreTrainedModel, TFMaskedLanguageModel
training=training, training=training,
kwargs_call=kwargs, kwargs_call=kwargs,
) )
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.mobilebert.return_dict
outputs = self.mobilebert( outputs = self.mobilebert(
inputs["input_ids"], inputs["input_ids"],
attention_mask=inputs["attention_mask"], attention_mask=inputs["attention_mask"],
...@@ -1170,7 +1167,7 @@ class TFMobileBertForMaskedLM(TFMobileBertPreTrainedModel, TFMaskedLanguageModel ...@@ -1170,7 +1167,7 @@ class TFMobileBertForMaskedLM(TFMobileBertPreTrainedModel, TFMaskedLanguageModel
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"], output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict, return_dict=inputs["return_dict"],
training=inputs["training"], training=inputs["training"],
) )
...@@ -1179,7 +1176,7 @@ class TFMobileBertForMaskedLM(TFMobileBertPreTrainedModel, TFMaskedLanguageModel ...@@ -1179,7 +1176,7 @@ class TFMobileBertForMaskedLM(TFMobileBertPreTrainedModel, TFMaskedLanguageModel
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], prediction_scores) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], prediction_scores)
if not return_dict: if not inputs["return_dict"]:
output = (prediction_scores,) + outputs[2:] output = (prediction_scores,) + outputs[2:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
...@@ -1248,6 +1245,7 @@ class TFMobileBertForNextSentencePrediction(TFMobileBertPreTrainedModel, TFNextS ...@@ -1248,6 +1245,7 @@ class TFMobileBertForNextSentencePrediction(TFMobileBertPreTrainedModel, TFNextS
""" """
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -1261,7 +1259,6 @@ class TFMobileBertForNextSentencePrediction(TFMobileBertPreTrainedModel, TFNextS ...@@ -1261,7 +1259,6 @@ class TFMobileBertForNextSentencePrediction(TFMobileBertPreTrainedModel, TFNextS
training=training, training=training,
kwargs_call=kwargs, kwargs_call=kwargs,
) )
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.mobilebert.return_dict
outputs = self.mobilebert( outputs = self.mobilebert(
inputs["input_ids"], inputs["input_ids"],
attention_mask=inputs["attention_mask"], attention_mask=inputs["attention_mask"],
...@@ -1271,10 +1268,9 @@ class TFMobileBertForNextSentencePrediction(TFMobileBertPreTrainedModel, TFNextS ...@@ -1271,10 +1268,9 @@ class TFMobileBertForNextSentencePrediction(TFMobileBertPreTrainedModel, TFNextS
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"], output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict, return_dict=inputs["return_dict"],
training=inputs["training"], training=inputs["training"],
) )
pooled_output = outputs[1] pooled_output = outputs[1]
seq_relationship_scores = self.cls(pooled_output) seq_relationship_scores = self.cls(pooled_output)
...@@ -1284,7 +1280,7 @@ class TFMobileBertForNextSentencePrediction(TFMobileBertPreTrainedModel, TFNextS ...@@ -1284,7 +1280,7 @@ class TFMobileBertForNextSentencePrediction(TFMobileBertPreTrainedModel, TFNextS
else self.compute_loss(labels=inputs["next_sentence_label"], logits=seq_relationship_scores) else self.compute_loss(labels=inputs["next_sentence_label"], logits=seq_relationship_scores)
) )
if not return_dict: if not inputs["return_dict"]:
output = (seq_relationship_scores,) + outputs[2:] output = (seq_relationship_scores,) + outputs[2:]
return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output
...@@ -1344,6 +1340,7 @@ class TFMobileBertForSequenceClassification(TFMobileBertPreTrainedModel, TFSeque ...@@ -1344,6 +1340,7 @@ class TFMobileBertForSequenceClassification(TFMobileBertPreTrainedModel, TFSeque
""" """
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -1357,7 +1354,6 @@ class TFMobileBertForSequenceClassification(TFMobileBertPreTrainedModel, TFSeque ...@@ -1357,7 +1354,6 @@ class TFMobileBertForSequenceClassification(TFMobileBertPreTrainedModel, TFSeque
training=training, training=training,
kwargs_call=kwargs, kwargs_call=kwargs,
) )
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.mobilebert.return_dict
outputs = self.mobilebert( outputs = self.mobilebert(
inputs["input_ids"], inputs["input_ids"],
attention_mask=inputs["attention_mask"], attention_mask=inputs["attention_mask"],
...@@ -1367,18 +1363,17 @@ class TFMobileBertForSequenceClassification(TFMobileBertPreTrainedModel, TFSeque ...@@ -1367,18 +1363,17 @@ class TFMobileBertForSequenceClassification(TFMobileBertPreTrainedModel, TFSeque
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"], output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict, return_dict=inputs["return_dict"],
training=inputs["training"], training=inputs["training"],
) )
pooled_output = outputs[1] pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output, training=training) pooled_output = self.dropout(pooled_output, training=inputs["training"])
logits = self.classifier(pooled_output) logits = self.classifier(pooled_output)
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
if not return_dict: if not inputs["return_dict"]:
output = (logits,) + outputs[2:] output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
...@@ -1445,6 +1440,7 @@ class TFMobileBertForQuestionAnswering(TFMobileBertPreTrainedModel, TFQuestionAn ...@@ -1445,6 +1440,7 @@ class TFMobileBertForQuestionAnswering(TFMobileBertPreTrainedModel, TFQuestionAn
""" """
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -1459,7 +1455,6 @@ class TFMobileBertForQuestionAnswering(TFMobileBertPreTrainedModel, TFQuestionAn ...@@ -1459,7 +1455,6 @@ class TFMobileBertForQuestionAnswering(TFMobileBertPreTrainedModel, TFQuestionAn
training=training, training=training,
kwargs_call=kwargs, kwargs_call=kwargs,
) )
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.mobilebert.return_dict
outputs = self.mobilebert( outputs = self.mobilebert(
inputs["input_ids"], inputs["input_ids"],
attention_mask=inputs["attention_mask"], attention_mask=inputs["attention_mask"],
...@@ -1469,10 +1464,9 @@ class TFMobileBertForQuestionAnswering(TFMobileBertPreTrainedModel, TFQuestionAn ...@@ -1469,10 +1464,9 @@ class TFMobileBertForQuestionAnswering(TFMobileBertPreTrainedModel, TFQuestionAn
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"], output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict, return_dict=inputs["return_dict"],
training=inputs["training"], training=inputs["training"],
) )
sequence_output = outputs[0] sequence_output = outputs[0]
logits = self.qa_outputs(sequence_output) logits = self.qa_outputs(sequence_output)
...@@ -1486,7 +1480,7 @@ class TFMobileBertForQuestionAnswering(TFMobileBertPreTrainedModel, TFQuestionAn ...@@ -1486,7 +1480,7 @@ class TFMobileBertForQuestionAnswering(TFMobileBertPreTrainedModel, TFQuestionAn
labels["end_position"] = inputs["end_positions"] labels["end_position"] = inputs["end_positions"]
loss = self.compute_loss(labels, (start_logits, end_logits)) loss = self.compute_loss(labels, (start_logits, end_logits))
if not return_dict: if not inputs["return_dict"]:
output = (start_logits, end_logits) + outputs[2:] output = (start_logits, end_logits) + outputs[2:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
...@@ -1558,6 +1552,7 @@ class TFMobileBertForMultipleChoice(TFMobileBertPreTrainedModel, TFMultipleChoic ...@@ -1558,6 +1552,7 @@ class TFMobileBertForMultipleChoice(TFMobileBertPreTrainedModel, TFMultipleChoic
""" """
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -1571,7 +1566,6 @@ class TFMobileBertForMultipleChoice(TFMobileBertPreTrainedModel, TFMultipleChoic ...@@ -1571,7 +1566,6 @@ class TFMobileBertForMultipleChoice(TFMobileBertPreTrainedModel, TFMultipleChoic
training=training, training=training,
kwargs_call=kwargs, kwargs_call=kwargs,
) )
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.mobilebert.return_dict
if inputs["input_ids"] is not None: if inputs["input_ids"] is not None:
num_choices = shape_list(inputs["input_ids"])[1] num_choices = shape_list(inputs["input_ids"])[1]
...@@ -1604,17 +1598,17 @@ class TFMobileBertForMultipleChoice(TFMobileBertPreTrainedModel, TFMultipleChoic ...@@ -1604,17 +1598,17 @@ class TFMobileBertForMultipleChoice(TFMobileBertPreTrainedModel, TFMultipleChoic
flat_inputs_embeds, flat_inputs_embeds,
inputs["output_attentions"], inputs["output_attentions"],
inputs["output_hidden_states"], inputs["output_hidden_states"],
return_dict=return_dict, return_dict=inputs["return_dict"],
training=inputs["training"], training=inputs["training"],
) )
pooled_output = outputs[1] pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output, training=training) pooled_output = self.dropout(pooled_output, training=inputs["training"])
logits = self.classifier(pooled_output) logits = self.classifier(pooled_output)
reshaped_logits = tf.reshape(logits, (-1, num_choices)) reshaped_logits = tf.reshape(logits, (-1, num_choices))
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], reshaped_logits) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], reshaped_logits)
if not return_dict: if not inputs["return_dict"]:
output = (reshaped_logits,) + outputs[2:] output = (reshaped_logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
...@@ -1676,6 +1670,7 @@ class TFMobileBertForTokenClassification(TFMobileBertPreTrainedModel, TFTokenCla ...@@ -1676,6 +1670,7 @@ class TFMobileBertForTokenClassification(TFMobileBertPreTrainedModel, TFTokenCla
""" """
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -1689,7 +1684,6 @@ class TFMobileBertForTokenClassification(TFMobileBertPreTrainedModel, TFTokenCla ...@@ -1689,7 +1684,6 @@ class TFMobileBertForTokenClassification(TFMobileBertPreTrainedModel, TFTokenCla
training=training, training=training,
kwargs_call=kwargs, kwargs_call=kwargs,
) )
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.mobilebert.return_dict
outputs = self.mobilebert( outputs = self.mobilebert(
inputs["input_ids"], inputs["input_ids"],
attention_mask=inputs["attention_mask"], attention_mask=inputs["attention_mask"],
...@@ -1705,12 +1699,12 @@ class TFMobileBertForTokenClassification(TFMobileBertPreTrainedModel, TFTokenCla ...@@ -1705,12 +1699,12 @@ class TFMobileBertForTokenClassification(TFMobileBertPreTrainedModel, TFTokenCla
sequence_output = outputs[0] sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output, training=training) sequence_output = self.dropout(sequence_output, training=inputs["training"])
logits = self.classifier(sequence_output) logits = self.classifier(sequence_output)
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
if not return_dict: if not inputs["return_dict"]:
output = (logits,) + outputs[2:] output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
......
...@@ -192,6 +192,8 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer): ...@@ -192,6 +192,8 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super().__init__(*inputs, **kwargs) super().__init__(*inputs, **kwargs)
self.config = config
self.output_hidden_states = config.output_hidden_states self.output_hidden_states = config.output_hidden_states
self.output_attentions = config.output_attentions self.output_attentions = config.output_attentions
self.return_dict = config.use_return_dict self.return_dict = config.use_return_dict
...@@ -240,6 +242,7 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer): ...@@ -240,6 +242,7 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
): ):
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -252,13 +255,6 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer): ...@@ -252,13 +255,6 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
training=training, training=training,
kwargs_call=kwargs, kwargs_call=kwargs,
) )
output_attentions = (
inputs["output_attentions"] if inputs["output_attentions"] is not None else self.output_attentions
)
output_hidden_states = (
inputs["output_hidden_states"] if inputs["output_hidden_states"] is not None else self.output_hidden_states
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.return_dict
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None: if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
...@@ -320,34 +316,34 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer): ...@@ -320,34 +316,34 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
output_shape = input_shape + [shape_list(hidden_states)[-1]] output_shape = input_shape + [shape_list(hidden_states)[-1]]
all_attentions = () if output_attentions else None all_attentions = () if inputs["output_attentions"] else None
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if inputs["output_hidden_states"] else None
for i, block in enumerate(self.h): for i, block in enumerate(self.h):
if output_hidden_states: if inputs["output_hidden_states"]:
all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),) all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),)
outputs = block( outputs = block(
hidden_states, hidden_states,
inputs["attention_mask"], inputs["attention_mask"],
inputs["head_mask"][i], inputs["head_mask"][i],
output_attentions, inputs["output_attentions"],
training=inputs["training"], training=inputs["training"],
) )
hidden_states = outputs[0] hidden_states = outputs[0]
if output_attentions: if inputs["output_attentions"]:
all_attentions = all_attentions + (outputs[1],) all_attentions = all_attentions + (outputs[1],)
hidden_states = tf.reshape(hidden_states, output_shape) hidden_states = tf.reshape(hidden_states, output_shape)
# Add last hidden state # Add last hidden state
if output_hidden_states: if inputs["output_hidden_states"]:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if output_attentions: if inputs["output_attentions"]:
# let the number of heads free (-1) so we can extract attention even after head pruning # let the number of heads free (-1) so we can extract attention even after head pruning
attention_output_shape = input_shape[:-1] + [-1] + shape_list(all_attentions[0])[-2:] attention_output_shape = input_shape[:-1] + [-1] + shape_list(all_attentions[0])[-2:]
all_attentions = tuple(tf.reshape(t, attention_output_shape) for t in all_attentions) all_attentions = tuple(tf.reshape(t, attention_output_shape) for t in all_attentions)
if not return_dict: if not inputs["return_dict"]:
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None) return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
return TFBaseModelOutput( return TFBaseModelOutput(
...@@ -519,6 +515,7 @@ class TFOpenAIGPTModel(TFOpenAIGPTPreTrainedModel): ...@@ -519,6 +515,7 @@ class TFOpenAIGPTModel(TFOpenAIGPTPreTrainedModel):
): ):
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -590,6 +587,7 @@ class TFOpenAIGPTLMHeadModel(TFOpenAIGPTPreTrainedModel, TFCausalLanguageModelin ...@@ -590,6 +587,7 @@ class TFOpenAIGPTLMHeadModel(TFOpenAIGPTPreTrainedModel, TFCausalLanguageModelin
""" """
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -603,7 +601,6 @@ class TFOpenAIGPTLMHeadModel(TFOpenAIGPTPreTrainedModel, TFCausalLanguageModelin ...@@ -603,7 +601,6 @@ class TFOpenAIGPTLMHeadModel(TFOpenAIGPTPreTrainedModel, TFCausalLanguageModelin
training=training, training=training,
kwargs_call=kwargs, kwargs_call=kwargs,
) )
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
transformer_outputs = self.transformer( transformer_outputs = self.transformer(
input_ids=inputs["input_ids"], input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"], attention_mask=inputs["attention_mask"],
...@@ -627,7 +624,7 @@ class TFOpenAIGPTLMHeadModel(TFOpenAIGPTPreTrainedModel, TFCausalLanguageModelin ...@@ -627,7 +624,7 @@ class TFOpenAIGPTLMHeadModel(TFOpenAIGPTPreTrainedModel, TFCausalLanguageModelin
labels = inputs["labels"][:, 1:] labels = inputs["labels"][:, 1:]
loss = self.compute_loss(labels, logits) loss = self.compute_loss(labels, logits)
if not return_dict: if not inputs["return_dict"]:
output = (logits,) + transformer_outputs[1:] output = (logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
...@@ -707,6 +704,7 @@ class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel): ...@@ -707,6 +704,7 @@ class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel):
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -720,7 +718,6 @@ class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel): ...@@ -720,7 +718,6 @@ class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel):
training=training, training=training,
kwargs_call=kwargs, kwargs_call=kwargs,
) )
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
if inputs["input_ids"] is not None: if inputs["input_ids"] is not None:
input_shapes = shape_list(inputs["input_ids"]) input_shapes = shape_list(inputs["input_ids"])
...@@ -747,7 +744,7 @@ class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel): ...@@ -747,7 +744,7 @@ class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel):
inputs["inputs_embeds"], inputs["inputs_embeds"],
inputs["output_attentions"], inputs["output_attentions"],
inputs["output_hidden_states"], inputs["output_hidden_states"],
return_dict=return_dict, return_dict=inputs["return_dict"],
training=inputs["training"], training=inputs["training"],
) )
hidden_states = transformer_outputs[0] hidden_states = transformer_outputs[0]
...@@ -756,7 +753,7 @@ class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel): ...@@ -756,7 +753,7 @@ class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel):
mc_logits = self.multiple_choice_head(hidden_states, inputs["mc_token_ids"], training=inputs["training"]) mc_logits = self.multiple_choice_head(hidden_states, inputs["mc_token_ids"], training=inputs["training"])
mc_logits = tf.squeeze(mc_logits, axis=-1) mc_logits = tf.squeeze(mc_logits, axis=-1)
if not return_dict: if not inputs["return_dict"]:
return (lm_logits, mc_logits) + transformer_outputs[1:] return (lm_logits, mc_logits) + transformer_outputs[1:]
return TFOpenAIGPTDoubleHeadsModelOutput( return TFOpenAIGPTDoubleHeadsModelOutput(
......
...@@ -467,6 +467,7 @@ class TFRobertaMainLayer(tf.keras.layers.Layer): ...@@ -467,6 +467,7 @@ class TFRobertaMainLayer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.config = config
self.num_hidden_layers = config.num_hidden_layers self.num_hidden_layers = config.num_hidden_layers
self.initializer_range = config.initializer_range self.initializer_range = config.initializer_range
self.output_attentions = config.output_attentions self.output_attentions = config.output_attentions
...@@ -511,6 +512,7 @@ class TFRobertaMainLayer(tf.keras.layers.Layer): ...@@ -511,6 +512,7 @@ class TFRobertaMainLayer(tf.keras.layers.Layer):
): ):
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -523,13 +525,6 @@ class TFRobertaMainLayer(tf.keras.layers.Layer): ...@@ -523,13 +525,6 @@ class TFRobertaMainLayer(tf.keras.layers.Layer):
training=training, training=training,
kwargs_call=kwargs, kwargs_call=kwargs,
) )
output_attentions = (
inputs["output_attentions"] if inputs["output_attentions"] is not None else self.output_attentions
)
output_hidden_states = (
inputs["output_hidden_states"] if inputs["output_hidden_states"] is not None else self.output_hidden_states
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.return_dict
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None: if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
...@@ -584,16 +579,16 @@ class TFRobertaMainLayer(tf.keras.layers.Layer): ...@@ -584,16 +579,16 @@ class TFRobertaMainLayer(tf.keras.layers.Layer):
embedding_output, embedding_output,
extended_attention_mask, extended_attention_mask,
inputs["head_mask"], inputs["head_mask"],
output_attentions, inputs["output_attentions"],
output_hidden_states, inputs["output_hidden_states"],
return_dict, inputs["return_dict"],
training=inputs["training"], training=inputs["training"],
) )
sequence_output = encoder_outputs[0] sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output) pooled_output = self.pooler(sequence_output)
if not return_dict: if not inputs["return_dict"]:
return ( return (
sequence_output, sequence_output,
pooled_output, pooled_output,
...@@ -739,6 +734,7 @@ class TFRobertaModel(TFRobertaPreTrainedModel): ...@@ -739,6 +734,7 @@ class TFRobertaModel(TFRobertaPreTrainedModel):
): ):
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -844,6 +840,7 @@ class TFRobertaForMaskedLM(TFRobertaPreTrainedModel, TFMaskedLanguageModelingLos ...@@ -844,6 +840,7 @@ class TFRobertaForMaskedLM(TFRobertaPreTrainedModel, TFMaskedLanguageModelingLos
""" """
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -857,7 +854,6 @@ class TFRobertaForMaskedLM(TFRobertaPreTrainedModel, TFMaskedLanguageModelingLos ...@@ -857,7 +854,6 @@ class TFRobertaForMaskedLM(TFRobertaPreTrainedModel, TFMaskedLanguageModelingLos
training=training, training=training,
kwargs_call=kwargs, kwargs_call=kwargs,
) )
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.roberta.return_dict
outputs = self.roberta( outputs = self.roberta(
inputs["input_ids"], inputs["input_ids"],
attention_mask=inputs["attention_mask"], attention_mask=inputs["attention_mask"],
...@@ -867,7 +863,7 @@ class TFRobertaForMaskedLM(TFRobertaPreTrainedModel, TFMaskedLanguageModelingLos ...@@ -867,7 +863,7 @@ class TFRobertaForMaskedLM(TFRobertaPreTrainedModel, TFMaskedLanguageModelingLos
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"], output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict, return_dict=inputs["return_dict"],
training=inputs["training"], training=inputs["training"],
) )
...@@ -876,7 +872,7 @@ class TFRobertaForMaskedLM(TFRobertaPreTrainedModel, TFMaskedLanguageModelingLos ...@@ -876,7 +872,7 @@ class TFRobertaForMaskedLM(TFRobertaPreTrainedModel, TFMaskedLanguageModelingLos
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], prediction_scores) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], prediction_scores)
if not return_dict: if not inputs["return_dict"]:
output = (prediction_scores,) + outputs[2:] output = (prediction_scores,) + outputs[2:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
...@@ -961,6 +957,7 @@ class TFRobertaForSequenceClassification(TFRobertaPreTrainedModel, TFSequenceCla ...@@ -961,6 +957,7 @@ class TFRobertaForSequenceClassification(TFRobertaPreTrainedModel, TFSequenceCla
""" """
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -974,7 +971,6 @@ class TFRobertaForSequenceClassification(TFRobertaPreTrainedModel, TFSequenceCla ...@@ -974,7 +971,6 @@ class TFRobertaForSequenceClassification(TFRobertaPreTrainedModel, TFSequenceCla
training=training, training=training,
kwargs_call=kwargs, kwargs_call=kwargs,
) )
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.roberta.return_dict
outputs = self.roberta( outputs = self.roberta(
inputs["input_ids"], inputs["input_ids"],
attention_mask=inputs["attention_mask"], attention_mask=inputs["attention_mask"],
...@@ -984,16 +980,16 @@ class TFRobertaForSequenceClassification(TFRobertaPreTrainedModel, TFSequenceCla ...@@ -984,16 +980,16 @@ class TFRobertaForSequenceClassification(TFRobertaPreTrainedModel, TFSequenceCla
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"], output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict, return_dict=inputs["return_dict"],
training=inputs["training"], training=inputs["training"],
) )
sequence_output = outputs[0] sequence_output = outputs[0]
logits = self.classifier(sequence_output, training=training) logits = self.classifier(sequence_output, training=inputs["training"])
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
if not return_dict: if not inputs["return_dict"]:
output = (logits,) + outputs[2:] output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
...@@ -1062,6 +1058,7 @@ class TFRobertaForMultipleChoice(TFRobertaPreTrainedModel, TFMultipleChoiceLoss) ...@@ -1062,6 +1058,7 @@ class TFRobertaForMultipleChoice(TFRobertaPreTrainedModel, TFMultipleChoiceLoss)
""" """
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -1075,7 +1072,6 @@ class TFRobertaForMultipleChoice(TFRobertaPreTrainedModel, TFMultipleChoiceLoss) ...@@ -1075,7 +1072,6 @@ class TFRobertaForMultipleChoice(TFRobertaPreTrainedModel, TFMultipleChoiceLoss)
training=training, training=training,
kwargs_call=kwargs, kwargs_call=kwargs,
) )
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.roberta.return_dict
if inputs["input_ids"] is not None: if inputs["input_ids"] is not None:
num_choices = shape_list(inputs["input_ids"])[1] num_choices = shape_list(inputs["input_ids"])[1]
...@@ -1103,7 +1099,7 @@ class TFRobertaForMultipleChoice(TFRobertaPreTrainedModel, TFMultipleChoiceLoss) ...@@ -1103,7 +1099,7 @@ class TFRobertaForMultipleChoice(TFRobertaPreTrainedModel, TFMultipleChoiceLoss)
inputs["inputs_embeds"], inputs["inputs_embeds"],
inputs["output_attentions"], inputs["output_attentions"],
inputs["output_hidden_states"], inputs["output_hidden_states"],
return_dict=return_dict, return_dict=inputs["return_dict"],
training=inputs["training"], training=inputs["training"],
) )
pooled_output = outputs[1] pooled_output = outputs[1]
...@@ -1113,7 +1109,7 @@ class TFRobertaForMultipleChoice(TFRobertaPreTrainedModel, TFMultipleChoiceLoss) ...@@ -1113,7 +1109,7 @@ class TFRobertaForMultipleChoice(TFRobertaPreTrainedModel, TFMultipleChoiceLoss)
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], reshaped_logits) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], reshaped_logits)
if not return_dict: if not inputs["return_dict"]:
output = (reshaped_logits,) + outputs[2:] output = (reshaped_logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
...@@ -1175,6 +1171,7 @@ class TFRobertaForTokenClassification(TFRobertaPreTrainedModel, TFTokenClassific ...@@ -1175,6 +1171,7 @@ class TFRobertaForTokenClassification(TFRobertaPreTrainedModel, TFTokenClassific
""" """
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -1188,7 +1185,6 @@ class TFRobertaForTokenClassification(TFRobertaPreTrainedModel, TFTokenClassific ...@@ -1188,7 +1185,6 @@ class TFRobertaForTokenClassification(TFRobertaPreTrainedModel, TFTokenClassific
training=training, training=training,
kwargs_call=kwargs, kwargs_call=kwargs,
) )
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.roberta.return_dict
outputs = self.roberta( outputs = self.roberta(
inputs["input_ids"], inputs["input_ids"],
attention_mask=inputs["attention_mask"], attention_mask=inputs["attention_mask"],
...@@ -1198,18 +1194,17 @@ class TFRobertaForTokenClassification(TFRobertaPreTrainedModel, TFTokenClassific ...@@ -1198,18 +1194,17 @@ class TFRobertaForTokenClassification(TFRobertaPreTrainedModel, TFTokenClassific
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"], output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict, return_dict=inputs["return_dict"],
training=inputs["training"], training=inputs["training"],
) )
sequence_output = outputs[0] sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output, training=training) sequence_output = self.dropout(sequence_output, training=inputs["training"])
logits = self.classifier(sequence_output) logits = self.classifier(sequence_output)
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
if not return_dict: if not inputs["return_dict"]:
output = (logits,) + outputs[2:] output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
...@@ -1276,6 +1271,7 @@ class TFRobertaForQuestionAnswering(TFRobertaPreTrainedModel, TFQuestionAnswerin ...@@ -1276,6 +1271,7 @@ class TFRobertaForQuestionAnswering(TFRobertaPreTrainedModel, TFQuestionAnswerin
""" """
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -1290,7 +1286,6 @@ class TFRobertaForQuestionAnswering(TFRobertaPreTrainedModel, TFQuestionAnswerin ...@@ -1290,7 +1286,6 @@ class TFRobertaForQuestionAnswering(TFRobertaPreTrainedModel, TFQuestionAnswerin
training=training, training=training,
kwargs_call=kwargs, kwargs_call=kwargs,
) )
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.roberta.return_dict
outputs = self.roberta( outputs = self.roberta(
inputs["input_ids"], inputs["input_ids"],
attention_mask=inputs["attention_mask"], attention_mask=inputs["attention_mask"],
...@@ -1300,10 +1295,9 @@ class TFRobertaForQuestionAnswering(TFRobertaPreTrainedModel, TFQuestionAnswerin ...@@ -1300,10 +1295,9 @@ class TFRobertaForQuestionAnswering(TFRobertaPreTrainedModel, TFQuestionAnswerin
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"], output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict, return_dict=inputs["return_dict"],
training=inputs["training"], training=inputs["training"],
) )
sequence_output = outputs[0] sequence_output = outputs[0]
logits = self.qa_outputs(sequence_output) logits = self.qa_outputs(sequence_output)
...@@ -1317,7 +1311,7 @@ class TFRobertaForQuestionAnswering(TFRobertaPreTrainedModel, TFQuestionAnswerin ...@@ -1317,7 +1311,7 @@ class TFRobertaForQuestionAnswering(TFRobertaPreTrainedModel, TFQuestionAnswerin
labels["end_position"] = inputs["end_positions"] labels["end_position"] = inputs["end_positions"]
loss = self.compute_loss(labels, (start_logits, end_logits)) loss = self.compute_loss(labels, (start_logits, end_logits))
if not return_dict: if not inputs["return_dict"]:
output = (start_logits, end_logits) + outputs[2:] output = (start_logits, end_logits) + outputs[2:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
......
...@@ -547,6 +547,8 @@ class TFT5MainLayer(tf.keras.layers.Layer): ...@@ -547,6 +547,8 @@ class TFT5MainLayer(tf.keras.layers.Layer):
def __init__(self, config, embed_tokens=None, **kwargs): def __init__(self, config, embed_tokens=None, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.config = config
self.output_hidden_states = config.output_hidden_states self.output_hidden_states = config.output_hidden_states
self.output_attentions = config.output_attentions self.output_attentions = config.output_attentions
self.use_cache = config.use_cache self.use_cache = config.use_cache
...@@ -597,6 +599,7 @@ class TFT5MainLayer(tf.keras.layers.Layer): ...@@ -597,6 +599,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
) -> Tuple: ) -> Tuple:
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
...@@ -610,13 +613,6 @@ class TFT5MainLayer(tf.keras.layers.Layer): ...@@ -610,13 +613,6 @@ class TFT5MainLayer(tf.keras.layers.Layer):
training=training, training=training,
kwargs_call=kwargs, kwargs_call=kwargs,
) )
output_attentions = (
inputs["output_attentions"] if inputs["output_attentions"] is not None else self.output_attentions
)
output_hidden_states = (
inputs["output_hidden_states"] if inputs["output_hidden_states"] is not None else self.output_hidden_states
)
use_cache = inputs["use_cache"] if inputs["use_cache"] is not None else self.use_cache
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None: if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
err_msg_prefix = "decoder_" if self.is_decoder else "" err_msg_prefix = "decoder_" if self.is_decoder else ""
...@@ -727,7 +723,7 @@ class TFT5MainLayer(tf.keras.layers.Layer): ...@@ -727,7 +723,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
hidden_states = self.dropout(inputs["inputs_embeds"], training=inputs["training"]) hidden_states = self.dropout(inputs["inputs_embeds"], training=inputs["training"])
for i, (layer_module, past_key_value) in enumerate(zip(self.block, inputs["past_key_values"])): for i, (layer_module, past_key_value) in enumerate(zip(self.block, inputs["past_key_values"])):
if output_hidden_states: if inputs["output_hidden_states"]:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
layer_outputs = layer_module( layer_outputs = layer_module(
...@@ -739,8 +735,8 @@ class TFT5MainLayer(tf.keras.layers.Layer): ...@@ -739,8 +735,8 @@ class TFT5MainLayer(tf.keras.layers.Layer):
encoder_decoder_position_bias=encoder_decoder_position_bias, encoder_decoder_position_bias=encoder_decoder_position_bias,
head_mask=inputs["head_mask"][i], head_mask=inputs["head_mask"][i],
past_key_value=past_key_value, past_key_value=past_key_value,
use_cache=use_cache, use_cache=inputs["use_cache"],
output_attentions=output_attentions, output_attentions=inputs["output_attentions"],
training=inputs["training"], training=inputs["training"],
) )
# layer_outputs is a tuple with: # layer_outputs is a tuple with:
...@@ -756,23 +752,23 @@ class TFT5MainLayer(tf.keras.layers.Layer): ...@@ -756,23 +752,23 @@ class TFT5MainLayer(tf.keras.layers.Layer):
# append next layer key value states # append next layer key value states
present_key_value_states = present_key_value_states + (present_key_value_state,) present_key_value_states = present_key_value_states + (present_key_value_state,)
if output_attentions: if inputs["output_attentions"]:
all_attentions = all_attentions + (layer_outputs[3],) all_attentions = all_attentions + (layer_outputs[3],)
hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.dropout(hidden_states, training=inputs["training"]) hidden_states = self.dropout(hidden_states, training=inputs["training"])
# Add last layer # Add last layer
if output_hidden_states: if inputs["output_hidden_states"]:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
outputs = (hidden_states,) outputs = (hidden_states,)
# need to check if is decoder here as well for special cases when using keras compile # need to check if is decoder here as well for special cases when using keras compile
if cast_bool_to_primitive(use_cache, self.use_cache) is True and self.is_decoder: if cast_bool_to_primitive(inputs["use_cache"], self.use_cache) is True and self.is_decoder:
outputs = outputs + (present_key_value_states,) outputs = outputs + (present_key_value_states,)
if output_hidden_states: if inputs["output_hidden_states"]:
outputs = outputs + (all_hidden_states,) outputs = outputs + (all_hidden_states,)
if output_attentions: if inputs["output_attentions"]:
outputs = outputs + (all_attentions,) outputs = outputs + (all_attentions,)
return outputs # last-layer hidden state, (all hidden states), (all attentions) return outputs # last-layer hidden state, (all hidden states), (all attentions)
...@@ -1073,6 +1069,7 @@ class TFT5Model(TFT5PreTrainedModel): ...@@ -1073,6 +1069,7 @@ class TFT5Model(TFT5PreTrainedModel):
""" """
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
...@@ -1089,20 +1086,10 @@ class TFT5Model(TFT5PreTrainedModel): ...@@ -1089,20 +1086,10 @@ class TFT5Model(TFT5PreTrainedModel):
training=training, training=training,
kwargs_call=kwargs, kwargs_call=kwargs,
) )
use_cache = inputs["use_cache"] if inputs["use_cache"] is not None else self.config.use_cache
output_attentions = (
inputs["output_attentions"] if inputs["output_attentions"] is not None else self.config.output_attentions
)
output_hidden_states = (
inputs["output_hidden_states"]
if inputs["output_hidden_states"] is not None
else self.config.output_hidden_states
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.config.return_dict
# Encode if needed (training, first prediction pass) # Encode if needed (training, first prediction pass)
if encoder_outputs is None: if inputs["encoder_outputs"] is None:
encoder_outputs = self.encoder( inputs["encoder_outputs"] = self.encoder(
inputs["input_ids"], inputs["input_ids"],
attention_mask=inputs["attention_mask"], attention_mask=inputs["attention_mask"],
encoder_hidden_states=None, encoder_hidden_states=None,
...@@ -1111,12 +1098,12 @@ class TFT5Model(TFT5PreTrainedModel): ...@@ -1111,12 +1098,12 @@ class TFT5Model(TFT5PreTrainedModel):
head_mask=inputs["head_mask"], head_mask=inputs["head_mask"],
past_key_values=None, past_key_values=None,
use_cache=False, use_cache=False,
output_attentions=output_attentions, output_attentions=inputs["output_attentions"],
output_hidden_states=output_hidden_states, output_hidden_states=inputs["output_hidden_states"],
training=inputs["training"], training=inputs["training"],
) )
hidden_states = encoder_outputs[0] hidden_states = inputs["encoder_outputs"][0]
# Decode # Decode
decoder_outputs = self.decoder( decoder_outputs = self.decoder(
...@@ -1127,29 +1114,31 @@ class TFT5Model(TFT5PreTrainedModel): ...@@ -1127,29 +1114,31 @@ class TFT5Model(TFT5PreTrainedModel):
inputs_embeds=inputs["decoder_inputs_embeds"], inputs_embeds=inputs["decoder_inputs_embeds"],
head_mask=inputs["head_mask"], head_mask=inputs["head_mask"],
past_key_values=inputs["past_key_values"], past_key_values=inputs["past_key_values"],
use_cache=use_cache, use_cache=inputs["use_cache"],
output_attentions=output_attentions, output_attentions=inputs["output_attentions"],
output_hidden_states=output_hidden_states, output_hidden_states=inputs["output_hidden_states"],
training=inputs["training"], training=inputs["training"],
) )
past = ( past = (
(encoder_outputs, decoder_outputs[1]) if cast_bool_to_primitive(use_cache, self.config.use_cache) else None (inputs["encoder_outputs"], decoder_outputs[1])
if cast_bool_to_primitive(inputs["use_cache"], self.config.use_cache)
else None
) )
if not return_dict: if not inputs["return_dict"]:
if past is not None: if past is not None:
decoder_outputs = decoder_outputs[:1] + (past,) + decoder_outputs[2:] decoder_outputs = decoder_outputs[:1] + (past,) + decoder_outputs[2:]
return decoder_outputs + encoder_outputs return decoder_outputs + inputs["encoder_outputs"]
# This is long and annoying but if we introduce return_dict at the TFT5MainLayer level (like in PyTorch) # This is long and annoying but if we introduce return_dict at the TFT5MainLayer level (like in PyTorch)
# TF refuses to compile anymore. # TF refuses to compile anymore.
if not cast_bool_to_primitive(use_cache, self.config.use_cache): if not cast_bool_to_primitive(inputs["use_cache"], self.config.use_cache):
decoder_outputs = decoder_outputs[:1] + (None,) + decoder_outputs[1:] decoder_outputs = decoder_outputs[:1] + (None,) + decoder_outputs[1:]
if not cast_bool_to_primitive(output_hidden_states, self.config.output_hidden_states): if not cast_bool_to_primitive(inputs["output_hidden_states"], self.config.output_hidden_states):
encoder_outputs = encoder_outputs[:1] + (None,) + encoder_outputs[1:] inputs["encoder_outputs"] = inputs["encoder_outputs"][:1] + (None,) + inputs["encoder_outputs"][1:]
decoder_outputs = decoder_outputs[:2] + (None,) + decoder_outputs[2:] decoder_outputs = decoder_outputs[:2] + (None,) + decoder_outputs[2:]
if not cast_bool_to_primitive(output_attentions, self.config.output_attentions): if not cast_bool_to_primitive(inputs["output_attentions"], self.config.output_attentions):
encoder_outputs = encoder_outputs + (None,) inputs["encoder_outputs"] = inputs["encoder_outputs"] + (None,)
decoder_outputs = decoder_outputs + (None,) decoder_outputs = decoder_outputs + (None,)
return TFSeq2SeqModelOutput( return TFSeq2SeqModelOutput(
...@@ -1157,9 +1146,9 @@ class TFT5Model(TFT5PreTrainedModel): ...@@ -1157,9 +1146,9 @@ class TFT5Model(TFT5PreTrainedModel):
past_key_values=past, past_key_values=past,
decoder_hidden_states=decoder_outputs[2], decoder_hidden_states=decoder_outputs[2],
decoder_attentions=decoder_outputs[3], decoder_attentions=decoder_outputs[3],
encoder_last_hidden_state=encoder_outputs[0], encoder_last_hidden_state=inputs["encoder_outputs"][0],
encoder_hidden_states=encoder_outputs[1], encoder_hidden_states=inputs["encoder_outputs"][1],
encoder_attentions=encoder_outputs[2], encoder_attentions=inputs["encoder_outputs"][2],
) )
...@@ -1261,6 +1250,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling ...@@ -1261,6 +1250,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
""" """
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
...@@ -1278,28 +1268,20 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling ...@@ -1278,28 +1268,20 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
training=training, training=training,
kwargs_call=kwargs, kwargs_call=kwargs,
) )
use_cache = inputs["use_cache"] if inputs["use_cache"] is not None else self.config.use_cache
output_attentions = (
inputs["output_attentions"] if inputs["output_attentions"] else self.config.output_attentions
)
output_hidden_states = (
inputs["output_hidden_states"] if inputs["output_hidden_states"] else self.config.output_hidden_states
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.config.return_dict
# Encode if needed (training, first prediction pass) # Encode if needed (training, first prediction pass)
if encoder_outputs is None: if inputs["encoder_outputs"] is None:
encoder_outputs = self.encoder( inputs["encoder_outputs"] = self.encoder(
inputs["input_ids"], inputs["input_ids"],
attention_mask=inputs["attention_mask"], attention_mask=inputs["attention_mask"],
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
head_mask=inputs["head_mask"], head_mask=inputs["head_mask"],
output_attentions=output_attentions, output_attentions=inputs["output_attentions"],
output_hidden_states=output_hidden_states, output_hidden_states=inputs["output_hidden_states"],
training=inputs["training"], training=inputs["training"],
) )
hidden_states = encoder_outputs[0] hidden_states = inputs["encoder_outputs"][0]
if ( if (
inputs["labels"] is not None inputs["labels"] is not None
...@@ -1326,9 +1308,9 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling ...@@ -1326,9 +1308,9 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
inputs_embeds=inputs["decoder_inputs_embeds"], inputs_embeds=inputs["decoder_inputs_embeds"],
head_mask=inputs["head_mask"], head_mask=inputs["head_mask"],
past_key_values=inputs["past_key_values"], past_key_values=inputs["past_key_values"],
use_cache=use_cache, use_cache=inputs["use_cache"],
output_attentions=output_attentions, output_attentions=inputs["output_attentions"],
output_hidden_states=output_hidden_states, output_hidden_states=inputs["output_hidden_states"],
training=inputs["training"], training=inputs["training"],
) )
...@@ -1344,29 +1326,33 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling ...@@ -1344,29 +1326,33 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
past = ( past = (
(encoder_outputs, decoder_outputs[1]) if cast_bool_to_primitive(use_cache, self.config.use_cache) else None (inputs["encoder_outputs"], decoder_outputs[1])
if cast_bool_to_primitive(inputs["use_cache"], self.config.use_cache)
else None
) )
if not return_dict: if not inputs["return_dict"]:
if past is not None: if past is not None:
decoder_outputs = decoder_outputs[:1] + (past,) + decoder_outputs[2:] decoder_outputs = decoder_outputs[:1] + (past,) + decoder_outputs[2:]
output = (logits,) + decoder_outputs[1:] + encoder_outputs output = (logits,) + decoder_outputs[1:] + inputs["encoder_outputs"]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
# Putting this before breaks tf compilation. # Putting this before breaks tf compilation.
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = (
output_attentions if inputs["output_attentions"] is not None else self.config.output_attentions
)
output_hidden_states = ( output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states output_hidden_states if inputs["output_hidden_states"] is not None else self.config.output_hidden_states
) )
# This is long and annoying but if we introduce return_dict at the TFT5MainLayer level (like in PyTorch) # This is long and annoying but if we introduce return_dict at the TFT5MainLayer level (like in PyTorch)
# TF refuses to compile anymore. # TF refuses to compile anymore.
if not cast_bool_to_primitive(use_cache, self.config.use_cache): if not cast_bool_to_primitive(inputs["use_cache"], self.config.use_cache):
decoder_outputs = decoder_outputs[:1] + (None,) + decoder_outputs[1:] decoder_outputs = decoder_outputs[:1] + (None,) + decoder_outputs[1:]
if not cast_bool_to_primitive(output_hidden_states, self.config.output_hidden_states): if not cast_bool_to_primitive(inputs["output_hidden_states"], self.config.output_hidden_states):
encoder_outputs = encoder_outputs[:1] + (None,) + encoder_outputs[1:] inputs["encoder_outputs"] = inputs["encoder_outputs"][:1] + (None,) + inputs["encoder_outputs"][1:]
decoder_outputs = decoder_outputs[:2] + (None,) + decoder_outputs[2:] decoder_outputs = decoder_outputs[:2] + (None,) + decoder_outputs[2:]
if not cast_bool_to_primitive(output_attentions, self.config.output_attentions): if not cast_bool_to_primitive(inputs["output_attentions"], self.config.output_attentions):
encoder_outputs = encoder_outputs + (None,) inputs["encoder_outputs"] = inputs["encoder_outputs"] + (None,)
decoder_outputs = decoder_outputs + (None,) decoder_outputs = decoder_outputs + (None,)
return TFSeq2SeqLMOutput( return TFSeq2SeqLMOutput(
...@@ -1375,9 +1361,9 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling ...@@ -1375,9 +1361,9 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
past_key_values=past, past_key_values=past,
decoder_hidden_states=decoder_outputs[2], decoder_hidden_states=decoder_outputs[2],
decoder_attentions=decoder_outputs[3], decoder_attentions=decoder_outputs[3],
encoder_last_hidden_state=encoder_outputs[0], encoder_last_hidden_state=inputs["encoder_outputs"][0],
encoder_hidden_states=encoder_outputs[1], encoder_hidden_states=inputs["encoder_outputs"][1],
encoder_attentions=encoder_outputs[2], encoder_attentions=inputs["encoder_outputs"][2],
) )
def prepare_inputs_for_generation(self, inputs, past, attention_mask, use_cache, **kwargs): def prepare_inputs_for_generation(self, inputs, past, attention_mask, use_cache, **kwargs):
......
...@@ -16,7 +16,6 @@ ...@@ -16,7 +16,6 @@
""" """
TF 2.0 Transformer XL model. TF 2.0 Transformer XL model.
""" """
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
...@@ -384,6 +383,8 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer): ...@@ -384,6 +383,8 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.config = config
self.output_hidden_states = config.output_hidden_states self.output_hidden_states = config.output_hidden_states
self.output_attentions = config.output_attentions self.output_attentions = config.output_attentions
self.return_dict = config.use_return_dict self.return_dict = config.use_return_dict
...@@ -516,6 +517,7 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer): ...@@ -516,6 +517,7 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
): ):
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
mems=mems, mems=mems,
head_mask=head_mask, head_mask=head_mask,
...@@ -526,13 +528,6 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer): ...@@ -526,13 +528,6 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
training=training, training=training,
kwargs_call=kwargs, kwargs_call=kwargs,
) )
output_attentions = (
inputs["output_attentions"] if inputs["output_attentions"] is not None else self.output_attentions
)
output_hidden_states = (
inputs["output_hidden_states"] if inputs["output_hidden_states"] is not None else self.output_hidden_states
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.return_dict
# the original code for Transformer-XL used shapes [len, bsz] but we want a unified interface in the library # the original code for Transformer-XL used shapes [len, bsz] but we want a unified interface in the library
# so we transpose here from shape [bsz, len] to shape [len, bsz] # so we transpose here from shape [bsz, len] to shape [len, bsz]
...@@ -591,7 +586,7 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer): ...@@ -591,7 +586,7 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
# word_emb.new_ones((qlen, klen), dtype=torch.uint8), diagonal=1+mlen)[:,:,None] # word_emb.new_ones((qlen, klen), dtype=torch.uint8), diagonal=1+mlen)[:,:,None]
hids = [] hids = []
attentions = [] if output_attentions else None attentions = [] if inputs["output_attentions"] else None
if self.attn_type == 0: # default if self.attn_type == 0: # default
pos_seq = tf.range(klen - 1, -1, -1.0) pos_seq = tf.range(klen - 1, -1, -1.0)
if self.clamp_len > 0: if self.clamp_len > 0:
...@@ -610,11 +605,11 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer): ...@@ -610,11 +605,11 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
dec_attn_mask, dec_attn_mask,
mems_i, mems_i,
inputs["head_mask"][i], inputs["head_mask"][i],
output_attentions, inputs["output_attentions"],
training=inputs["training"], training=inputs["training"],
) )
core_out = layer_outputs[0] core_out = layer_outputs[0]
if output_attentions: if inputs["output_attentions"]:
attentions.append(layer_outputs[1]) attentions.append(layer_outputs[1])
else: # learnable embeddings and absolute embeddings else: # learnable embeddings and absolute embeddings
raise NotImplementedError # Removed these to avoid maintaining dead code - They are not used in our pretrained checkpoint raise NotImplementedError # Removed these to avoid maintaining dead code - They are not used in our pretrained checkpoint
...@@ -626,17 +621,17 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer): ...@@ -626,17 +621,17 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
# We transpose back here to shape [bsz, len, hidden_dim] # We transpose back here to shape [bsz, len, hidden_dim]
core_out = tf.transpose(core_out, perm=(1, 0, 2)) core_out = tf.transpose(core_out, perm=(1, 0, 2))
if output_hidden_states: if inputs["output_hidden_states"]:
# Add last layer and transpose to library standard shape [bsz, len, hidden_dim] # Add last layer and transpose to library standard shape [bsz, len, hidden_dim]
hids.append(core_out) hids.append(core_out)
hids = tuple(tf.transpose(t, perm=(1, 0, 2)) for t in hids) hids = tuple(tf.transpose(t, perm=(1, 0, 2)) for t in hids)
else: else:
hids = None hids = None
if output_attentions: if inputs["output_attentions"]:
# Transpose to library standard shape [bsz, n_heads, query_seq_len, key_seq_len] # Transpose to library standard shape [bsz, n_heads, query_seq_len, key_seq_len]
attentions = tuple(tf.transpose(t, perm=(2, 3, 0, 1)) for t in attentions) attentions = tuple(tf.transpose(t, perm=(2, 3, 0, 1)) for t in attentions)
if not return_dict: if not inputs["return_dict"]:
return tuple(v for v in [core_out, new_mems, hids, attentions] if v is not None) return tuple(v for v in [core_out, new_mems, hids, attentions] if v is not None)
return TFTransfoXLModelOutput( return TFTransfoXLModelOutput(
...@@ -824,6 +819,7 @@ class TFTransfoXLModel(TFTransfoXLPreTrainedModel): ...@@ -824,6 +819,7 @@ class TFTransfoXLModel(TFTransfoXLPreTrainedModel):
): ):
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
mems=mems, mems=mems,
head_mask=head_mask, head_mask=head_mask,
...@@ -921,6 +917,7 @@ class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel): ...@@ -921,6 +917,7 @@ class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel):
): ):
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
mems=mems, mems=mems,
head_mask=head_mask, head_mask=head_mask,
...@@ -931,7 +928,6 @@ class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel): ...@@ -931,7 +928,6 @@ class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel):
training=training, training=training,
kwargs_call=kwargs, kwargs_call=kwargs,
) )
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
if inputs["input_ids"] is not None: if inputs["input_ids"] is not None:
bsz, tgt_len = shape_list(inputs["input_ids"])[:2] bsz, tgt_len = shape_list(inputs["input_ids"])[:2]
...@@ -945,7 +941,7 @@ class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel): ...@@ -945,7 +941,7 @@ class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel):
inputs["inputs_embeds"], inputs["inputs_embeds"],
inputs["output_attentions"], inputs["output_attentions"],
inputs["output_hidden_states"], inputs["output_hidden_states"],
return_dict, inputs["return_dict"],
training=inputs["training"], training=inputs["training"],
) )
...@@ -954,7 +950,7 @@ class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel): ...@@ -954,7 +950,7 @@ class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel):
softmax_output = self.crit(pred_hid, labels, training=inputs["training"]) softmax_output = self.crit(pred_hid, labels, training=inputs["training"])
if not return_dict: if not inputs["return_dict"]:
return (softmax_output,) + transformer_outputs[1:] return (softmax_output,) + transformer_outputs[1:]
return TFTransfoXLLMHeadModelOutput( return TFTransfoXLLMHeadModelOutput(
......
...@@ -230,6 +230,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer): ...@@ -230,6 +230,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.config = config
self.output_hidden_states = config.output_hidden_states self.output_hidden_states = config.output_hidden_states
self.output_attentions = config.output_attentions self.output_attentions = config.output_attentions
self.return_dict = config.use_return_dict self.return_dict = config.use_return_dict
...@@ -361,6 +362,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer): ...@@ -361,6 +362,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
# removed: src_enc=None, src_len=None # removed: src_enc=None, src_len=None
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
langs=langs, langs=langs,
...@@ -376,13 +378,6 @@ class TFXLMMainLayer(tf.keras.layers.Layer): ...@@ -376,13 +378,6 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
training=training, training=training,
kwargs_call=kwargs, kwargs_call=kwargs,
) )
output_attentions = (
inputs["output_attentions"] if inputs["output_attentions"] is not None else self.output_attentions
)
output_hidden_states = (
inputs["output_hidden_states"] if inputs["output_hidden_states"] is not None else self.output_hidden_states
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.return_dict
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None: if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
...@@ -473,11 +468,11 @@ class TFXLMMainLayer(tf.keras.layers.Layer): ...@@ -473,11 +468,11 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
tensor = tensor * mask[..., tf.newaxis] tensor = tensor * mask[..., tf.newaxis]
# transformer layers # transformer layers
hidden_states = () if output_hidden_states else None hidden_states = () if inputs["output_hidden_states"] else None
attentions = () if output_attentions else None attentions = () if inputs["output_attentions"] else None
for i in range(self.n_layers): for i in range(self.n_layers):
if output_hidden_states: if inputs["output_hidden_states"]:
hidden_states = hidden_states + (tensor,) hidden_states = hidden_states + (tensor,)
# self attention # self attention
...@@ -487,12 +482,12 @@ class TFXLMMainLayer(tf.keras.layers.Layer): ...@@ -487,12 +482,12 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
None, None,
inputs["cache"], inputs["cache"],
inputs["head_mask"][i], inputs["head_mask"][i],
output_attentions, inputs["output_attentions"],
training=inputs["training"], training=inputs["training"],
) )
attn = attn_outputs[0] attn = attn_outputs[0]
if output_attentions: if inputs["output_attentions"]:
attentions = attentions + (attn_outputs[1],) attentions = attentions + (attn_outputs[1],)
attn = self.dropout(attn, training=inputs["training"]) attn = self.dropout(attn, training=inputs["training"])
...@@ -512,7 +507,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer): ...@@ -512,7 +507,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
tensor = tensor * mask[..., tf.newaxis] tensor = tensor * mask[..., tf.newaxis]
# Add last hidden state # Add last hidden state
if output_hidden_states: if inputs["output_hidden_states"]:
hidden_states = hidden_states + (tensor,) hidden_states = hidden_states + (tensor,)
# update cache length # update cache length
...@@ -522,7 +517,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer): ...@@ -522,7 +517,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
# move back sequence length to dimension 0 # move back sequence length to dimension 0
# tensor = tensor.transpose(0, 1) # tensor = tensor.transpose(0, 1)
if not return_dict: if not inputs["return_dict"]:
return tuple(v for v in [tensor, hidden_states, attentions] if v is not None) return tuple(v for v in [tensor, hidden_states, attentions] if v is not None)
return TFBaseModelOutput(last_hidden_state=tensor, hidden_states=hidden_states, attentions=attentions) return TFBaseModelOutput(last_hidden_state=tensor, hidden_states=hidden_states, attentions=attentions)
...@@ -720,6 +715,7 @@ class TFXLMModel(TFXLMPreTrainedModel): ...@@ -720,6 +715,7 @@ class TFXLMModel(TFXLMPreTrainedModel):
): ):
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
langs=langs, langs=langs,
...@@ -735,7 +731,6 @@ class TFXLMModel(TFXLMPreTrainedModel): ...@@ -735,7 +731,6 @@ class TFXLMModel(TFXLMPreTrainedModel):
training=training, training=training,
kwargs_call=kwargs, kwargs_call=kwargs,
) )
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
outputs = self.transformer( outputs = self.transformer(
input_ids=inputs["input_ids"], input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"], attention_mask=inputs["attention_mask"],
...@@ -748,7 +743,7 @@ class TFXLMModel(TFXLMPreTrainedModel): ...@@ -748,7 +743,7 @@ class TFXLMModel(TFXLMPreTrainedModel):
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"], output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict, return_dict=inputs["return_dict"],
training=inputs["training"], training=inputs["training"],
) )
...@@ -848,6 +843,7 @@ class TFXLMWithLMHeadModel(TFXLMPreTrainedModel): ...@@ -848,6 +843,7 @@ class TFXLMWithLMHeadModel(TFXLMPreTrainedModel):
): ):
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
langs=langs, langs=langs,
...@@ -863,7 +859,6 @@ class TFXLMWithLMHeadModel(TFXLMPreTrainedModel): ...@@ -863,7 +859,6 @@ class TFXLMWithLMHeadModel(TFXLMPreTrainedModel):
training=training, training=training,
kwargs_call=kwargs, kwargs_call=kwargs,
) )
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
transformer_outputs = self.transformer( transformer_outputs = self.transformer(
input_ids=inputs["input_ids"], input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"], attention_mask=inputs["attention_mask"],
...@@ -876,14 +871,14 @@ class TFXLMWithLMHeadModel(TFXLMPreTrainedModel): ...@@ -876,14 +871,14 @@ class TFXLMWithLMHeadModel(TFXLMPreTrainedModel):
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"], output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict, return_dict=inputs["return_dict"],
training=inputs["training"], training=inputs["training"],
) )
output = transformer_outputs[0] output = transformer_outputs[0]
outputs = self.pred_layer(output) outputs = self.pred_layer(output)
if not return_dict: if not inputs["return_dict"]:
return (outputs,) + transformer_outputs[1:] return (outputs,) + transformer_outputs[1:]
return TFXLMWithLMHeadModelOutput( return TFXLMWithLMHeadModelOutput(
...@@ -939,6 +934,7 @@ class TFXLMForSequenceClassification(TFXLMPreTrainedModel, TFSequenceClassificat ...@@ -939,6 +934,7 @@ class TFXLMForSequenceClassification(TFXLMPreTrainedModel, TFSequenceClassificat
""" """
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
langs=langs, langs=langs,
...@@ -955,7 +951,6 @@ class TFXLMForSequenceClassification(TFXLMPreTrainedModel, TFSequenceClassificat ...@@ -955,7 +951,6 @@ class TFXLMForSequenceClassification(TFXLMPreTrainedModel, TFSequenceClassificat
training=training, training=training,
kwargs_call=kwargs, kwargs_call=kwargs,
) )
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
transformer_outputs = self.transformer( transformer_outputs = self.transformer(
input_ids=inputs["input_ids"], input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"], attention_mask=inputs["attention_mask"],
...@@ -977,7 +972,7 @@ class TFXLMForSequenceClassification(TFXLMPreTrainedModel, TFSequenceClassificat ...@@ -977,7 +972,7 @@ class TFXLMForSequenceClassification(TFXLMPreTrainedModel, TFSequenceClassificat
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
if not return_dict: if not inputs["return_dict"]:
output = (logits,) + transformer_outputs[1:] output = (logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
...@@ -1046,6 +1041,7 @@ class TFXLMForMultipleChoice(TFXLMPreTrainedModel, TFMultipleChoiceLoss): ...@@ -1046,6 +1041,7 @@ class TFXLMForMultipleChoice(TFXLMPreTrainedModel, TFMultipleChoiceLoss):
): ):
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
langs=langs, langs=langs,
...@@ -1062,7 +1058,6 @@ class TFXLMForMultipleChoice(TFXLMPreTrainedModel, TFMultipleChoiceLoss): ...@@ -1062,7 +1058,6 @@ class TFXLMForMultipleChoice(TFXLMPreTrainedModel, TFMultipleChoiceLoss):
training=training, training=training,
kwargs_call=kwargs, kwargs_call=kwargs,
) )
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
if inputs["input_ids"] is not None: if inputs["input_ids"] is not None:
num_choices = shape_list(inputs["input_ids"])[1] num_choices = shape_list(inputs["input_ids"])[1]
...@@ -1107,7 +1102,7 @@ class TFXLMForMultipleChoice(TFXLMPreTrainedModel, TFMultipleChoiceLoss): ...@@ -1107,7 +1102,7 @@ class TFXLMForMultipleChoice(TFXLMPreTrainedModel, TFMultipleChoiceLoss):
flat_inputs_embeds, flat_inputs_embeds,
inputs["output_attentions"], inputs["output_attentions"],
inputs["output_hidden_states"], inputs["output_hidden_states"],
return_dict=return_dict, return_dict=inputs["return_dict"],
training=inputs["training"], training=inputs["training"],
) )
output = transformer_outputs[0] output = transformer_outputs[0]
...@@ -1117,7 +1112,7 @@ class TFXLMForMultipleChoice(TFXLMPreTrainedModel, TFMultipleChoiceLoss): ...@@ -1117,7 +1112,7 @@ class TFXLMForMultipleChoice(TFXLMPreTrainedModel, TFMultipleChoiceLoss):
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], reshaped_logits) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], reshaped_logits)
if not return_dict: if not inputs["return_dict"]:
output = (reshaped_logits,) + transformer_outputs[1:] output = (reshaped_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
...@@ -1180,6 +1175,7 @@ class TFXLMForTokenClassification(TFXLMPreTrainedModel, TFTokenClassificationLos ...@@ -1180,6 +1175,7 @@ class TFXLMForTokenClassification(TFXLMPreTrainedModel, TFTokenClassificationLos
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
input_ids=input_ids, input_ids=input_ids,
config=self.config,
attention_mask=attention_mask, attention_mask=attention_mask,
langs=langs, langs=langs,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -1195,7 +1191,6 @@ class TFXLMForTokenClassification(TFXLMPreTrainedModel, TFTokenClassificationLos ...@@ -1195,7 +1191,6 @@ class TFXLMForTokenClassification(TFXLMPreTrainedModel, TFTokenClassificationLos
training=training, training=training,
kwargs_call=kwargs, kwargs_call=kwargs,
) )
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
transformer_outputs = self.transformer( transformer_outputs = self.transformer(
input_ids=inputs["input_ids"], input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"], attention_mask=inputs["attention_mask"],
...@@ -1208,10 +1203,9 @@ class TFXLMForTokenClassification(TFXLMPreTrainedModel, TFTokenClassificationLos ...@@ -1208,10 +1203,9 @@ class TFXLMForTokenClassification(TFXLMPreTrainedModel, TFTokenClassificationLos
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"], output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict, return_dict=inputs["return_dict"],
training=inputs["training"], training=inputs["training"],
) )
sequence_output = transformer_outputs[0] sequence_output = transformer_outputs[0]
sequence_output = self.dropout(sequence_output, training=inputs["training"]) sequence_output = self.dropout(sequence_output, training=inputs["training"])
...@@ -1219,7 +1213,7 @@ class TFXLMForTokenClassification(TFXLMPreTrainedModel, TFTokenClassificationLos ...@@ -1219,7 +1213,7 @@ class TFXLMForTokenClassification(TFXLMPreTrainedModel, TFTokenClassificationLos
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
if not return_dict: if not inputs["return_dict"]:
output = (logits,) + transformer_outputs[1:] output = (logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
...@@ -1284,6 +1278,7 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel, TFQuestionAnsweringL ...@@ -1284,6 +1278,7 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel, TFQuestionAnsweringL
""" """
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
langs=langs, langs=langs,
...@@ -1301,7 +1296,6 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel, TFQuestionAnsweringL ...@@ -1301,7 +1296,6 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel, TFQuestionAnsweringL
training=training, training=training,
kwargs_call=kwargs, kwargs_call=kwargs,
) )
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
transformer_outputs = self.transformer( transformer_outputs = self.transformer(
input_ids=inputs["input_ids"], input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"], attention_mask=inputs["attention_mask"],
...@@ -1314,10 +1308,9 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel, TFQuestionAnsweringL ...@@ -1314,10 +1308,9 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel, TFQuestionAnsweringL
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"], output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict, return_dict=inputs["return_dict"],
training=inputs["training"], training=inputs["training"],
) )
sequence_output = transformer_outputs[0] sequence_output = transformer_outputs[0]
logits = self.qa_outputs(sequence_output) logits = self.qa_outputs(sequence_output)
...@@ -1331,7 +1324,7 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel, TFQuestionAnsweringL ...@@ -1331,7 +1324,7 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel, TFQuestionAnsweringL
labels["end_position"] = inputs["end_positions"] labels["end_position"] = inputs["end_positions"]
loss = self.compute_loss(labels, (start_logits, end_logits)) loss = self.compute_loss(labels, (start_logits, end_logits))
if not return_dict: if not inputs["return_dict"]:
output = (start_logits, end_logits) + transformer_outputs[1:] output = (start_logits, end_logits) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
......
...@@ -419,6 +419,8 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): ...@@ -419,6 +419,8 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.config = config
self.output_hidden_states = config.output_hidden_states self.output_hidden_states = config.output_hidden_states
self.output_attentions = config.output_attentions self.output_attentions = config.output_attentions
self.return_dict = config.return_dict self.return_dict = config.return_dict
...@@ -590,6 +592,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): ...@@ -590,6 +592,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
): ):
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
mems=mems, mems=mems,
...@@ -606,18 +609,11 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): ...@@ -606,18 +609,11 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
training=training, training=training,
kwargs_call=kwargs, kwargs_call=kwargs,
) )
output_attentions = (
inputs["output_attentions"] if inputs["output_attentions"] is not None else self.output_attentions
)
output_hidden_states = (
inputs["output_hidden_states"] if inputs["output_hidden_states"] is not None else self.output_hidden_states
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.return_dict
if training: if training and inputs["use_mems"] is None:
use_mems = use_mems if use_mems is not None else self.use_mems_train inputs["use_mems"] = self.use_mems_train
else: else:
use_mems = use_mems if use_mems is not None else self.use_mems_eval inputs["use_mems"] = self.use_mems_eval
# the original code for XLNet uses shapes [len, bsz] with the batch dimension at the end # the original code for XLNet uses shapes [len, bsz] with the batch dimension at the end
# but we want a unified interface in the library with the batch size on the first dimension # but we want a unified interface in the library with the batch size on the first dimension
...@@ -750,13 +746,13 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): ...@@ -750,13 +746,13 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
if inputs["mems"] is None: if inputs["mems"] is None:
inputs["mems"] = [None] * len(self.layer) inputs["mems"] = [None] * len(self.layer)
attentions = [] if output_attentions else None attentions = [] if inputs["output_attentions"] else None
hidden_states = [] if output_hidden_states else None hidden_states = [] if inputs["output_hidden_states"] else None
for i, layer_module in enumerate(self.layer): for i, layer_module in enumerate(self.layer):
# cache new mems # cache new mems
if use_mems: if inputs["use_mems"]:
new_mems = new_mems + (self.cache_mem(output_h, inputs["mems"][i]),) new_mems = new_mems + (self.cache_mem(output_h, inputs["mems"][i]),)
if output_hidden_states: if inputs["output_hidden_states"]:
hidden_states.append((output_h, output_g) if output_g is not None else output_h) hidden_states.append((output_h, output_g) if output_g is not None else output_h)
outputs = layer_module( outputs = layer_module(
...@@ -769,15 +765,15 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): ...@@ -769,15 +765,15 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
inputs["mems"][i], inputs["mems"][i],
inputs["target_mapping"], inputs["target_mapping"],
inputs["head_mask"][i], inputs["head_mask"][i],
output_attentions, inputs["output_attentions"],
training=inputs["training"], training=inputs["training"],
) )
output_h, output_g = outputs[:2] output_h, output_g = outputs[:2]
if output_attentions: if inputs["output_attentions"]:
attentions.append(outputs[2]) attentions.append(outputs[2])
# Add last hidden state # Add last hidden state
if output_hidden_states: if inputs["output_hidden_states"]:
hidden_states.append((output_h, output_g) if output_g is not None else output_h) hidden_states.append((output_h, output_g) if output_g is not None else output_h)
output = self.dropout(output_g if output_g is not None else output_h, training=inputs["training"]) output = self.dropout(output_g if output_g is not None else output_h, training=inputs["training"])
...@@ -785,17 +781,17 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): ...@@ -785,17 +781,17 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
# Prepare outputs, we transpose back here to shape [bsz, len, hidden_dim] (cf. beginning of forward() method) # Prepare outputs, we transpose back here to shape [bsz, len, hidden_dim] (cf. beginning of forward() method)
output = tf.transpose(output, perm=(1, 0, 2)) output = tf.transpose(output, perm=(1, 0, 2))
if not use_mems: if not inputs["use_mems"]:
new_mems = None new_mems = None
if output_hidden_states: if inputs["output_hidden_states"]:
if output_g is not None: if output_g is not None:
hidden_states = tuple(tf.transpose(h, perm=(1, 0, 2)) for hs in hidden_states for h in hs) hidden_states = tuple(tf.transpose(h, perm=(1, 0, 2)) for hs in hidden_states for h in hs)
else: else:
hidden_states = tuple(tf.transpose(hs, perm=(1, 0, 2)) for hs in hidden_states) hidden_states = tuple(tf.transpose(hs, perm=(1, 0, 2)) for hs in hidden_states)
if output_attentions: if inputs["output_attentions"]:
attentions = tuple(tf.transpose(t, perm=(2, 3, 0, 1)) for t in attentions) attentions = tuple(tf.transpose(t, perm=(2, 3, 0, 1)) for t in attentions)
if not return_dict: if not inputs["return_dict"]:
return tuple(v for v in [output, new_mems, hidden_states, attentions] if v is not None) return tuple(v for v in [output, new_mems, hidden_states, attentions] if v is not None)
return TFXLNetModelOutput( return TFXLNetModelOutput(
...@@ -1173,6 +1169,7 @@ class TFXLNetModel(TFXLNetPreTrainedModel): ...@@ -1173,6 +1169,7 @@ class TFXLNetModel(TFXLNetPreTrainedModel):
): ):
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
mems=mems, mems=mems,
...@@ -1317,6 +1314,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -1317,6 +1314,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
""" """
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
mems=mems, mems=mems,
...@@ -1334,7 +1332,6 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -1334,7 +1332,6 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
training=training, training=training,
kwargs_call=kwargs, kwargs_call=kwargs,
) )
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
transformer_outputs = self.transformer( transformer_outputs = self.transformer(
input_ids=inputs["input_ids"], input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"], attention_mask=inputs["attention_mask"],
...@@ -1348,7 +1345,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -1348,7 +1345,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
use_mems=inputs["use_mems"], use_mems=inputs["use_mems"],
output_attentions=inputs["output_attentions"], output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict, return_dict=inputs["return_dict"],
training=inputs["training"], training=inputs["training"],
) )
hidden_state = transformer_outputs[0] hidden_state = transformer_outputs[0]
...@@ -1361,7 +1358,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -1361,7 +1358,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
labels = inputs["labels"][:, 1:] labels = inputs["labels"][:, 1:]
loss = self.compute_loss(labels, logits) loss = self.compute_loss(labels, logits)
if not return_dict: if not inputs["return_dict"]:
output = (logits,) + transformer_outputs[1:] output = (logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
...@@ -1428,6 +1425,7 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif ...@@ -1428,6 +1425,7 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif
""" """
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
mems=mems, mems=mems,
...@@ -1445,7 +1443,6 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif ...@@ -1445,7 +1443,6 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif
training=training, training=training,
kwargs_call=kwargs, kwargs_call=kwargs,
) )
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
transformer_outputs = self.transformer( transformer_outputs = self.transformer(
input_ids=inputs["input_ids"], input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"], attention_mask=inputs["attention_mask"],
...@@ -1469,7 +1466,7 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif ...@@ -1469,7 +1466,7 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
if not return_dict: if not inputs["return_dict"]:
output = (logits,) + transformer_outputs[1:] output = (logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
...@@ -1546,6 +1543,7 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss): ...@@ -1546,6 +1543,7 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
mems=mems, mems=mems,
...@@ -1563,7 +1561,6 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss): ...@@ -1563,7 +1561,6 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
training=training, training=training,
kwargs_call=kwargs, kwargs_call=kwargs,
) )
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
if inputs["input_ids"] is not None: if inputs["input_ids"] is not None:
num_choices = shape_list(inputs["input_ids"])[1] num_choices = shape_list(inputs["input_ids"])[1]
...@@ -1600,7 +1597,7 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss): ...@@ -1600,7 +1597,7 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
inputs["use_mems"], inputs["use_mems"],
inputs["output_attentions"], inputs["output_attentions"],
inputs["output_hidden_states"], inputs["output_hidden_states"],
return_dict=return_dict, return_dict=inputs["return_dict"],
training=inputs["training"], training=inputs["training"],
) )
output = transformer_outputs[0] output = transformer_outputs[0]
...@@ -1609,7 +1606,7 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss): ...@@ -1609,7 +1606,7 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
reshaped_logits = tf.reshape(logits, (-1, num_choices)) reshaped_logits = tf.reshape(logits, (-1, num_choices))
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], reshaped_logits) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], reshaped_logits)
if not return_dict: if not inputs["return_dict"]:
output = (reshaped_logits,) + transformer_outputs[1:] output = (reshaped_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
...@@ -1672,6 +1669,7 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio ...@@ -1672,6 +1669,7 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio
""" """
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
mems=mems, mems=mems,
...@@ -1689,7 +1687,6 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio ...@@ -1689,7 +1687,6 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio
training=training, training=training,
kwargs_call=kwargs, kwargs_call=kwargs,
) )
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
transformer_outputs = self.transformer( transformer_outputs = self.transformer(
input_ids=inputs["input_ids"], input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"], attention_mask=inputs["attention_mask"],
...@@ -1703,7 +1700,7 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio ...@@ -1703,7 +1700,7 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio
use_mems=inputs["use_mems"], use_mems=inputs["use_mems"],
output_attentions=inputs["output_attentions"], output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict, return_dict=inputs["return_dict"],
training=inputs["training"], training=inputs["training"],
) )
...@@ -1711,7 +1708,7 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio ...@@ -1711,7 +1708,7 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio
logits = self.classifier(output) logits = self.classifier(output)
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
if not return_dict: if not inputs["return_dict"]:
output = (logits,) + transformer_outputs[1:] output = (logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
...@@ -1778,6 +1775,7 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer ...@@ -1778,6 +1775,7 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer
""" """
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
mems=mems, mems=mems,
...@@ -1796,7 +1794,6 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer ...@@ -1796,7 +1794,6 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer
training=training, training=training,
kwargs_call=kwargs, kwargs_call=kwargs,
) )
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
transformer_outputs = self.transformer( transformer_outputs = self.transformer(
input_ids=inputs["input_ids"], input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"], attention_mask=inputs["attention_mask"],
...@@ -1810,10 +1807,9 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer ...@@ -1810,10 +1807,9 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer
use_mems=inputs["use_mems"], use_mems=inputs["use_mems"],
output_attentions=inputs["output_attentions"], output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict, return_dict=inputs["return_dict"],
training=inputs["training"], training=inputs["training"],
) )
sequence_output = transformer_outputs[0] sequence_output = transformer_outputs[0]
logits = self.qa_outputs(sequence_output) logits = self.qa_outputs(sequence_output)
...@@ -1827,7 +1823,7 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer ...@@ -1827,7 +1823,7 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer
labels["end_position"] = inputs["end_positions"] labels["end_position"] = inputs["end_positions"]
loss = self.compute_loss(labels, (start_logits, end_logits)) loss = self.compute_loss(labels, (start_logits, end_logits))
if not return_dict: if not inputs["return_dict"]:
output = (start_logits, end_logits) + transformer_outputs[1:] output = (start_logits, end_logits) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment