Unverified Commit c2f8eaf6 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

TF: unpack inputs on Convbert, GPTJ, LED, and templates (#16491)

* Add unpack_inputs to remaining models

* remove stray use of inputs in the templates; fix tf.debugging of attn masks
parent ae189ef9
...@@ -943,12 +943,12 @@ class TFBartDecoder(tf.keras.layers.Layer): ...@@ -943,12 +943,12 @@ class TFBartDecoder(tf.keras.layers.Layer):
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired # check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they # The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager. # have to be disabled in other modes than eager.
for attn_mask in [head_mask, cross_attn_head_mask]: for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]:
if attn_mask is not None and tf.executing_eagerly(): if attn_mask is not None and tf.executing_eagerly():
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(attn_mask)[0], shape_list(attn_mask)[0],
len(self.layers), len(self.layers),
message=f"The {attn_mask} should be specified for {len(self.layers)} layers, but it is for {shape_list(attn_mask)[0]}.", message=f"The {attn_mask_name} should be specified for {len(self.layers)} layers, but it is for {shape_list(attn_mask)[0]}.",
) )
for idx, decoder_layer in enumerate(self.layers): for idx, decoder_layer in enumerate(self.layers):
......
...@@ -39,8 +39,8 @@ from ...modeling_tf_utils import ( ...@@ -39,8 +39,8 @@ from ...modeling_tf_utils import (
TFSequenceSummary, TFSequenceSummary,
TFTokenClassificationLoss, TFTokenClassificationLoss,
get_initializer, get_initializer,
input_processing,
keras_serializable, keras_serializable,
unpack_inputs,
) )
from ...tf_utils import shape_list from ...tf_utils import shape_list
from ...utils import ( from ...utils import (
...@@ -568,6 +568,7 @@ class TFConvBertMainLayer(tf.keras.layers.Layer): ...@@ -568,6 +568,7 @@ class TFConvBertMainLayer(tf.keras.layers.Layer):
return head_mask return head_mask
@unpack_inputs
def call( def call(
self, self,
input_ids=None, input_ids=None,
...@@ -582,60 +583,36 @@ class TFConvBertMainLayer(tf.keras.layers.Layer): ...@@ -582,60 +583,36 @@ class TFConvBertMainLayer(tf.keras.layers.Layer):
training=False, training=False,
**kwargs, **kwargs,
): ):
inputs = input_processing( if input_ids is not None and inputs_embeds is not None:
func=self.call,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif inputs["input_ids"] is not None: elif input_ids is not None:
input_shape = shape_list(inputs["input_ids"]) input_shape = shape_list(input_ids)
elif inputs["inputs_embeds"] is not None: elif inputs_embeds is not None:
input_shape = shape_list(inputs["inputs_embeds"])[:-1] input_shape = shape_list(inputs_embeds)[:-1]
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs["attention_mask"] is None: if attention_mask is None:
inputs["attention_mask"] = tf.fill(input_shape, 1) attention_mask = tf.fill(input_shape, 1)
if inputs["token_type_ids"] is None: if token_type_ids is None:
inputs["token_type_ids"] = tf.fill(input_shape, 0) token_type_ids = tf.fill(input_shape, 0)
hidden_states = self.embeddings( hidden_states = self.embeddings(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)
inputs["input_ids"], extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, hidden_states.dtype)
inputs["position_ids"], head_mask = self.get_head_mask(head_mask)
inputs["token_type_ids"],
inputs["inputs_embeds"],
training=inputs["training"],
)
extended_attention_mask = self.get_extended_attention_mask(
inputs["attention_mask"], input_shape, hidden_states.dtype
)
inputs["head_mask"] = self.get_head_mask(inputs["head_mask"])
if hasattr(self, "embeddings_project"): if hasattr(self, "embeddings_project"):
hidden_states = self.embeddings_project(hidden_states, training=inputs["training"]) hidden_states = self.embeddings_project(hidden_states, training=training)
hidden_states = self.encoder( hidden_states = self.encoder(
hidden_states, hidden_states,
extended_attention_mask, extended_attention_mask,
inputs["head_mask"], head_mask,
inputs["output_attentions"], output_attentions,
inputs["output_hidden_states"], output_hidden_states,
inputs["return_dict"], return_dict,
training=inputs["training"], training=training,
) )
return hidden_states return hidden_states
...@@ -754,6 +731,7 @@ class TFConvBertModel(TFConvBertPreTrainedModel): ...@@ -754,6 +731,7 @@ class TFConvBertModel(TFConvBertPreTrainedModel):
self.convbert = TFConvBertMainLayer(config, name="convbert") self.convbert = TFConvBertMainLayer(config, name="convbert")
@unpack_inputs
@add_start_docstrings_to_model_forward(CONVBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(CONVBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC, processor_class=_TOKENIZER_FOR_DOC,
...@@ -775,9 +753,7 @@ class TFConvBertModel(TFConvBertPreTrainedModel): ...@@ -775,9 +753,7 @@ class TFConvBertModel(TFConvBertPreTrainedModel):
training=False, training=False,
**kwargs, **kwargs,
): ):
inputs = input_processing( outputs = self.convbert(
func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -788,19 +764,6 @@ class TFConvBertModel(TFConvBertPreTrainedModel): ...@@ -788,19 +764,6 @@ class TFConvBertModel(TFConvBertPreTrainedModel):
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
training=training, training=training,
kwargs_call=kwargs,
)
outputs = self.convbert(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
) )
return outputs return outputs
...@@ -886,6 +849,7 @@ class TFConvBertForMaskedLM(TFConvBertPreTrainedModel, TFMaskedLanguageModelingL ...@@ -886,6 +849,7 @@ class TFConvBertForMaskedLM(TFConvBertPreTrainedModel, TFMaskedLanguageModelingL
def get_prefix_bias_name(self): def get_prefix_bias_name(self):
return self.name + "/" + self.generator_lm_head.name return self.name + "/" + self.generator_lm_head.name
@unpack_inputs
@add_start_docstrings_to_model_forward(CONVBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(CONVBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC, processor_class=_TOKENIZER_FOR_DOC,
...@@ -914,9 +878,7 @@ class TFConvBertForMaskedLM(TFConvBertPreTrainedModel, TFMaskedLanguageModelingL ...@@ -914,9 +878,7 @@ class TFConvBertForMaskedLM(TFConvBertPreTrainedModel, TFMaskedLanguageModelingL
config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
""" """
inputs = input_processing( generator_hidden_states = self.convbert(
func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -926,28 +888,14 @@ class TFConvBertForMaskedLM(TFConvBertPreTrainedModel, TFMaskedLanguageModelingL ...@@ -926,28 +888,14 @@ class TFConvBertForMaskedLM(TFConvBertPreTrainedModel, TFMaskedLanguageModelingL
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
labels=labels,
training=training, training=training,
kwargs_call=kwargs,
)
generator_hidden_states = self.convbert(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
) )
generator_sequence_output = generator_hidden_states[0] generator_sequence_output = generator_hidden_states[0]
prediction_scores = self.generator_predictions(generator_sequence_output, training=inputs["training"]) prediction_scores = self.generator_predictions(generator_sequence_output, training=training)
prediction_scores = self.generator_lm_head(prediction_scores, training=inputs["training"]) prediction_scores = self.generator_lm_head(prediction_scores, training=training)
loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], prediction_scores) loss = None if labels is None else self.hf_compute_loss(labels, prediction_scores)
if not inputs["return_dict"]: if not return_dict:
output = (prediction_scores,) + generator_hidden_states[1:] output = (prediction_scores,) + generator_hidden_states[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
...@@ -1010,6 +958,7 @@ class TFConvBertForSequenceClassification(TFConvBertPreTrainedModel, TFSequenceC ...@@ -1010,6 +958,7 @@ class TFConvBertForSequenceClassification(TFConvBertPreTrainedModel, TFSequenceC
self.convbert = TFConvBertMainLayer(config, name="convbert") self.convbert = TFConvBertMainLayer(config, name="convbert")
self.classifier = TFConvBertClassificationHead(config, name="classifier") self.classifier = TFConvBertClassificationHead(config, name="classifier")
@unpack_inputs
@add_start_docstrings_to_model_forward(CONVBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(CONVBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC, processor_class=_TOKENIZER_FOR_DOC,
...@@ -1038,10 +987,8 @@ class TFConvBertForSequenceClassification(TFConvBertPreTrainedModel, TFSequenceC ...@@ -1038,10 +987,8 @@ class TFConvBertForSequenceClassification(TFConvBertPreTrainedModel, TFSequenceC
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy). `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
""" """
inputs = input_processing( outputs = self.convbert(
func=self.call, input_ids,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
...@@ -1050,26 +997,12 @@ class TFConvBertForSequenceClassification(TFConvBertPreTrainedModel, TFSequenceC ...@@ -1050,26 +997,12 @@ class TFConvBertForSequenceClassification(TFConvBertPreTrainedModel, TFSequenceC
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
labels=labels,
training=training, training=training,
kwargs_call=kwargs,
)
outputs = self.convbert(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
) )
logits = self.classifier(outputs[0], training=inputs["training"]) logits = self.classifier(outputs[0], training=training)
loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], logits) loss = None if labels is None else self.hf_compute_loss(labels, logits)
if not inputs["return_dict"]: if not return_dict:
output = (logits,) + outputs[1:] output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
...@@ -1117,6 +1050,7 @@ class TFConvBertForMultipleChoice(TFConvBertPreTrainedModel, TFMultipleChoiceLos ...@@ -1117,6 +1050,7 @@ class TFConvBertForMultipleChoice(TFConvBertPreTrainedModel, TFMultipleChoiceLos
""" """
return {"input_ids": tf.convert_to_tensor(MULTIPLE_CHOICE_DUMMY_INPUTS)} return {"input_ids": tf.convert_to_tensor(MULTIPLE_CHOICE_DUMMY_INPUTS)}
@unpack_inputs
@add_start_docstrings_to_model_forward( @add_start_docstrings_to_model_forward(
CONVBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length") CONVBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
) )
...@@ -1146,43 +1080,20 @@ class TFConvBertForMultipleChoice(TFConvBertPreTrainedModel, TFMultipleChoiceLos ...@@ -1146,43 +1080,20 @@ class TFConvBertForMultipleChoice(TFConvBertPreTrainedModel, TFMultipleChoiceLos
Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]` Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]`
where `num_choices` is the size of the second dimension of the input tensors. (See `input_ids` above) where `num_choices` is the size of the second dimension of the input tensors. (See `input_ids` above)
""" """
inputs = input_processing( if input_ids is not None:
func=self.call, num_choices = shape_list(input_ids)[1]
config=self.config, seq_length = shape_list(input_ids)[2]
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
if inputs["input_ids"] is not None:
num_choices = shape_list(inputs["input_ids"])[1]
seq_length = shape_list(inputs["input_ids"])[2]
else: else:
num_choices = shape_list(inputs["inputs_embeds"])[1] num_choices = shape_list(inputs_embeds)[1]
seq_length = shape_list(inputs["inputs_embeds"])[2] seq_length = shape_list(inputs_embeds)[2]
flat_input_ids = tf.reshape(inputs["input_ids"], (-1, seq_length)) if inputs["input_ids"] is not None else None flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
flat_attention_mask = ( flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
tf.reshape(inputs["attention_mask"], (-1, seq_length)) if inputs["attention_mask"] is not None else None flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
) flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
flat_token_type_ids = (
tf.reshape(inputs["token_type_ids"], (-1, seq_length)) if inputs["token_type_ids"] is not None else None
)
flat_position_ids = (
tf.reshape(inputs["position_ids"], (-1, seq_length)) if inputs["position_ids"] is not None else None
)
flat_inputs_embeds = ( flat_inputs_embeds = (
tf.reshape(inputs["inputs_embeds"], (-1, seq_length, shape_list(inputs["inputs_embeds"])[3])) tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3]))
if inputs["inputs_embeds"] is not None if inputs_embeds is not None
else None else None
) )
outputs = self.convbert( outputs = self.convbert(
...@@ -1190,19 +1101,19 @@ class TFConvBertForMultipleChoice(TFConvBertPreTrainedModel, TFMultipleChoiceLos ...@@ -1190,19 +1101,19 @@ class TFConvBertForMultipleChoice(TFConvBertPreTrainedModel, TFMultipleChoiceLos
flat_attention_mask, flat_attention_mask,
flat_token_type_ids, flat_token_type_ids,
flat_position_ids, flat_position_ids,
inputs["head_mask"], head_mask,
flat_inputs_embeds, flat_inputs_embeds,
inputs["output_attentions"], output_attentions,
inputs["output_hidden_states"], output_hidden_states,
return_dict=inputs["return_dict"], return_dict=return_dict,
training=inputs["training"], training=training,
) )
logits = self.sequence_summary(outputs[0], training=inputs["training"]) logits = self.sequence_summary(outputs[0], training=training)
logits = self.classifier(logits) logits = self.classifier(logits)
reshaped_logits = tf.reshape(logits, (-1, num_choices)) reshaped_logits = tf.reshape(logits, (-1, num_choices))
loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], reshaped_logits) loss = None if labels is None else self.hf_compute_loss(labels, reshaped_logits)
if not inputs["return_dict"]: if not return_dict:
output = (reshaped_logits,) + outputs[1:] output = (reshaped_logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
...@@ -1256,6 +1167,7 @@ class TFConvBertForTokenClassification(TFConvBertPreTrainedModel, TFTokenClassif ...@@ -1256,6 +1167,7 @@ class TFConvBertForTokenClassification(TFConvBertPreTrainedModel, TFTokenClassif
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier" config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
) )
@unpack_inputs
@add_start_docstrings_to_model_forward(CONVBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(CONVBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC, processor_class=_TOKENIZER_FOR_DOC,
...@@ -1282,10 +1194,8 @@ class TFConvBertForTokenClassification(TFConvBertPreTrainedModel, TFTokenClassif ...@@ -1282,10 +1194,8 @@ class TFConvBertForTokenClassification(TFConvBertPreTrainedModel, TFTokenClassif
labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
""" """
inputs = input_processing( outputs = self.convbert(
func=self.call, input_ids,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
...@@ -1294,28 +1204,14 @@ class TFConvBertForTokenClassification(TFConvBertPreTrainedModel, TFTokenClassif ...@@ -1294,28 +1204,14 @@ class TFConvBertForTokenClassification(TFConvBertPreTrainedModel, TFTokenClassif
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
labels=labels,
training=training, training=training,
kwargs_call=kwargs,
)
outputs = self.convbert(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
) )
sequence_output = outputs[0] sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output, training=inputs["training"]) sequence_output = self.dropout(sequence_output, training=training)
logits = self.classifier(sequence_output) logits = self.classifier(sequence_output)
loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], logits) loss = None if labels is None else self.hf_compute_loss(labels, logits)
if not inputs["return_dict"]: if not return_dict:
output = (logits,) + outputs[1:] output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
...@@ -1350,6 +1246,7 @@ class TFConvBertForQuestionAnswering(TFConvBertPreTrainedModel, TFQuestionAnswer ...@@ -1350,6 +1246,7 @@ class TFConvBertForQuestionAnswering(TFConvBertPreTrainedModel, TFQuestionAnswer
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs" config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
) )
@unpack_inputs
@add_start_docstrings_to_model_forward(CONVBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(CONVBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC, processor_class=_TOKENIZER_FOR_DOC,
...@@ -1383,10 +1280,8 @@ class TFConvBertForQuestionAnswering(TFConvBertPreTrainedModel, TFQuestionAnswer ...@@ -1383,10 +1280,8 @@ class TFConvBertForQuestionAnswering(TFConvBertPreTrainedModel, TFQuestionAnswer
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss. are not taken into account for computing the loss.
""" """
inputs = input_processing( outputs = self.convbert(
func=self.call, input_ids,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
...@@ -1395,22 +1290,7 @@ class TFConvBertForQuestionAnswering(TFConvBertPreTrainedModel, TFQuestionAnswer ...@@ -1395,22 +1290,7 @@ class TFConvBertForQuestionAnswering(TFConvBertPreTrainedModel, TFQuestionAnswer
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
start_positions=start_positions,
end_positions=end_positions,
training=training, training=training,
kwargs_call=kwargs,
)
outputs = self.convbert(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
) )
sequence_output = outputs[0] sequence_output = outputs[0]
logits = self.qa_outputs(sequence_output) logits = self.qa_outputs(sequence_output)
...@@ -1419,12 +1299,12 @@ class TFConvBertForQuestionAnswering(TFConvBertPreTrainedModel, TFQuestionAnswer ...@@ -1419,12 +1299,12 @@ class TFConvBertForQuestionAnswering(TFConvBertPreTrainedModel, TFQuestionAnswer
end_logits = tf.squeeze(end_logits, axis=-1) end_logits = tf.squeeze(end_logits, axis=-1)
loss = None loss = None
if inputs["start_positions"] is not None and inputs["end_positions"] is not None: if start_positions is not None and end_positions is not None:
labels = {"start_position": inputs["start_positions"]} labels = {"start_position": start_positions}
labels["end_position"] = inputs["end_positions"] labels["end_position"] = end_positions
loss = self.hf_compute_loss(labels, (start_logits, end_logits)) loss = self.hf_compute_loss(labels, (start_logits, end_logits))
if not inputs["return_dict"]: if not return_dict:
output = (start_logits, end_logits) + outputs[1:] output = (start_logits, end_logits) + outputs[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
......
...@@ -40,7 +40,6 @@ from ...modeling_tf_utils import ( ...@@ -40,7 +40,6 @@ from ...modeling_tf_utils import (
TFSequenceClassificationLoss, TFSequenceClassificationLoss,
TFSharedEmbeddings, TFSharedEmbeddings,
get_initializer, get_initializer,
input_processing,
keras_serializable, keras_serializable,
unpack_inputs, unpack_inputs,
) )
...@@ -376,6 +375,7 @@ class TFGPTJMainLayer(tf.keras.layers.Layer): ...@@ -376,6 +375,7 @@ class TFGPTJMainLayer(tf.keras.layers.Layer):
""" """
raise NotImplementedError raise NotImplementedError
@unpack_inputs
def call( def call(
self, self,
input_ids=None, input_ids=None,
...@@ -392,53 +392,34 @@ class TFGPTJMainLayer(tf.keras.layers.Layer): ...@@ -392,53 +392,34 @@ class TFGPTJMainLayer(tf.keras.layers.Layer):
training=False, training=False,
**kwargs, **kwargs,
): ):
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None: if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif inputs["input_ids"] is not None: elif input_ids is not None:
input_shape = shape_list(inputs["input_ids"]) input_shape = shape_list(input_ids)
inputs["input_ids"] = tf.reshape(inputs["input_ids"], [-1, input_shape[-1]]) input_ids = tf.reshape(input_ids, [-1, input_shape[-1]])
elif inputs["inputs_embeds"] is not None: elif inputs_embeds is not None:
input_shape = shape_list(inputs["inputs_embeds"])[:-1] input_shape = shape_list(inputs_embeds)[:-1]
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs["past_key_values"] is None: if past_key_values is None:
past_length = 0 past_length = 0
inputs["past_key_values"] = [None] * len(self.h) past_key_values = [None] * len(self.h)
else: else:
past_length = shape_list(inputs["past_key_values"][0][0])[-2] past_length = shape_list(past_key_values[0][0])[-2]
if inputs["position_ids"] is None: if position_ids is None:
inputs["position_ids"] = tf.expand_dims(tf.range(past_length, input_shape[-1] + past_length), axis=0) position_ids = tf.expand_dims(tf.range(past_length, input_shape[-1] + past_length), axis=0)
if inputs["attention_mask"] is not None: if attention_mask is not None:
# We create a 3D attention mask from a 2D tensor mask. # We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length] # Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention # this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here. # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
attention_mask_shape = shape_list(inputs["attention_mask"]) attention_mask_shape = shape_list(attention_mask)
inputs["attention_mask"] = tf.reshape( attention_mask = tf.reshape(attention_mask, (attention_mask_shape[0], 1, 1, attention_mask_shape[1]))
inputs["attention_mask"], (attention_mask_shape[0], 1, 1, attention_mask_shape[1])
)
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for # masked positions, this operation will create a tensor which is 0.0 for
...@@ -446,78 +427,74 @@ class TFGPTJMainLayer(tf.keras.layers.Layer): ...@@ -446,78 +427,74 @@ class TFGPTJMainLayer(tf.keras.layers.Layer):
# Since we are adding it to the raw scores before the softmax, this is # Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely. # effectively the same as removing these entirely.
one_cst = tf.constant(1.0) one_cst = tf.constant(1.0)
inputs["attention_mask"] = tf.cast(inputs["attention_mask"], dtype=one_cst.dtype) attention_mask = tf.cast(attention_mask, dtype=one_cst.dtype)
inputs["attention_mask"] = tf.multiply( attention_mask = tf.multiply(tf.subtract(one_cst, attention_mask), tf.constant(-10000.0))
tf.subtract(one_cst, inputs["attention_mask"]), tf.constant(-10000.0)
)
# Prepare head mask if needed # Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head # 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N # attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
if inputs["head_mask"] is not None: if head_mask is not None:
raise NotImplementedError raise NotImplementedError
else: else:
inputs["head_mask"] = [None] * self.num_hidden_layers head_mask = [None] * self.num_hidden_layers
# head_mask = tf.constant([0] * self.num_hidden_layers) # head_mask = tf.constant([0] * self.num_hidden_layers)
inputs["position_ids"] = tf.reshape(inputs["position_ids"], [-1, shape_list(inputs["position_ids"])[-1]]) position_ids = tf.reshape(position_ids, [-1, shape_list(position_ids)[-1]])
if inputs["inputs_embeds"] is None: if inputs_embeds is None:
inputs["inputs_embeds"] = self.wte(inputs["input_ids"], mode="embedding") inputs_embeds = self.wte(input_ids, mode="embedding")
if inputs["token_type_ids"] is not None: if token_type_ids is not None:
inputs["token_type_ids"] = tf.reshape( token_type_ids = tf.reshape(token_type_ids, [-1, shape_list(token_type_ids)[-1]])
inputs["token_type_ids"], [-1, shape_list(inputs["token_type_ids"])[-1]] token_type_embeds = self.wte(token_type_ids, mode="embedding")
)
token_type_embeds = self.wte(inputs["token_type_ids"], mode="embedding")
else: else:
token_type_embeds = tf.constant(0.0) token_type_embeds = tf.constant(0.0)
token_type_embeds = tf.cast(token_type_embeds, dtype=inputs["inputs_embeds"].dtype) token_type_embeds = tf.cast(token_type_embeds, dtype=inputs_embeds.dtype)
hidden_states = inputs["inputs_embeds"] + token_type_embeds hidden_states = inputs_embeds + token_type_embeds
hidden_states = self.drop(hidden_states, training=inputs["training"]) hidden_states = self.drop(hidden_states, training=training)
output_shape = input_shape + [shape_list(hidden_states)[-1]] output_shape = input_shape + [shape_list(hidden_states)[-1]]
presents = () if inputs["use_cache"] else None presents = () if use_cache else None
all_attentions = () if inputs["output_attentions"] else None all_attentions = () if output_attentions else None
all_hidden_states = () if inputs["output_hidden_states"] else None all_hidden_states = () if output_hidden_states else None
for i, (block, layer_past) in enumerate(zip(self.h, inputs["past_key_values"])): for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
if inputs["output_hidden_states"]: if output_hidden_states:
all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),) all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),)
outputs = block( outputs = block(
hidden_states, hidden_states,
layer_past, layer_past,
inputs["attention_mask"], attention_mask,
inputs["head_mask"][i], head_mask[i],
inputs["use_cache"], use_cache,
inputs["output_attentions"], output_attentions,
training=inputs["training"], training=training,
) )
hidden_states = outputs[0] hidden_states = outputs[0]
if inputs["use_cache"]: if use_cache:
presents = presents + (outputs[1],) presents = presents + (outputs[1],)
if inputs["output_attentions"]: if output_attentions:
all_attentions = all_attentions + (outputs[2 if inputs["use_cache"] else 1],) all_attentions = all_attentions + (outputs[2 if use_cache else 1],)
hidden_states = self.ln_f(hidden_states) hidden_states = self.ln_f(hidden_states)
hidden_states = tf.reshape(hidden_states, output_shape) hidden_states = tf.reshape(hidden_states, output_shape)
# Add last hidden state # Add last hidden state
if inputs["output_hidden_states"]: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if inputs["output_attentions"]: if output_attentions:
# let the number of heads free (-1) so we can extract attention even after head pruning # let the number of heads free (-1) so we can extract attention even after head pruning
attention_output_shape = input_shape[:-1] + [-1] + shape_list(all_attentions[0])[-2:] attention_output_shape = input_shape[:-1] + [-1] + shape_list(all_attentions[0])[-2:]
all_attentions = tuple(tf.reshape(t, attention_output_shape) for t in all_attentions) all_attentions = tuple(tf.reshape(t, attention_output_shape) for t in all_attentions)
if not inputs["return_dict"]: if not return_dict:
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_attentions] if v is not None) return tuple(v for v in [hidden_states, presents, all_hidden_states, all_attentions] if v is not None)
return TFBaseModelOutputWithPast( return TFBaseModelOutputWithPast(
......
...@@ -30,8 +30,8 @@ from ...modeling_tf_utils import ( ...@@ -30,8 +30,8 @@ from ...modeling_tf_utils import (
TFSharedEmbeddings, TFSharedEmbeddings,
TFWrappedEmbeddings, TFWrappedEmbeddings,
get_initializer, get_initializer,
input_processing,
keras_serializable, keras_serializable,
unpack_inputs,
) )
from ...tf_utils import shape_list from ...tf_utils import shape_list
from ...utils import ( from ...utils import (
...@@ -1654,6 +1654,7 @@ class TFLEDEncoder(tf.keras.layers.Layer): ...@@ -1654,6 +1654,7 @@ class TFLEDEncoder(tf.keras.layers.Layer):
def set_embed_tokens(self, embed_tokens): def set_embed_tokens(self, embed_tokens):
self.embed_tokens = embed_tokens self.embed_tokens = embed_tokens
@unpack_inputs
def call( def call(
self, self,
input_ids=None, input_ids=None,
...@@ -1703,95 +1704,74 @@ class TFLEDEncoder(tf.keras.layers.Layer): ...@@ -1703,95 +1704,74 @@ class TFLEDEncoder(tf.keras.layers.Layer):
return_dict (`bool`, *optional*): return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
""" """
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
global_attention_mask=global_attention_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None: if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif inputs["input_ids"] is not None: elif input_ids is not None:
input_shape = shape_list(inputs["input_ids"]) input_shape = shape_list(input_ids)
inputs["inputs_embeds"] = self.embed_tokens(inputs["input_ids"]) inputs_embeds = self.embed_tokens(input_ids)
elif inputs["inputs_embeds"] is not None: elif inputs_embeds is not None:
input_shape = shape_list(inputs["inputs_embeds"])[:-1] input_shape = shape_list(inputs_embeds)[:-1]
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs["attention_mask"] is None: if attention_mask is None:
inputs["attention_mask"] = tf.fill(input_shape, 1) attention_mask = tf.fill(input_shape, 1)
# merge `global_attention_mask` and `attention_mask` # merge `global_attention_mask` and `attention_mask`
if inputs["global_attention_mask"] is not None: if global_attention_mask is not None:
inputs["attention_mask"] = inputs["attention_mask"] * tf.cast( attention_mask = attention_mask * tf.cast((global_attention_mask + 1), dtype=attention_mask.dtype)
(inputs["global_attention_mask"] + 1), dtype=inputs["attention_mask"].dtype
)
( (padding_len, input_ids, attention_mask, inputs_embeds,) = self._pad_to_window_size(
padding_len, input_ids=input_ids,
inputs["input_ids"], attention_mask=attention_mask,
inputs["attention_mask"], inputs_embeds=inputs_embeds,
inputs["inputs_embeds"],
) = self._pad_to_window_size(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
inputs_embeds=inputs["inputs_embeds"],
pad_token_id=self.padding_idx, pad_token_id=self.padding_idx,
) )
input_shape = shape_list(inputs["attention_mask"]) input_shape = shape_list(attention_mask)
# is index masked or global attention # is index masked or global attention
is_index_masked = tf.math.less(tf.cast(inputs["attention_mask"], tf.int8), 1) is_index_masked = tf.math.less(tf.cast(attention_mask, tf.int8), 1)
is_index_global_attn = tf.math.greater(tf.cast(inputs["attention_mask"], tf.int8), 1) is_index_global_attn = tf.math.greater(tf.cast(attention_mask, tf.int8), 1)
is_global_attn = tf.math.reduce_any(is_index_global_attn) is_global_attn = tf.math.reduce_any(is_index_global_attn)
embed_pos = self.embed_positions(input_shape) embed_pos = self.embed_positions(input_shape)
hidden_states = inputs["inputs_embeds"] + embed_pos hidden_states = inputs_embeds + embed_pos
hidden_states = self.layernorm_embedding(hidden_states) hidden_states = self.layernorm_embedding(hidden_states)
hidden_states = self.dropout(hidden_states, training=inputs["training"]) hidden_states = self.dropout(hidden_states, training=training)
# check attention mask and invert # check attention mask and invert
if inputs["attention_mask"] is not None: if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
inputs["attention_mask"] = _expand_mask(inputs["attention_mask"])[:, 0, 0, :] attention_mask = _expand_mask(attention_mask)[:, 0, 0, :]
inputs["attention_mask"] = inputs["attention_mask"][:, :, None, None] attention_mask = attention_mask[:, :, None, None]
encoder_states = () if inputs["output_hidden_states"] else None encoder_states = () if output_hidden_states else None
all_attentions = all_global_attentions = () if inputs["output_attentions"] else None all_attentions = all_global_attentions = () if output_attentions else None
# check if head_mask has a correct number of layers specified if desired # check if head_mask has a correct number of layers specified if desired
if inputs["head_mask"] is not None and tf.executing_eagerly(): if head_mask is not None and tf.executing_eagerly():
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(inputs["head_mask"])[0], shape_list(head_mask)[0],
len(self.layers), len(self.layers),
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.", message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(head_mask)[0]}.",
) )
# encoder layers # encoder layers
for idx, encoder_layer in enumerate(self.layers): for idx, encoder_layer in enumerate(self.layers):
if inputs["output_hidden_states"]: if output_hidden_states:
hidden_states_to_add = self.compute_hidden_states(hidden_states, padding_len) hidden_states_to_add = self.compute_hidden_states(hidden_states, padding_len)
encoder_states = encoder_states + (hidden_states_to_add,) encoder_states = encoder_states + (hidden_states_to_add,)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
dropout_probability = random.uniform(0, 1) dropout_probability = random.uniform(0, 1)
if inputs["training"] and (dropout_probability < self.layerdrop): # skip the layer if training and (dropout_probability < self.layerdrop): # skip the layer
continue continue
layer_outputs = encoder_layer( layer_outputs = encoder_layer(
hidden_states=hidden_states, hidden_states=hidden_states,
attention_mask=inputs["attention_mask"], attention_mask=attention_mask,
layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None, layer_head_mask=head_mask[idx] if head_mask is not None else None,
is_index_masked=is_index_masked, is_index_masked=is_index_masked,
is_index_global_attn=is_index_global_attn, is_index_global_attn=is_index_global_attn,
is_global_attn=is_global_attn, is_global_attn=is_global_attn,
...@@ -1799,7 +1779,7 @@ class TFLEDEncoder(tf.keras.layers.Layer): ...@@ -1799,7 +1779,7 @@ class TFLEDEncoder(tf.keras.layers.Layer):
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if inputs["output_attentions"]: if output_attentions:
# bzs x seq_len x num_attn_heads x (num_global_attn + attention_window_len + 1) => bzs x num_attn_heads x seq_len x (num_global_attn + attention_window_len + 1) # bzs x seq_len x num_attn_heads x (num_global_attn + attention_window_len + 1) => bzs x num_attn_heads x seq_len x (num_global_attn + attention_window_len + 1)
all_attentions = all_attentions + (tf.transpose(layer_outputs[1], (0, 2, 1, 3)),) all_attentions = all_attentions + (tf.transpose(layer_outputs[1], (0, 2, 1, 3)),)
...@@ -1811,17 +1791,17 @@ class TFLEDEncoder(tf.keras.layers.Layer): ...@@ -1811,17 +1791,17 @@ class TFLEDEncoder(tf.keras.layers.Layer):
hidden_states = self.compute_hidden_states(hidden_states, padding_len) hidden_states = self.compute_hidden_states(hidden_states, padding_len)
# undo padding # undo padding
if inputs["output_attentions"]: if output_attentions:
all_attentions = ( all_attentions = (
tuple([state[:, :, :-padding_len, :] for state in all_attentions]) tuple([state[:, :, :-padding_len, :] for state in all_attentions])
if padding_len > 0 if padding_len > 0
else all_attentions else all_attentions
) )
if inputs["output_hidden_states"]: if output_hidden_states:
encoder_states = encoder_states + (hidden_states,) encoder_states = encoder_states + (hidden_states,)
if not inputs["return_dict"]: if not return_dict:
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
return TFLEDEncoderBaseModelOutput( return TFLEDEncoderBaseModelOutput(
last_hidden_state=hidden_states, last_hidden_state=hidden_states,
...@@ -1915,6 +1895,7 @@ class TFLEDDecoder(tf.keras.layers.Layer): ...@@ -1915,6 +1895,7 @@ class TFLEDDecoder(tf.keras.layers.Layer):
def set_embed_tokens(self, embed_tokens): def set_embed_tokens(self, embed_tokens):
self.embed_tokens = embed_tokens self.embed_tokens = embed_tokens
@unpack_inputs
def call( def call(
self, self,
input_ids=None, input_ids=None,
...@@ -1985,45 +1966,25 @@ class TFLEDDecoder(tf.keras.layers.Layer): ...@@ -1985,45 +1966,25 @@ class TFLEDDecoder(tf.keras.layers.Layer):
return_dict (`bool`, *optional*): return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
""" """
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
head_mask=head_mask,
encoder_head_mask=encoder_head_mask,
inputs_embeds=inputs_embeds,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None: if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif inputs["input_ids"] is not None: elif input_ids is not None:
input_shape = shape_list(inputs["input_ids"]) input_shape = shape_list(input_ids)
elif inputs["inputs_embeds"] is not None: elif inputs_embeds is not None:
input_shape = shape_list(inputs["inputs_embeds"])[:-1] input_shape = shape_list(inputs_embeds)[:-1]
else: else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
past_key_values_length = ( past_key_values_length = shape_list(past_key_values[0][0])[2] if past_key_values is not None else 0
shape_list(inputs["past_key_values"][0][0])[2] if inputs["past_key_values"] is not None else 0
)
# embed positions # embed positions
positions = self.embed_positions(input_shape, past_key_values_length) positions = self.embed_positions(input_shape, past_key_values_length)
if inputs["inputs_embeds"] is None: if inputs_embeds is None:
inputs["inputs_embeds"] = self.embed_tokens(inputs["input_ids"]) inputs_embeds = self.embed_tokens(input_ids)
hidden_states = inputs["inputs_embeds"] hidden_states = inputs_embeds
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
if input_shape[-1] > 1: if input_shape[-1] > 1:
...@@ -2033,17 +1994,15 @@ class TFLEDDecoder(tf.keras.layers.Layer): ...@@ -2033,17 +1994,15 @@ class TFLEDDecoder(tf.keras.layers.Layer):
tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1] tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1]
) )
if inputs["attention_mask"] is not None and input_shape[-1] > 1: if attention_mask is not None and input_shape[-1] > 1:
combined_attention_mask = combined_attention_mask + _expand_mask( combined_attention_mask = combined_attention_mask + _expand_mask(attention_mask, tgt_len=input_shape[-1])
inputs["attention_mask"], tgt_len=input_shape[-1]
)
if inputs["encoder_hidden_states"] is not None and inputs["encoder_attention_mask"] is not None: if encoder_hidden_states is not None and encoder_attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
inputs["encoder_attention_mask"] = _expand_mask(inputs["encoder_attention_mask"], tgt_len=input_shape[-1]) encoder_attention_mask = _expand_mask(encoder_attention_mask, tgt_len=input_shape[-1])
hidden_states = self.layernorm_embedding(hidden_states + positions) hidden_states = self.layernorm_embedding(hidden_states + positions)
hidden_states = self.dropout(hidden_states, training=inputs["training"]) hidden_states = self.dropout(hidden_states, training=training)
# decoder layers # decoder layers
all_hidden_states = () all_hidden_states = ()
...@@ -2052,54 +2011,52 @@ class TFLEDDecoder(tf.keras.layers.Layer): ...@@ -2052,54 +2011,52 @@ class TFLEDDecoder(tf.keras.layers.Layer):
present_key_values = () present_key_values = ()
# check if head_mask has a correct number of layers specified if desired # check if head_mask has a correct number of layers specified if desired
if inputs["head_mask"] is not None and tf.executing_eagerly(): if head_mask is not None and tf.executing_eagerly():
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(inputs["head_mask"])[0], shape_list(head_mask)[0],
len(self.layers), len(self.layers),
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.", message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(head_mask)[0]}.",
) )
for idx, decoder_layer in enumerate(self.layers): for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if inputs["output_hidden_states"]: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
dropout_probability = random.uniform(0, 1) dropout_probability = random.uniform(0, 1)
if inputs["training"] and (dropout_probability < self.layerdrop): if training and (dropout_probability < self.layerdrop):
continue continue
past_key_value = inputs["past_key_values"][idx] if inputs["past_key_values"] is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
hidden_states, layer_self_attn, layer_cross_attn, present_key_value = decoder_layer( hidden_states, layer_self_attn, layer_cross_attn, present_key_value = decoder_layer(
hidden_states, hidden_states,
attention_mask=combined_attention_mask, attention_mask=combined_attention_mask,
encoder_hidden_states=inputs["encoder_hidden_states"], encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=inputs["encoder_attention_mask"], encoder_attention_mask=encoder_attention_mask,
layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None, layer_head_mask=head_mask[idx] if head_mask is not None else None,
encoder_layer_head_mask=inputs["encoder_head_mask"][idx] encoder_layer_head_mask=encoder_head_mask[idx] if encoder_head_mask is not None else None,
if inputs["encoder_head_mask"] is not None
else None,
past_key_value=past_key_value, past_key_value=past_key_value,
) )
if inputs["use_cache"]: if use_cache:
present_key_values += (present_key_value,) present_key_values += (present_key_value,)
if inputs["output_attentions"]: if output_attentions:
all_self_attns += (layer_self_attn,) all_self_attns += (layer_self_attn,)
all_cross_attentions += (layer_cross_attn,) all_cross_attentions += (layer_cross_attn,)
if inputs["output_hidden_states"]: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
else: else:
all_hidden_states = None all_hidden_states = None
all_self_attns = all_self_attns if inputs["output_attentions"] else None all_self_attns = all_self_attns if output_attentions else None
all_cross_attentions = all_cross_attentions if inputs["output_attentions"] else None all_cross_attentions = all_cross_attentions if output_attentions else None
present_key_values = present_key_values if inputs["use_cache"] else None present_key_values = present_key_values if use_cache else None
if not inputs["return_dict"]: if not return_dict:
return tuple( return tuple(
v v
for v in [hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attentions] for v in [hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attentions]
...@@ -2149,6 +2106,7 @@ class TFLEDMainLayer(tf.keras.layers.Layer): ...@@ -2149,6 +2106,7 @@ class TFLEDMainLayer(tf.keras.layers.Layer):
self.encoder.set_embed_tokens(embed_tokens) self.encoder.set_embed_tokens(embed_tokens)
self.decoder.set_embed_tokens(embed_tokens) self.decoder.set_embed_tokens(embed_tokens)
@unpack_inputs
def call( def call(
self, self,
input_ids=None, input_ids=None,
...@@ -2169,72 +2127,51 @@ class TFLEDMainLayer(tf.keras.layers.Layer): ...@@ -2169,72 +2127,51 @@ class TFLEDMainLayer(tf.keras.layers.Layer):
training=False, training=False,
**kwargs **kwargs
): ):
inputs = input_processing(
func=self.call, if decoder_input_ids is None and decoder_inputs_embeds is None:
config=self.config, use_cache = False
if encoder_outputs is None:
encoder_outputs = self.encoder(
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
encoder_outputs=encoder_outputs,
global_attention_mask=global_attention_mask, global_attention_mask=global_attention_mask,
past_key_values=past_key_values, head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
training=training, training=training,
kwargs_call=kwargs,
)
if inputs["decoder_input_ids"] is None and inputs["decoder_inputs_embeds"] is None:
inputs["use_cache"] = False
if inputs["encoder_outputs"] is None:
inputs["encoder_outputs"] = self.encoder(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
global_attention_mask=inputs["global_attention_mask"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
) )
# If the user passed a tuple for encoder_outputs, we wrap it in a TFLEDEncoderBaseModelOutput when return_dict=True # If the user passed a tuple for encoder_outputs, we wrap it in a TFLEDEncoderBaseModelOutput when return_dict=True
elif inputs["return_dict"] and not isinstance(inputs["encoder_outputs"], TFLEDEncoderBaseModelOutput): elif return_dict and not isinstance(encoder_outputs, TFLEDEncoderBaseModelOutput):
inputs["encoder_outputs"] = TFLEDEncoderBaseModelOutput( encoder_outputs = TFLEDEncoderBaseModelOutput(
last_hidden_state=inputs["encoder_outputs"][0], last_hidden_state=encoder_outputs[0],
hidden_states=inputs["encoder_outputs"][1] if len(inputs["encoder_outputs"]) > 1 else None, hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
attentions=inputs["encoder_outputs"][2] if len(inputs["encoder_outputs"]) > 2 else None, attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
) )
# If the user passed a TFLEDEncoderBaseModelOutput for encoder_outputs, we wrap it in a tuple when return_dict=False # If the user passed a TFLEDEncoderBaseModelOutput for encoder_outputs, we wrap it in a tuple when return_dict=False
elif not inputs["return_dict"] and not isinstance(inputs["encoder_outputs"], tuple): elif not return_dict and not isinstance(encoder_outputs, tuple):
inputs["encoder_outputs"] = inputs["encoder_outputs"].to_tuple() encoder_outputs = encoder_outputs.to_tuple()
decoder_outputs = self.decoder( decoder_outputs = self.decoder(
inputs["decoder_input_ids"], decoder_input_ids,
attention_mask=inputs["decoder_attention_mask"], attention_mask=decoder_attention_mask,
encoder_hidden_states=inputs["encoder_outputs"][0], encoder_hidden_states=encoder_outputs[0],
encoder_attention_mask=inputs["attention_mask"], encoder_attention_mask=attention_mask,
head_mask=inputs["decoder_head_mask"], head_mask=decoder_head_mask,
encoder_head_mask=inputs["head_mask"], encoder_head_mask=head_mask,
past_key_values=inputs["past_key_values"], past_key_values=past_key_values,
inputs_embeds=inputs["decoder_inputs_embeds"], inputs_embeds=decoder_inputs_embeds,
use_cache=inputs["use_cache"], use_cache=use_cache,
output_attentions=inputs["output_attentions"], output_attentions=output_attentions,
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=output_hidden_states,
return_dict=inputs["return_dict"], return_dict=return_dict,
training=inputs["training"], training=training,
) )
if not inputs["return_dict"]: if not return_dict:
return decoder_outputs + inputs["encoder_outputs"] return decoder_outputs + encoder_outputs
return TFLEDSeq2SeqModelOutput( return TFLEDSeq2SeqModelOutput(
last_hidden_state=decoder_outputs.last_hidden_state, last_hidden_state=decoder_outputs.last_hidden_state,
...@@ -2242,10 +2179,10 @@ class TFLEDMainLayer(tf.keras.layers.Layer): ...@@ -2242,10 +2179,10 @@ class TFLEDMainLayer(tf.keras.layers.Layer):
decoder_hidden_states=decoder_outputs.hidden_states, decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions, decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions, cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state, encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=inputs["encoder_outputs"].hidden_states, encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=inputs["encoder_outputs"].attentions, encoder_attentions=encoder_outputs.attentions,
encoder_global_attentions=inputs["encoder_outputs"].global_attentions, encoder_global_attentions=encoder_outputs.global_attentions,
) )
...@@ -2265,6 +2202,7 @@ class TFLEDModel(TFLEDPreTrainedModel): ...@@ -2265,6 +2202,7 @@ class TFLEDModel(TFLEDPreTrainedModel):
def get_decoder(self): def get_decoder(self):
return self.led.decoder return self.led.decoder
@unpack_inputs
@add_start_docstrings_to_model_forward(LED_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(LED_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC, processor_class=_TOKENIZER_FOR_DOC,
...@@ -2292,17 +2230,16 @@ class TFLEDModel(TFLEDPreTrainedModel): ...@@ -2292,17 +2230,16 @@ class TFLEDModel(TFLEDPreTrainedModel):
training=False, training=False,
**kwargs **kwargs
): ):
inputs = input_processing(
func=self.call, outputs = self.led(
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
global_attention_mask=global_attention_mask, global_attention_mask=global_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds,
...@@ -2311,25 +2248,6 @@ class TFLEDModel(TFLEDPreTrainedModel): ...@@ -2311,25 +2248,6 @@ class TFLEDModel(TFLEDPreTrainedModel):
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
training=training, training=training,
kwargs_call=kwargs,
)
outputs = self.led(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
decoder_input_ids=inputs["decoder_input_ids"],
decoder_attention_mask=inputs["decoder_attention_mask"],
encoder_outputs=inputs["encoder_outputs"],
global_attention_mask=inputs["global_attention_mask"],
head_mask=inputs["head_mask"],
decoder_head_mask=inputs["decoder_head_mask"],
past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["inputs_embeds"],
decoder_inputs_embeds=inputs["decoder_inputs_embeds"],
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
) )
return outputs return outputs
...@@ -2393,6 +2311,7 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel): ...@@ -2393,6 +2311,7 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel):
def set_output_embeddings(self, value): def set_output_embeddings(self, value):
self.set_input_embeddings(value) self.set_input_embeddings(value)
@unpack_inputs
@add_start_docstrings_to_model_forward(LED_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(LED_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFLEDSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=TFLEDSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
def call( def call(
...@@ -2435,17 +2354,22 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel): ...@@ -2435,17 +2354,22 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel):
>>> # probs[5] is associated with the mask token >>> # probs[5] is associated with the mask token
```""" ```"""
inputs = input_processing( if labels is not None:
func=self.call, use_cache = False
config=self.config, if decoder_input_ids is None:
input_ids=input_ids, decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id
)
outputs = self.led(
input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
global_attention_mask=global_attention_mask, global_attention_mask=global_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds,
...@@ -2453,41 +2377,13 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel): ...@@ -2453,41 +2377,13 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel):
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
labels=labels,
training=training, training=training,
kwargs_call=kwargs,
)
if inputs["labels"] is not None:
inputs["use_cache"] = False
if inputs["decoder_input_ids"] is None:
inputs["decoder_input_ids"] = shift_tokens_right(
inputs["labels"], self.config.pad_token_id, self.config.decoder_start_token_id
)
outputs = self.led(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
decoder_input_ids=inputs["decoder_input_ids"],
decoder_attention_mask=inputs["decoder_attention_mask"],
encoder_outputs=inputs["encoder_outputs"],
global_attention_mask=inputs["global_attention_mask"],
head_mask=inputs["head_mask"],
decoder_head_mask=inputs["decoder_head_mask"],
past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["inputs_embeds"],
decoder_inputs_embeds=inputs["decoder_inputs_embeds"],
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
) )
lm_logits = self.led.shared(outputs[0], mode="linear") lm_logits = self.led.shared(outputs[0], mode="linear")
lm_logits = lm_logits + self.final_logits_bias lm_logits = lm_logits + self.final_logits_bias
masked_lm_loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], lm_logits) masked_lm_loss = None if labels is None else self.hf_compute_loss(labels, lm_logits)
if not inputs["return_dict"]: if not return_dict:
output = (lm_logits,) + outputs[1:] output = (lm_logits,) + outputs[1:]
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
return TFLEDSeq2SeqLMOutput( return TFLEDSeq2SeqLMOutput(
......
...@@ -965,12 +965,12 @@ class TFMBartDecoder(tf.keras.layers.Layer): ...@@ -965,12 +965,12 @@ class TFMBartDecoder(tf.keras.layers.Layer):
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired # check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they # The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager. # have to be disabled in other modes than eager.
for attn_mask in [head_mask, cross_attn_head_mask]: for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]:
if attn_mask is not None and tf.executing_eagerly(): if attn_mask is not None and tf.executing_eagerly():
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(attn_mask)[0], shape_list(attn_mask)[0],
len(self.layers), len(self.layers),
message=f"The {attn_mask} should be specified for {len(self.layers)} layers, but it is for {shape_list(attn_mask)[0]}.", message=f"The {attn_mask_name} should be specified for {len(self.layers)} layers, but it is for {shape_list(attn_mask)[0]}.",
) )
for idx, decoder_layer in enumerate(self.layers): for idx, decoder_layer in enumerate(self.layers):
......
...@@ -1060,12 +1060,12 @@ class TFSpeech2TextDecoder(tf.keras.layers.Layer): ...@@ -1060,12 +1060,12 @@ class TFSpeech2TextDecoder(tf.keras.layers.Layer):
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired # check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they have to be disabled in other modes than eager. # The tf.debugging asserts are not compliant with XLA then they have to be disabled in other modes than eager.
for attn_mask in [head_mask, cross_attn_head_mask]: for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]:
if attn_mask is not None and tf.executing_eagerly(): if attn_mask is not None and tf.executing_eagerly():
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(attn_mask)[0], shape_list(attn_mask)[0],
len(self.layers), len(self.layers),
message=f"The {attn_mask} should be specified for {len(self.layers)} layers, but it is for {shape_list(attn_mask)[0]}.", message=f"The {attn_mask_name} should be specified for {len(self.layers)} layers, but it is for {shape_list(attn_mask)[0]}.",
) )
for idx, decoder_layer in enumerate(self.layers): for idx, decoder_layer in enumerate(self.layers):
......
...@@ -50,8 +50,8 @@ from ...modeling_tf_utils import ( ...@@ -50,8 +50,8 @@ from ...modeling_tf_utils import (
TFSequenceSummary, TFSequenceSummary,
TFTokenClassificationLoss, TFTokenClassificationLoss,
get_initializer, get_initializer,
input_processing,
keras_serializable, keras_serializable,
unpack_inputs,
) )
from ...tf_utils import shape_list from ...tf_utils import shape_list
from ...utils import logging from ...utils import logging
...@@ -636,6 +636,7 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer): ...@@ -636,6 +636,7 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
""" """
raise NotImplementedError raise NotImplementedError
@unpack_inputs
def call( def call(
self, self,
input_ids: Optional[TFModelInputType] = None, input_ids: Optional[TFModelInputType] = None,
...@@ -654,59 +655,40 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer): ...@@ -654,59 +655,40 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
training: bool = False, training: bool = False,
**kwargs, **kwargs,
) -> Union[TFBaseModelOutputWithPastAndCrossAttentions, Tuple[tf.Tensor]]: ) -> Union[TFBaseModelOutputWithPastAndCrossAttentions, Tuple[tf.Tensor]]:
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
if not self.config.is_decoder: if not self.config.is_decoder:
inputs["use_cache"] = False use_cache = False
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None: if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif inputs["input_ids"] is not None: elif input_ids is not None:
input_shape = shape_list(inputs["input_ids"]) input_shape = shape_list(input_ids)
elif inputs["inputs_embeds"] is not None: elif inputs_embeds is not None:
input_shape = shape_list(inputs["inputs_embeds"])[:-1] input_shape = shape_list(inputs_embeds)[:-1]
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
batch_size, seq_length = input_shape batch_size, seq_length = input_shape
if inputs["past_key_values"] is None: if past_key_values is None:
past_key_values_length = 0 past_key_values_length = 0
inputs["past_key_values"] = [None] * len(self.encoder.layer) past_key_values = [None] * len(self.encoder.layer)
else: else:
past_key_values_length = shape_list(inputs["past_key_values"][0][0])[-2] past_key_values_length = shape_list(past_key_values[0][0])[-2]
if inputs["attention_mask"] is None: if attention_mask is None:
inputs["attention_mask"] = tf.fill(dims=(batch_size, seq_length + past_key_values_length), value=1) attention_mask = tf.fill(dims=(batch_size, seq_length + past_key_values_length), value=1)
if inputs["token_type_ids"] is None: if token_type_ids is None:
inputs["token_type_ids"] = tf.fill(dims=input_shape, value=0) token_type_ids = tf.fill(dims=input_shape, value=0)
embedding_output = self.embeddings( embedding_output = self.embeddings(
input_ids=inputs["input_ids"], input_ids=input_ids,
position_ids=inputs["position_ids"], position_ids=position_ids,
token_type_ids=inputs["token_type_ids"], token_type_ids=token_type_ids,
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs_embeds,
past_key_values_length=past_key_values_length, past_key_values_length=past_key_values_length,
training=inputs["training"], training=training,
) )
# We create a 3D attention mask from a 2D tensor mask. # We create a 3D attention mask from a 2D tensor mask.
...@@ -714,7 +696,7 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer): ...@@ -714,7 +696,7 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention # this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here. # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
attention_mask_shape = shape_list(inputs["attention_mask"]) attention_mask_shape = shape_list(attention_mask)
mask_seq_length = seq_length + past_key_values_length mask_seq_length = seq_length + past_key_values_length
# Copied from `modeling_tf_t5.py` # Copied from `modeling_tf_t5.py`
...@@ -727,18 +709,18 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer): ...@@ -727,18 +709,18 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
tf.tile(seq_ids[None, None, :], (batch_size, mask_seq_length, 1)), tf.tile(seq_ids[None, None, :], (batch_size, mask_seq_length, 1)),
seq_ids[None, :, None], seq_ids[None, :, None],
) )
causal_mask = tf.cast(causal_mask, dtype=inputs["attention_mask"].dtype) causal_mask = tf.cast(causal_mask, dtype=attention_mask.dtype)
extended_attention_mask = causal_mask * inputs["attention_mask"][:, None, :] extended_attention_mask = causal_mask * attention_mask[:, None, :]
attention_mask_shape = shape_list(extended_attention_mask) attention_mask_shape = shape_list(extended_attention_mask)
extended_attention_mask = tf.reshape( extended_attention_mask = tf.reshape(
extended_attention_mask, (attention_mask_shape[0], 1, attention_mask_shape[1], attention_mask_shape[2]) extended_attention_mask, (attention_mask_shape[0], 1, attention_mask_shape[1], attention_mask_shape[2])
) )
if inputs["past_key_values"][0] is not None: if past_key_values[0] is not None:
# attention_mask needs to be sliced to the shape `[batch_size, 1, from_seq_length - cached_seq_length, to_seq_length] # attention_mask needs to be sliced to the shape `[batch_size, 1, from_seq_length - cached_seq_length, to_seq_length]
extended_attention_mask = extended_attention_mask[:, :, -seq_length:, :] extended_attention_mask = extended_attention_mask[:, :, -seq_length:, :]
else: else:
extended_attention_mask = tf.reshape( extended_attention_mask = tf.reshape(
inputs["attention_mask"], (attention_mask_shape[0], 1, 1, attention_mask_shape[1]) attention_mask, (attention_mask_shape[0], 1, 1, attention_mask_shape[1])
) )
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
...@@ -752,18 +734,18 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer): ...@@ -752,18 +734,18 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst) extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst)
# Copied from `modeling_tf_t5.py` with -1e9 -> -10000 # Copied from `modeling_tf_t5.py` with -1e9 -> -10000
if self.is_decoder and inputs["encoder_attention_mask"] is not None: if self.is_decoder and encoder_attention_mask is not None:
# If a 2D ou 3D attention mask is provided for the cross-attention # If a 2D ou 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length] # we need to make broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length]
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
inputs["encoder_attention_mask"] = tf.cast( encoder_attention_mask = tf.cast(
inputs["encoder_attention_mask"], dtype=extended_attention_mask.dtype encoder_attention_mask, dtype=extended_attention_mask.dtype
) )
num_dims_encoder_attention_mask = len(shape_list(inputs["encoder_attention_mask"])) num_dims_encoder_attention_mask = len(shape_list(encoder_attention_mask))
if num_dims_encoder_attention_mask == 3: if num_dims_encoder_attention_mask == 3:
encoder_extended_attention_mask = inputs["encoder_attention_mask"][:, None, :, :] encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
if num_dims_encoder_attention_mask == 2: if num_dims_encoder_attention_mask == 2:
encoder_extended_attention_mask = inputs["encoder_attention_mask"][:, None, None, :] encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
# T5 has a mask that can compare sequence ids, we can simulate this here with this transposition # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition
# Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270 # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270
...@@ -779,28 +761,28 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer): ...@@ -779,28 +761,28 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
# attention_probs has shape bsz x n_heads x N x N # attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
if inputs["head_mask"] is not None: if head_mask is not None:
raise NotImplementedError raise NotImplementedError
else: else:
inputs["head_mask"] = [None] * self.config.num_hidden_layers head_mask = [None] * self.config.num_hidden_layers
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
hidden_states=embedding_output, hidden_states=embedding_output,
attention_mask=extended_attention_mask, attention_mask=extended_attention_mask,
head_mask=inputs["head_mask"], head_mask=head_mask,
encoder_hidden_states=inputs["encoder_hidden_states"], encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask, encoder_attention_mask=encoder_extended_attention_mask,
past_key_values=inputs["past_key_values"], past_key_values=past_key_values,
use_cache=inputs["use_cache"], use_cache=use_cache,
output_attentions=inputs["output_attentions"], output_attentions=output_attentions,
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=output_hidden_states,
return_dict=inputs["return_dict"], return_dict=return_dict,
training=inputs["training"], training=training,
) )
sequence_output = encoder_outputs[0] sequence_output = encoder_outputs[0]
if not inputs["return_dict"]: if not return_dict:
return ( return (
sequence_output, sequence_output,
) + encoder_outputs[1:] ) + encoder_outputs[1:]
...@@ -943,6 +925,7 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod ...@@ -943,6 +925,7 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod
self.{{cookiecutter.lowercase_modelname}} = TF{{cookiecutter.camelcase_modelname}}MainLayer(config, name="{{cookiecutter.lowercase_modelname}}") self.{{cookiecutter.lowercase_modelname}} = TF{{cookiecutter.camelcase_modelname}}MainLayer(config, name="{{cookiecutter.lowercase_modelname}}")
@unpack_inputs
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC, processor_class=_TOKENIZER_FOR_DOC,
...@@ -988,9 +971,7 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod ...@@ -988,9 +971,7 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod
If set to `True`, `past_key_values` key value states are returned and can be used to speed up If set to `True`, `past_key_values` key value states are returned and can be used to speed up
decoding (see `past_key_values`). Set to `False` during training, `True` during generation decoding (see `past_key_values`). Set to `False` during training, `True` during generation
""" """
inputs = input_processing( outputs = self.{{cookiecutter.lowercase_modelname}}(
func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -1005,23 +986,6 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod ...@@ -1005,23 +986,6 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
training=training, training=training,
kwargs_call=kwargs,
)
outputs = self.{{cookiecutter.lowercase_modelname}}(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
encoder_hidden_states=inputs["encoder_hidden_states"],
encoder_attention_mask=inputs["encoder_attention_mask"],
past_key_values=inputs["past_key_values"],
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
) )
return outputs return outputs
...@@ -1064,6 +1028,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForMaskedLM(TF{{cookiecutter.camelca ...@@ -1064,6 +1028,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForMaskedLM(TF{{cookiecutter.camelca
def get_lm_head(self) -> tf.keras.layers.Layer: def get_lm_head(self) -> tf.keras.layers.Layer:
return self.mlm.predictions return self.mlm.predictions
@unpack_inputs
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC, processor_class=_TOKENIZER_FOR_DOC,
...@@ -1091,9 +1056,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForMaskedLM(TF{{cookiecutter.camelca ...@@ -1091,9 +1056,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForMaskedLM(TF{{cookiecutter.camelca
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
""" """
inputs = input_processing( outputs = self.{{cookiecutter.lowercase_modelname}}(
func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -1103,29 +1066,15 @@ class TF{{cookiecutter.camelcase_modelname}}ForMaskedLM(TF{{cookiecutter.camelca ...@@ -1103,29 +1066,15 @@ class TF{{cookiecutter.camelcase_modelname}}ForMaskedLM(TF{{cookiecutter.camelca
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
labels=labels,
training=training, training=training,
kwargs_call=kwargs,
)
outputs = self.{{cookiecutter.lowercase_modelname}}(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
) )
sequence_output = outputs[0] sequence_output = outputs[0]
prediction_scores = self.mlm(sequence_output=sequence_output, training=inputs["training"]) prediction_scores = self.mlm(sequence_output=sequence_output, training=training)
loss = ( loss = (
None if inputs["labels"] is None else self.hf_compute_loss(labels=inputs["labels"], logits=prediction_scores) None if labels is None else self.hf_compute_loss(labels=labels, logits=prediction_scores)
) )
if not inputs["return_dict"]: if not return_dict:
output = (prediction_scores,) + outputs[2:] output = (prediction_scores,) + outputs[2:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
...@@ -1173,6 +1122,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForCausalLM(TF{{cookiecutter.camelca ...@@ -1173,6 +1122,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForCausalLM(TF{{cookiecutter.camelca
"use_cache": model_kwargs["use_cache"], "use_cache": model_kwargs["use_cache"],
} }
@unpack_inputs
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC, processor_class=_TOKENIZER_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC, checkpoint=_CHECKPOINT_FOR_DOC,
...@@ -1220,9 +1170,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForCausalLM(TF{{cookiecutter.camelca ...@@ -1220,9 +1170,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForCausalLM(TF{{cookiecutter.camelca
labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size, sequence_length)`, *optional*): labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the cross entropy classification loss. Indices should be in `[0, ..., config.vocab_size - 1]`. Labels for computing the cross entropy classification loss. Indices should be in `[0, ..., config.vocab_size - 1]`.
""" """
inputs = input_processing( outputs = self.{{cookiecutter.lowercase_modelname}}(
func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -1236,37 +1184,19 @@ class TF{{cookiecutter.camelcase_modelname}}ForCausalLM(TF{{cookiecutter.camelca ...@@ -1236,37 +1184,19 @@ class TF{{cookiecutter.camelcase_modelname}}ForCausalLM(TF{{cookiecutter.camelca
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
labels=labels,
training=training, training=training,
kwargs_call=kwargs,
)
outputs = self.{{cookiecutter.lowercase_modelname}}(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
encoder_hidden_states=inputs["encoder_hidden_states"],
encoder_attention_mask=inputs["encoder_attention_mask"],
past_key_values=inputs["past_key_values"],
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
) )
sequence_output = outputs[0] sequence_output = outputs[0]
logits = self.mlm(sequence_output=sequence_output, training=inputs["training"]) logits = self.mlm(sequence_output=sequence_output, training=training)
loss = None loss = None
if inputs["labels"] is not None: if labels is not None:
# shift labels to the left and cut last logit token # shift labels to the left and cut last logit token
shifted_logits = logits[:, :-1] shifted_logits = logits[:, :-1]
labels = inputs["labels"][:, 1:] labels = labels[:, 1:]
loss = self.hf_compute_loss(labels=labels, logits=shifted_logits) loss = self.hf_compute_loss(labels=labels, logits=shifted_logits)
if not inputs["return_dict"]: if not return_dict:
output = (logits,) + outputs[2:] output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
...@@ -1338,6 +1268,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForSequenceClassification(TF{{cookie ...@@ -1338,6 +1268,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForSequenceClassification(TF{{cookie
self.{{cookiecutter.lowercase_modelname}} = TF{{cookiecutter.camelcase_modelname}}MainLayer(config, name="{{cookiecutter.lowercase_modelname}}") self.{{cookiecutter.lowercase_modelname}} = TF{{cookiecutter.camelcase_modelname}}MainLayer(config, name="{{cookiecutter.lowercase_modelname}}")
self.classifier = TF{{cookiecutter.camelcase_modelname}}ClassificationHead(config, name="classifier") self.classifier = TF{{cookiecutter.camelcase_modelname}}ClassificationHead(config, name="classifier")
@unpack_inputs
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC, processor_class=_TOKENIZER_FOR_DOC,
...@@ -1365,9 +1296,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForSequenceClassification(TF{{cookie ...@@ -1365,9 +1296,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForSequenceClassification(TF{{cookie
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss),
If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
""" """
inputs = input_processing( outputs = self.{{cookiecutter.lowercase_modelname}}(
func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -1377,26 +1306,12 @@ class TF{{cookiecutter.camelcase_modelname}}ForSequenceClassification(TF{{cookie ...@@ -1377,26 +1306,12 @@ class TF{{cookiecutter.camelcase_modelname}}ForSequenceClassification(TF{{cookie
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
labels=labels,
training=training, training=training,
kwargs_call=kwargs,
) )
outputs = self.{{cookiecutter.lowercase_modelname}}( logits = self.classifier(hidden_states=outputs[0], training=training)
input_ids=inputs["input_ids"], loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits)
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"], if not return_dict:
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
logits = self.classifier(hidden_states=outputs[0], training=inputs["training"])
loss = None if inputs["labels"] is None else self.hf_compute_loss(labels=inputs["labels"], logits=logits)
if not inputs["return_dict"]:
output = (logits,) + outputs[1:] output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
...@@ -1443,6 +1358,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice(TF{{cookiecutter.c ...@@ -1443,6 +1358,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice(TF{{cookiecutter.c
""" """
return {"input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS)} return {"input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS)}
@unpack_inputs
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")) @add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC, processor_class=_TOKENIZER_FOR_DOC,
...@@ -1470,53 +1386,37 @@ class TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice(TF{{cookiecutter.c ...@@ -1470,53 +1386,37 @@ class TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice(TF{{cookiecutter.c
Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]` where `num_choices` is the size of the second dimension of the input tensors. (See Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]` where `num_choices` is the size of the second dimension of the input tensors. (See
`input_ids` above) `input_ids` above)
""" """
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
if inputs["input_ids"] is not None: if input_ids is not None:
num_choices = shape_list(inputs["input_ids"])[1] num_choices = shape_list(input_ids)[1]
seq_length = shape_list(inputs["input_ids"])[2] seq_length = shape_list(input_ids)[2]
else: else:
num_choices = shape_list(inputs["inputs_embeds"])[1] num_choices = shape_list(inputs_embeds)[1]
seq_length = shape_list(inputs["inputs_embeds"])[2] seq_length = shape_list(inputs_embeds)[2]
flat_input_ids = ( flat_input_ids = (
tf.reshape(tensor=inputs["input_ids"], shape=(-1, seq_length)) if inputs["input_ids"] is not None else None tf.reshape(tensor=input_ids, shape=(-1, seq_length)) if input_ids is not None else None
) )
flat_attention_mask = ( flat_attention_mask = (
tf.reshape(tensor=inputs["attention_mask"], shape=(-1, seq_length)) tf.reshape(tensor=attention_mask, shape=(-1, seq_length))
if inputs["attention_mask"] is not None if attention_mask is not None
else None else None
) )
flat_token_type_ids = ( flat_token_type_ids = (
tf.reshape(tensor=inputs["token_type_ids"], shape=(-1, seq_length)) tf.reshape(tensor=token_type_ids, shape=(-1, seq_length))
if inputs["token_type_ids"] is not None if token_type_ids is not None
else None else None
) )
flat_position_ids = ( flat_position_ids = (
tf.reshape(tensor=inputs["position_ids"], shape=(-1, seq_length)) tf.reshape(tensor=position_ids, shape=(-1, seq_length))
if inputs["position_ids"] is not None if position_ids is not None
else None else None
) )
flat_inputs_embeds = ( flat_inputs_embeds = (
tf.reshape( tf.reshape(
tensor=inputs["inputs_embeds"], shape=(-1, seq_length, shape_list(inputs["inputs_embeds"])[3]) tensor=inputs_embeds, shape=(-1, seq_length, shape_list(inputs_embeds)[3])
) )
if inputs["inputs_embeds"] is not None if inputs_embeds is not None
else None else None
) )
outputs = self.{{cookiecutter.lowercase_modelname}}( outputs = self.{{cookiecutter.lowercase_modelname}}(
...@@ -1524,19 +1424,19 @@ class TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice(TF{{cookiecutter.c ...@@ -1524,19 +1424,19 @@ class TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice(TF{{cookiecutter.c
attention_mask=flat_attention_mask, attention_mask=flat_attention_mask,
token_type_ids=flat_token_type_ids, token_type_ids=flat_token_type_ids,
position_ids=flat_position_ids, position_ids=flat_position_ids,
head_mask=inputs["head_mask"], head_mask=head_mask,
inputs_embeds=flat_inputs_embeds, inputs_embeds=flat_inputs_embeds,
output_attentions=inputs["output_attentions"], output_attentions=output_attentions,
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=output_hidden_states,
return_dict=inputs["return_dict"], return_dict=return_dict,
training=inputs["training"], training=training,
) )
logits = self.sequence_summary(inputs=outputs[0], training=inputs["training"]) logits = self.sequence_summary(inputs=outputs[0], training=training)
logits = self.classifier(inputs=logits) logits = self.classifier(inputs=logits)
reshaped_logits = tf.reshape(tensor=logits, shape=(-1, num_choices)) reshaped_logits = tf.reshape(tensor=logits, shape=(-1, num_choices))
loss = None if inputs["labels"] is None else self.hf_compute_loss(labels=inputs["labels"], logits=reshaped_logits) loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=reshaped_logits)
if not inputs["return_dict"]: if not return_dict:
output = (reshaped_logits,) + outputs[1:] output = (reshaped_logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
...@@ -1585,6 +1485,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForTokenClassification(TF{{cookiecut ...@@ -1585,6 +1485,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForTokenClassification(TF{{cookiecut
units=config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier" units=config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
) )
@unpack_inputs
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC, processor_class=_TOKENIZER_FOR_DOC,
...@@ -1611,9 +1512,8 @@ class TF{{cookiecutter.camelcase_modelname}}ForTokenClassification(TF{{cookiecut ...@@ -1611,9 +1512,8 @@ class TF{{cookiecutter.camelcase_modelname}}ForTokenClassification(TF{{cookiecut
labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size, sequence_length)`, *optional*): labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
""" """
inputs = input_processing(
func=self.call, outputs = self.{{cookiecutter.lowercase_modelname}}(
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -1623,28 +1523,14 @@ class TF{{cookiecutter.camelcase_modelname}}ForTokenClassification(TF{{cookiecut ...@@ -1623,28 +1523,14 @@ class TF{{cookiecutter.camelcase_modelname}}ForTokenClassification(TF{{cookiecut
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
labels=labels,
training=training, training=training,
kwargs_call=kwargs,
)
outputs = self.{{cookiecutter.lowercase_modelname}}(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
) )
sequence_output = outputs[0] sequence_output = outputs[0]
sequence_output = self.dropout(inputs=sequence_output, training=inputs["training"]) sequence_output = self.dropout(inputs=sequence_output, training=training)
logits = self.classifier(inputs=sequence_output) logits = self.classifier(inputs=sequence_output)
loss = None if inputs["labels"] is None else self.hf_compute_loss(labels=inputs["labels"], logits=logits) loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits)
if not inputs["return_dict"]: if not return_dict:
output = (logits,) + outputs[1:] output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
...@@ -1680,6 +1566,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering(TF{{cookiecutte ...@@ -1680,6 +1566,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering(TF{{cookiecutte
units=config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs" units=config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
) )
@unpack_inputs
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC, processor_class=_TOKENIZER_FOR_DOC,
...@@ -1713,9 +1600,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering(TF{{cookiecutte ...@@ -1713,9 +1600,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering(TF{{cookiecutte
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the
sequence are not taken into account for computing the loss. sequence are not taken into account for computing the loss.
""" """
inputs = input_processing( outputs = self.{{cookiecutter.lowercase_modelname}}(
func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -1725,22 +1610,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering(TF{{cookiecutte ...@@ -1725,22 +1610,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering(TF{{cookiecutte
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
start_positions=start_positions,
end_positions=end_positions,
training=training, training=training,
kwargs_call=kwargs,
)
outputs = self.{{cookiecutter.lowercase_modelname}}(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
) )
sequence_output = outputs[0] sequence_output = outputs[0]
logits = self.qa_outputs(inputs=sequence_output) logits = self.qa_outputs(inputs=sequence_output)
...@@ -1749,12 +1619,12 @@ class TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering(TF{{cookiecutte ...@@ -1749,12 +1619,12 @@ class TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering(TF{{cookiecutte
end_logits = tf.squeeze(input=end_logits, axis=-1) end_logits = tf.squeeze(input=end_logits, axis=-1)
loss = None loss = None
if inputs["start_positions"] is not None and inputs["end_positions"] is not None: if start_positions is not None and end_positions is not None:
labels = {"start_position": inputs["start_positions"]} labels = {"start_position": start_positions}
labels["end_position"] = inputs["end_positions"] labels["end_position"] = end_positions
loss = self.hf_compute_loss(labels=labels, logits=(start_logits, end_logits)) loss = self.hf_compute_loss(labels=labels, logits=(start_logits, end_logits))
if not inputs["return_dict"]: if not return_dict:
output = (start_logits, end_logits) + outputs[2:] output = (start_logits, end_logits) + outputs[2:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
...@@ -1801,8 +1671,8 @@ from ...modeling_tf_utils import ( ...@@ -1801,8 +1671,8 @@ from ...modeling_tf_utils import (
TFPreTrainedModel, TFPreTrainedModel,
TFSharedEmbeddings, TFSharedEmbeddings,
TFWrappedEmbeddings, TFWrappedEmbeddings,
input_processing,
keras_serializable, keras_serializable,
unpack_inputs,
); from ...tf_utils import (shape_list, ); from ...tf_utils import (shape_list,
) )
from ...utils import logging from ...utils import logging
...@@ -2381,6 +2251,7 @@ class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer): ...@@ -2381,6 +2251,7 @@ class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer):
def set_embed_tokens(self, embed_tokens): def set_embed_tokens(self, embed_tokens):
self.embed_tokens = embed_tokens self.embed_tokens = embed_tokens
@unpack_inputs
def call( def call(
self, self,
input_ids=None, input_ids=None,
...@@ -2435,78 +2306,65 @@ class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer): ...@@ -2435,78 +2306,65 @@ class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer):
Whether or not to use the model in training mode (some modules like dropout modules have different Whether or not to use the model in training mode (some modules like dropout modules have different
behaviors between training and evaluation). behaviors between training and evaluation).
""" """
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None: if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif inputs["input_ids"] is not None: elif input_ids is not None:
input_shape = shape_list(inputs["input_ids"]) input_shape = shape_list(input_ids)
elif inputs["inputs_embeds"] is not None: elif inputs_embeds is not None:
input_shape = shape_list(inputs["inputs_embeds"])[:-1] input_shape = shape_list(inputs_embeds)[:-1]
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs["inputs_embeds"] is None: if inputs_embeds is None:
inputs["inputs_embeds"] = self.embed_tokens(inputs["input_ids"]) * self.embed_scale inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
embed_pos = self.embed_positions(input_shape) embed_pos = self.embed_positions(input_shape)
hidden_states = inputs["inputs_embeds"] + embed_pos hidden_states = inputs_embeds + embed_pos
hidden_states = self.layernorm_embedding(hidden_states) hidden_states = self.layernorm_embedding(hidden_states)
hidden_states = self.dropout(hidden_states, training=inputs["training"]) hidden_states = self.dropout(hidden_states, training=training)
# check attention mask and invert # check attention mask and invert
if inputs["attention_mask"] is not None: if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
inputs["attention_mask"] = _expand_mask(inputs["attention_mask"]) attention_mask = _expand_mask(attention_mask)
encoder_states = () if inputs["output_hidden_states"] else None encoder_states = () if output_hidden_states else None
all_attentions = () if inputs["output_attentions"] else None all_attentions = () if output_attentions else None
# check if head_mask has a correct number of layers specified if desired # check if head_mask has a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they # The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager. # have to be disabled in other modes than eager.
if inputs["head_mask"] is not None and tf.executing_eagerly(): if head_mask is not None and tf.executing_eagerly():
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(inputs["head_mask"])[0], shape_list(head_mask)[0],
len(self.layers), len(self.layers),
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.", message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(head_mask)[0]}.",
) )
# encoder layers # encoder layers
for idx, encoder_layer in enumerate(self.layers): for idx, encoder_layer in enumerate(self.layers):
if inputs["output_hidden_states"]: if output_hidden_states:
encoder_states = encoder_states + (hidden_states,) encoder_states = encoder_states + (hidden_states,)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
dropout_probability = random.uniform(0, 1) dropout_probability = random.uniform(0, 1)
if inputs["training"] and (dropout_probability < self.layerdrop): # skip the layer if training and (dropout_probability < self.layerdrop): # skip the layer
continue continue
hidden_states, attn = encoder_layer( hidden_states, attn = encoder_layer(
hidden_states, hidden_states,
inputs["attention_mask"], attention_mask,
inputs["head_mask"][idx] if inputs["head_mask"] is not None else None, head_mask[idx] if head_mask is not None else None,
) )
if inputs["output_attentions"]: if output_attentions:
all_attentions += (attn,) all_attentions += (attn,)
if inputs["output_hidden_states"]: if output_hidden_states:
encoder_states = encoder_states + (hidden_states,) encoder_states = encoder_states + (hidden_states,)
if not inputs["return_dict"]: if not return_dict:
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
return TFBaseModelOutput( return TFBaseModelOutput(
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
...@@ -2547,6 +2405,7 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer): ...@@ -2547,6 +2405,7 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer):
def set_embed_tokens(self, embed_tokens): def set_embed_tokens(self, embed_tokens):
self.embed_tokens = embed_tokens self.embed_tokens = embed_tokens
@unpack_inputs
def call( def call(
self, self,
input_ids=None, input_ids=None,
...@@ -2632,111 +2491,93 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer): ...@@ -2632,111 +2491,93 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer):
Whether or not to use the model in training mode (some modules like dropout modules have different Whether or not to use the model in training mode (some modules like dropout modules have different
behaviors between training and evaluation). behaviors between training and evaluation).
""" """
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
head_mask=head_mask,
cross_attn_head_mask=cross_attn_head_mask,
inputs_embeds=inputs_embeds,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None: if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif inputs["input_ids"] is not None: elif input_ids is not None:
input_shape = shape_list(inputs["input_ids"]) input_shape = shape_list(input_ids)
elif inputs["inputs_embeds"] is not None: elif inputs_embeds is not None:
input_shape = shape_list(inputs["inputs_embeds"])[:-1] input_shape = shape_list(inputs_embeds)[:-1]
else: else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
past_key_values_length = ( past_key_values_length = (
shape_list(inputs["past_key_values"][0][0])[2] if inputs["past_key_values"] is not None else 0 shape_list(past_key_values[0][0])[2] if past_key_values is not None else 0
) )
# embed positions # embed positions
positions = self.embed_positions(input_shape, past_key_values_length) positions = self.embed_positions(input_shape, past_key_values_length)
if inputs["inputs_embeds"] is None: if inputs_embeds is None:
inputs["inputs_embeds"] = self.embed_tokens(inputs["input_ids"]) inputs_embeds = self.embed_tokens(input_ids)
hidden_states = inputs["inputs_embeds"] hidden_states = inputs_embeds
inputs["attention_mask"], combined_attention_mask = self.compute_combined_attns_mask( attention_mask, combined_attention_mask = self.compute_combined_attns_mask(
inputs, input_shape, past_key_values_length input_ids, attention_mask, input_shape, past_key_values_length
) )
if inputs["encoder_hidden_states"] is not None and inputs["encoder_attention_mask"] is not None: if encoder_hidden_states is not None and encoder_attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
inputs["encoder_attention_mask"] = _expand_mask(inputs["encoder_attention_mask"], tgt_len=input_shape[-1]) encoder_attention_mask = _expand_mask(encoder_attention_mask, tgt_len=input_shape[-1])
hidden_states = self.layernorm_embedding(hidden_states + positions) hidden_states = self.layernorm_embedding(hidden_states + positions)
hidden_states = self.dropout(hidden_states, training=inputs["training"]) hidden_states = self.dropout(hidden_states, training=training)
# decoder layers # decoder layers
all_hidden_states = () if inputs["output_hidden_states"] else None all_hidden_states = () if output_hidden_states else None
all_self_attns = () if inputs["output_attentions"] else None all_self_attns = () if output_attentions else None
all_cross_attns = () if (inputs["output_attentions"] and inputs["encoder_hidden_states"] is not None) else None all_cross_attns = () if (output_attentions and encoder_hidden_states is not None) else None
present_key_values = () if inputs["use_cache"] else None present_key_values = () if use_cache else None
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired # check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they # The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager. # have to be disabled in other modes than eager.
for attn_mask in ["head_mask", "cross_attn_head_mask"]: for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]:
if inputs[attn_mask] is not None and tf.executing_eagerly(): if attn_mask is not None and tf.executing_eagerly():
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(inputs[attn_mask])[0], shape_list(attn_mask)[0],
len(self.layers), len(self.layers),
message=f"The {attn_mask} should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs[attn_mask])[0]}.", message=f"The {attn_mask_name} should be specified for {len(self.layers)} layers, but it is for {shape_list(attn_mask)[0]}.",
) )
for idx, decoder_layer in enumerate(self.layers): for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if inputs["output_hidden_states"]: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
dropout_probability = random.uniform(0, 1) dropout_probability = random.uniform(0, 1)
if inputs["training"] and (dropout_probability < self.layerdrop): if training and (dropout_probability < self.layerdrop):
continue continue
past_key_value = inputs["past_key_values"][idx] if inputs["past_key_values"] is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
hidden_states, layer_self_attn, layer_cross_attn, present_key_value = decoder_layer( hidden_states, layer_self_attn, layer_cross_attn, present_key_value = decoder_layer(
hidden_states, hidden_states,
attention_mask=combined_attention_mask, attention_mask=combined_attention_mask,
encoder_hidden_states=inputs["encoder_hidden_states"], encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=inputs["encoder_attention_mask"], encoder_attention_mask=encoder_attention_mask,
layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None, layer_head_mask=head_mask[idx] if head_mask is not None else None,
cross_attn_layer_head_mask=inputs["cross_attn_head_mask"][idx] cross_attn_layer_head_mask=cross_attn_head_mask[idx]
if inputs["cross_attn_head_mask"] is not None if cross_attn_head_mask is not None
else None, else None,
past_key_value=past_key_value, past_key_value=past_key_value,
) )
if inputs["use_cache"]: if use_cache:
present_key_values += (present_key_value,) present_key_values += (present_key_value,)
if inputs["output_attentions"]: if output_attentions:
all_self_attns += (layer_self_attn,) all_self_attns += (layer_self_attn,)
if inputs["encoder_hidden_states"] is not None: if encoder_hidden_states is not None:
all_cross_attns += (layer_cross_attn,) all_cross_attns += (layer_cross_attn,)
if inputs["output_hidden_states"]: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
if not inputs["return_dict"]: if not return_dict:
return hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attns return hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attns
else: else:
return TFBaseModelOutputWithPastAndCrossAttentions( return TFBaseModelOutputWithPastAndCrossAttentions(
...@@ -2748,7 +2589,7 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer): ...@@ -2748,7 +2589,7 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer):
) )
@tf.function @tf.function
def compute_combined_attns_mask(self, inputs, input_shape, past_key_values_length): def compute_combined_attns_mask(self, input_ids, attention_mask, input_shape, past_key_values_length):
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None combined_attention_mask = None
if input_shape[-1] > 1: if input_shape[-1] > 1:
...@@ -2758,9 +2599,9 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer): ...@@ -2758,9 +2599,9 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer):
tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1] tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1]
) )
if inputs["attention_mask"] is None and inputs["input_ids"] is not None and input_shape[-1] > 1: if attention_mask is None and input_ids is not None and input_shape[-1] > 1:
attention_mask = tf.cast( attention_mask = tf.cast(
tf.math.not_equal(inputs["input_ids"], self.config.pad_token_id), inputs["input_ids"].dtype tf.math.not_equal(input_ids, self.config.pad_token_id), input_ids.dtype
) )
attention_mask = tf.concat( attention_mask = tf.concat(
[ [
...@@ -2810,6 +2651,7 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer): ...@@ -2810,6 +2651,7 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
self.encoder.set_embed_tokens(embed_tokens) self.encoder.set_embed_tokens(embed_tokens)
self.decoder.set_embed_tokens(embed_tokens) self.decoder.set_embed_tokens(embed_tokens)
@unpack_inputs
def call( def call(
self, self,
input_ids=None, input_ids=None,
...@@ -2830,71 +2672,50 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer): ...@@ -2830,71 +2672,50 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
training=False, training=False,
**kwargs **kwargs
): ):
inputs = input_processing(
func=self.call, if decoder_input_ids is None and decoder_inputs_embeds is None:
config=self.config, use_cache = False
if encoder_outputs is None:
encoder_outputs = self.encoder(
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
encoder_outputs=encoder_outputs,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
training=training, training=training,
kwargs_call=kwargs,
)
if inputs["decoder_input_ids"] is None and inputs["decoder_inputs_embeds"] is None:
inputs["use_cache"] = False
if inputs["encoder_outputs"] is None:
inputs["encoder_outputs"] = self.encoder(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
) )
# If the user passed a tuple for encoder_outputs, we wrap it in a TFBaseModelOutput when return_dict=True # If the user passed a tuple for encoder_outputs, we wrap it in a TFBaseModelOutput when return_dict=True
elif inputs["return_dict"] and not isinstance(inputs["encoder_outputs"], TFBaseModelOutput): elif return_dict and not isinstance(encoder_outputs, TFBaseModelOutput):
inputs["encoder_outputs"] = TFBaseModelOutput( encoder_outputs = TFBaseModelOutput(
last_hidden_state=inputs["encoder_outputs"][0], last_hidden_state=encoder_outputs[0],
hidden_states=inputs["encoder_outputs"][1] if len(inputs["encoder_outputs"]) > 1 else None, hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
attentions=inputs["encoder_outputs"][2] if len(inputs["encoder_outputs"]) > 2 else None, attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
) )
# If the user passed a TFBaseModelOutput for encoder_outputs, we wrap it in a tuple when return_dict=False # If the user passed a TFBaseModelOutput for encoder_outputs, we wrap it in a tuple when return_dict=False
elif not inputs["return_dict"] and not isinstance(inputs["encoder_outputs"], tuple): elif not return_dict and not isinstance(encoder_outputs, tuple):
inputs["encoder_outputs"] = inputs["encoder_outputs"].to_tuple() encoder_outputs = encoder_outputs.to_tuple()
decoder_outputs = self.decoder( decoder_outputs = self.decoder(
inputs["decoder_input_ids"], decoder_input_ids,
attention_mask=inputs["decoder_attention_mask"], attention_mask=decoder_attention_mask,
encoder_hidden_states=inputs["encoder_outputs"][0], encoder_hidden_states=encoder_outputs[0],
encoder_attention_mask=inputs["attention_mask"], encoder_attention_mask=attention_mask,
head_mask=inputs["decoder_head_mask"], head_mask=decoder_head_mask,
cross_attn_head_mask=inputs["cross_attn_head_mask"], cross_attn_head_mask=cross_attn_head_mask,
past_key_values=inputs["past_key_values"], past_key_values=past_key_values,
inputs_embeds=inputs["decoder_inputs_embeds"], inputs_embeds=decoder_inputs_embeds,
use_cache=inputs["use_cache"], use_cache=use_cache,
output_attentions=inputs["output_attentions"], output_attentions=output_attentions,
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=output_hidden_states,
return_dict=inputs["return_dict"], return_dict=return_dict,
training=inputs["training"], training=training,
) )
if not inputs["return_dict"]: if not return_dict:
return decoder_outputs + inputs["encoder_outputs"] return decoder_outputs + encoder_outputs
return TFSeq2SeqModelOutput( return TFSeq2SeqModelOutput(
last_hidden_state=decoder_outputs.last_hidden_state, last_hidden_state=decoder_outputs.last_hidden_state,
...@@ -2902,9 +2723,9 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer): ...@@ -2902,9 +2723,9 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
decoder_hidden_states=decoder_outputs.hidden_states, decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions, decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions, cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state, encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=inputs["encoder_outputs"].hidden_states, encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=inputs["encoder_outputs"].attentions, encoder_attentions=encoder_outputs.attentions,
) )
...@@ -2924,6 +2745,7 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod ...@@ -2924,6 +2745,7 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod
def get_decoder(self): def get_decoder(self):
return self.model.decoder return self.model.decoder
@unpack_inputs
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC, processor_class=_TOKENIZER_FOR_DOC,
...@@ -2951,9 +2773,8 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod ...@@ -2951,9 +2773,8 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod
training=False, training=False,
**kwargs **kwargs
): ):
inputs = input_processing(
func=self.call, outputs = self.model(
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
...@@ -2970,26 +2791,6 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod ...@@ -2970,26 +2791,6 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
training=training, training=training,
kwargs_call=kwargs,
)
outputs = self.model(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
decoder_input_ids=inputs["decoder_input_ids"],
decoder_attention_mask=inputs["decoder_attention_mask"],
head_mask=inputs["head_mask"],
decoder_head_mask=inputs["decoder_head_mask"],
cross_attn_head_mask=inputs["cross_attn_head_mask"],
encoder_outputs=inputs["encoder_outputs"],
past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["inputs_embeds"],
decoder_inputs_embeds=inputs["decoder_inputs_embeds"],
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
) )
return outputs return outputs
...@@ -3053,6 +2854,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec ...@@ -3053,6 +2854,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec
def set_output_embeddings(self, value): def set_output_embeddings(self, value):
self.set_input_embeddings(value) self.set_input_embeddings(value)
@unpack_inputs
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
def call( def call(
...@@ -3093,17 +2895,23 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec ...@@ -3093,17 +2895,23 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec
>>> probs = tf.nn.softmax(logits[0]) >>> probs = tf.nn.softmax(logits[0])
>>> # probs[5] is associated with the mask token >>> # probs[5] is associated with the mask token
```""" ```"""
inputs = input_processing(
func=self.call, if labels is not None:
config=self.config, use_cache = False
input_ids=input_ids, if decoder_input_ids is None:
decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id
)
outputs = self.model(
input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
encoder_outputs=encoder_outputs,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask, head_mask=head_mask,
decoder_head_mask=decoder_head_mask, decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask, cross_attn_head_mask=cross_attn_head_mask,
encoder_outputs=encoder_outputs,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds,
...@@ -3111,41 +2919,13 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec ...@@ -3111,41 +2919,13 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
labels=labels, training=training
training=training,
kwargs_call=kwargs,
)
if inputs["labels"] is not None:
inputs["use_cache"] = False
if inputs["decoder_input_ids"] is None:
inputs["decoder_input_ids"] = shift_tokens_right(
inputs["labels"], self.config.pad_token_id, self.config.decoder_start_token_id
)
outputs = self.model(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
decoder_input_ids=inputs["decoder_input_ids"],
encoder_outputs=inputs["encoder_outputs"],
decoder_attention_mask=inputs["decoder_attention_mask"],
head_mask=inputs["head_mask"],
decoder_head_mask=inputs["decoder_head_mask"],
cross_attn_head_mask=inputs["cross_attn_head_mask"],
past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["inputs_embeds"],
decoder_inputs_embeds=inputs["decoder_inputs_embeds"],
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"]
) )
lm_logits = self.model.shared(outputs[0], mode="linear") lm_logits = self.model.shared(outputs[0], mode="linear")
lm_logits = lm_logits + self.final_logits_bias lm_logits = lm_logits + self.final_logits_bias
masked_lm_loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], lm_logits) masked_lm_loss = None if labels is None else self.hf_compute_loss(labels, lm_logits)
if not inputs["return_dict"]: if not return_dict:
output = (lm_logits,) + outputs[1:] output = (lm_logits,) + outputs[1:]
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
return TFSeq2SeqLMOutput( return TFSeq2SeqLMOutput(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment