Unverified Commit c2f8eaf6 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

TF: unpack inputs on Convbert, GPTJ, LED, and templates (#16491)

* Add unpack_inputs to remaining models

* remove stray use of inputs in the templates; fix tf.debugging of attn masks
parent ae189ef9
......@@ -943,12 +943,12 @@ class TFBartDecoder(tf.keras.layers.Layer):
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
for attn_mask in [head_mask, cross_attn_head_mask]:
for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]:
if attn_mask is not None and tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attn_mask)[0],
len(self.layers),
message=f"The {attn_mask} should be specified for {len(self.layers)} layers, but it is for {shape_list(attn_mask)[0]}.",
message=f"The {attn_mask_name} should be specified for {len(self.layers)} layers, but it is for {shape_list(attn_mask)[0]}.",
)
for idx, decoder_layer in enumerate(self.layers):
......
......@@ -39,8 +39,8 @@ from ...modeling_tf_utils import (
TFSequenceSummary,
TFTokenClassificationLoss,
get_initializer,
input_processing,
keras_serializable,
unpack_inputs,
)
from ...tf_utils import shape_list
from ...utils import (
......@@ -568,6 +568,7 @@ class TFConvBertMainLayer(tf.keras.layers.Layer):
return head_mask
@unpack_inputs
def call(
self,
input_ids=None,
......@@ -582,60 +583,36 @@ class TFConvBertMainLayer(tf.keras.layers.Layer):
training=False,
**kwargs,
):
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif inputs["input_ids"] is not None:
input_shape = shape_list(inputs["input_ids"])
elif inputs["inputs_embeds"] is not None:
input_shape = shape_list(inputs["inputs_embeds"])[:-1]
elif input_ids is not None:
input_shape = shape_list(input_ids)
elif inputs_embeds is not None:
input_shape = shape_list(inputs_embeds)[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs["attention_mask"] is None:
inputs["attention_mask"] = tf.fill(input_shape, 1)
if attention_mask is None:
attention_mask = tf.fill(input_shape, 1)
if inputs["token_type_ids"] is None:
inputs["token_type_ids"] = tf.fill(input_shape, 0)
if token_type_ids is None:
token_type_ids = tf.fill(input_shape, 0)
hidden_states = self.embeddings(
inputs["input_ids"],
inputs["position_ids"],
inputs["token_type_ids"],
inputs["inputs_embeds"],
training=inputs["training"],
)
extended_attention_mask = self.get_extended_attention_mask(
inputs["attention_mask"], input_shape, hidden_states.dtype
)
inputs["head_mask"] = self.get_head_mask(inputs["head_mask"])
hidden_states = self.embeddings(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, hidden_states.dtype)
head_mask = self.get_head_mask(head_mask)
if hasattr(self, "embeddings_project"):
hidden_states = self.embeddings_project(hidden_states, training=inputs["training"])
hidden_states = self.embeddings_project(hidden_states, training=training)
hidden_states = self.encoder(
hidden_states,
extended_attention_mask,
inputs["head_mask"],
inputs["output_attentions"],
inputs["output_hidden_states"],
inputs["return_dict"],
training=inputs["training"],
head_mask,
output_attentions,
output_hidden_states,
return_dict,
training=training,
)
return hidden_states
......@@ -754,6 +731,7 @@ class TFConvBertModel(TFConvBertPreTrainedModel):
self.convbert = TFConvBertMainLayer(config, name="convbert")
@unpack_inputs
@add_start_docstrings_to_model_forward(CONVBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
......@@ -775,9 +753,7 @@ class TFConvBertModel(TFConvBertPreTrainedModel):
training=False,
**kwargs,
):
inputs = input_processing(
func=self.call,
config=self.config,
outputs = self.convbert(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
......@@ -788,19 +764,6 @@ class TFConvBertModel(TFConvBertPreTrainedModel):
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
outputs = self.convbert(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
return outputs
......@@ -886,6 +849,7 @@ class TFConvBertForMaskedLM(TFConvBertPreTrainedModel, TFMaskedLanguageModelingL
def get_prefix_bias_name(self):
return self.name + "/" + self.generator_lm_head.name
@unpack_inputs
@add_start_docstrings_to_model_forward(CONVBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
......@@ -914,9 +878,7 @@ class TFConvBertForMaskedLM(TFConvBertPreTrainedModel, TFMaskedLanguageModelingL
config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
"""
inputs = input_processing(
func=self.call,
config=self.config,
generator_hidden_states = self.convbert(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
......@@ -926,28 +888,14 @@ class TFConvBertForMaskedLM(TFConvBertPreTrainedModel, TFMaskedLanguageModelingL
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
generator_hidden_states = self.convbert(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
generator_sequence_output = generator_hidden_states[0]
prediction_scores = self.generator_predictions(generator_sequence_output, training=inputs["training"])
prediction_scores = self.generator_lm_head(prediction_scores, training=inputs["training"])
loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], prediction_scores)
prediction_scores = self.generator_predictions(generator_sequence_output, training=training)
prediction_scores = self.generator_lm_head(prediction_scores, training=training)
loss = None if labels is None else self.hf_compute_loss(labels, prediction_scores)
if not inputs["return_dict"]:
if not return_dict:
output = (prediction_scores,) + generator_hidden_states[1:]
return ((loss,) + output) if loss is not None else output
......@@ -1010,6 +958,7 @@ class TFConvBertForSequenceClassification(TFConvBertPreTrainedModel, TFSequenceC
self.convbert = TFConvBertMainLayer(config, name="convbert")
self.classifier = TFConvBertClassificationHead(config, name="classifier")
@unpack_inputs
@add_start_docstrings_to_model_forward(CONVBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
......@@ -1038,10 +987,8 @@ class TFConvBertForSequenceClassification(TFConvBertPreTrainedModel, TFSequenceC
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
outputs = self.convbert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
......@@ -1050,26 +997,12 @@ class TFConvBertForSequenceClassification(TFConvBertPreTrainedModel, TFSequenceC
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
outputs = self.convbert(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
logits = self.classifier(outputs[0], training=inputs["training"])
loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], logits)
logits = self.classifier(outputs[0], training=training)
loss = None if labels is None else self.hf_compute_loss(labels, logits)
if not inputs["return_dict"]:
if not return_dict:
output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
......@@ -1117,6 +1050,7 @@ class TFConvBertForMultipleChoice(TFConvBertPreTrainedModel, TFMultipleChoiceLos
"""
return {"input_ids": tf.convert_to_tensor(MULTIPLE_CHOICE_DUMMY_INPUTS)}
@unpack_inputs
@add_start_docstrings_to_model_forward(
CONVBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
)
......@@ -1146,43 +1080,20 @@ class TFConvBertForMultipleChoice(TFConvBertPreTrainedModel, TFMultipleChoiceLos
Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]`
where `num_choices` is the size of the second dimension of the input tensors. (See `input_ids` above)
"""
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
if inputs["input_ids"] is not None:
num_choices = shape_list(inputs["input_ids"])[1]
seq_length = shape_list(inputs["input_ids"])[2]
if input_ids is not None:
num_choices = shape_list(input_ids)[1]
seq_length = shape_list(input_ids)[2]
else:
num_choices = shape_list(inputs["inputs_embeds"])[1]
seq_length = shape_list(inputs["inputs_embeds"])[2]
num_choices = shape_list(inputs_embeds)[1]
seq_length = shape_list(inputs_embeds)[2]
flat_input_ids = tf.reshape(inputs["input_ids"], (-1, seq_length)) if inputs["input_ids"] is not None else None
flat_attention_mask = (
tf.reshape(inputs["attention_mask"], (-1, seq_length)) if inputs["attention_mask"] is not None else None
)
flat_token_type_ids = (
tf.reshape(inputs["token_type_ids"], (-1, seq_length)) if inputs["token_type_ids"] is not None else None
)
flat_position_ids = (
tf.reshape(inputs["position_ids"], (-1, seq_length)) if inputs["position_ids"] is not None else None
)
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
flat_inputs_embeds = (
tf.reshape(inputs["inputs_embeds"], (-1, seq_length, shape_list(inputs["inputs_embeds"])[3]))
if inputs["inputs_embeds"] is not None
tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3]))
if inputs_embeds is not None
else None
)
outputs = self.convbert(
......@@ -1190,19 +1101,19 @@ class TFConvBertForMultipleChoice(TFConvBertPreTrainedModel, TFMultipleChoiceLos
flat_attention_mask,
flat_token_type_ids,
flat_position_ids,
inputs["head_mask"],
head_mask,
flat_inputs_embeds,
inputs["output_attentions"],
inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
output_attentions,
output_hidden_states,
return_dict=return_dict,
training=training,
)
logits = self.sequence_summary(outputs[0], training=inputs["training"])
logits = self.sequence_summary(outputs[0], training=training)
logits = self.classifier(logits)
reshaped_logits = tf.reshape(logits, (-1, num_choices))
loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], reshaped_logits)
loss = None if labels is None else self.hf_compute_loss(labels, reshaped_logits)
if not inputs["return_dict"]:
if not return_dict:
output = (reshaped_logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
......@@ -1256,6 +1167,7 @@ class TFConvBertForTokenClassification(TFConvBertPreTrainedModel, TFTokenClassif
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
)
@unpack_inputs
@add_start_docstrings_to_model_forward(CONVBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
......@@ -1282,10 +1194,8 @@ class TFConvBertForTokenClassification(TFConvBertPreTrainedModel, TFTokenClassif
labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
"""
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
outputs = self.convbert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
......@@ -1294,28 +1204,14 @@ class TFConvBertForTokenClassification(TFConvBertPreTrainedModel, TFTokenClassif
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
outputs = self.convbert(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output, training=inputs["training"])
sequence_output = self.dropout(sequence_output, training=training)
logits = self.classifier(sequence_output)
loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], logits)
loss = None if labels is None else self.hf_compute_loss(labels, logits)
if not inputs["return_dict"]:
if not return_dict:
output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
......@@ -1350,6 +1246,7 @@ class TFConvBertForQuestionAnswering(TFConvBertPreTrainedModel, TFQuestionAnswer
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
)
@unpack_inputs
@add_start_docstrings_to_model_forward(CONVBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
......@@ -1383,10 +1280,8 @@ class TFConvBertForQuestionAnswering(TFConvBertPreTrainedModel, TFQuestionAnswer
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss.
"""
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
outputs = self.convbert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
......@@ -1395,22 +1290,7 @@ class TFConvBertForQuestionAnswering(TFConvBertPreTrainedModel, TFQuestionAnswer
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
start_positions=start_positions,
end_positions=end_positions,
training=training,
kwargs_call=kwargs,
)
outputs = self.convbert(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
sequence_output = outputs[0]
logits = self.qa_outputs(sequence_output)
......@@ -1419,12 +1299,12 @@ class TFConvBertForQuestionAnswering(TFConvBertPreTrainedModel, TFQuestionAnswer
end_logits = tf.squeeze(end_logits, axis=-1)
loss = None
if inputs["start_positions"] is not None and inputs["end_positions"] is not None:
labels = {"start_position": inputs["start_positions"]}
labels["end_position"] = inputs["end_positions"]
if start_positions is not None and end_positions is not None:
labels = {"start_position": start_positions}
labels["end_position"] = end_positions
loss = self.hf_compute_loss(labels, (start_logits, end_logits))
if not inputs["return_dict"]:
if not return_dict:
output = (start_logits, end_logits) + outputs[1:]
return ((loss,) + output) if loss is not None else output
......
......@@ -40,7 +40,6 @@ from ...modeling_tf_utils import (
TFSequenceClassificationLoss,
TFSharedEmbeddings,
get_initializer,
input_processing,
keras_serializable,
unpack_inputs,
)
......@@ -376,6 +375,7 @@ class TFGPTJMainLayer(tf.keras.layers.Layer):
"""
raise NotImplementedError
@unpack_inputs
def call(
self,
input_ids=None,
......@@ -392,53 +392,34 @@ class TFGPTJMainLayer(tf.keras.layers.Layer):
training=False,
**kwargs,
):
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif inputs["input_ids"] is not None:
input_shape = shape_list(inputs["input_ids"])
inputs["input_ids"] = tf.reshape(inputs["input_ids"], [-1, input_shape[-1]])
elif inputs["inputs_embeds"] is not None:
input_shape = shape_list(inputs["inputs_embeds"])[:-1]
elif input_ids is not None:
input_shape = shape_list(input_ids)
input_ids = tf.reshape(input_ids, [-1, input_shape[-1]])
elif inputs_embeds is not None:
input_shape = shape_list(inputs_embeds)[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs["past_key_values"] is None:
if past_key_values is None:
past_length = 0
inputs["past_key_values"] = [None] * len(self.h)
past_key_values = [None] * len(self.h)
else:
past_length = shape_list(inputs["past_key_values"][0][0])[-2]
past_length = shape_list(past_key_values[0][0])[-2]
if inputs["position_ids"] is None:
inputs["position_ids"] = tf.expand_dims(tf.range(past_length, input_shape[-1] + past_length), axis=0)
if position_ids is None:
position_ids = tf.expand_dims(tf.range(past_length, input_shape[-1] + past_length), axis=0)
if inputs["attention_mask"] is not None:
if attention_mask is not None:
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
attention_mask_shape = shape_list(inputs["attention_mask"])
inputs["attention_mask"] = tf.reshape(
inputs["attention_mask"], (attention_mask_shape[0], 1, 1, attention_mask_shape[1])
)
attention_mask_shape = shape_list(attention_mask)
attention_mask = tf.reshape(attention_mask, (attention_mask_shape[0], 1, 1, attention_mask_shape[1]))
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
......@@ -446,78 +427,74 @@ class TFGPTJMainLayer(tf.keras.layers.Layer):
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
one_cst = tf.constant(1.0)
inputs["attention_mask"] = tf.cast(inputs["attention_mask"], dtype=one_cst.dtype)
inputs["attention_mask"] = tf.multiply(
tf.subtract(one_cst, inputs["attention_mask"]), tf.constant(-10000.0)
)
attention_mask = tf.cast(attention_mask, dtype=one_cst.dtype)
attention_mask = tf.multiply(tf.subtract(one_cst, attention_mask), tf.constant(-10000.0))
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
if inputs["head_mask"] is not None:
if head_mask is not None:
raise NotImplementedError
else:
inputs["head_mask"] = [None] * self.num_hidden_layers
head_mask = [None] * self.num_hidden_layers
# head_mask = tf.constant([0] * self.num_hidden_layers)
inputs["position_ids"] = tf.reshape(inputs["position_ids"], [-1, shape_list(inputs["position_ids"])[-1]])
position_ids = tf.reshape(position_ids, [-1, shape_list(position_ids)[-1]])
if inputs["inputs_embeds"] is None:
inputs["inputs_embeds"] = self.wte(inputs["input_ids"], mode="embedding")
if inputs_embeds is None:
inputs_embeds = self.wte(input_ids, mode="embedding")
if inputs["token_type_ids"] is not None:
inputs["token_type_ids"] = tf.reshape(
inputs["token_type_ids"], [-1, shape_list(inputs["token_type_ids"])[-1]]
)
token_type_embeds = self.wte(inputs["token_type_ids"], mode="embedding")
if token_type_ids is not None:
token_type_ids = tf.reshape(token_type_ids, [-1, shape_list(token_type_ids)[-1]])
token_type_embeds = self.wte(token_type_ids, mode="embedding")
else:
token_type_embeds = tf.constant(0.0)
token_type_embeds = tf.cast(token_type_embeds, dtype=inputs["inputs_embeds"].dtype)
hidden_states = inputs["inputs_embeds"] + token_type_embeds
hidden_states = self.drop(hidden_states, training=inputs["training"])
token_type_embeds = tf.cast(token_type_embeds, dtype=inputs_embeds.dtype)
hidden_states = inputs_embeds + token_type_embeds
hidden_states = self.drop(hidden_states, training=training)
output_shape = input_shape + [shape_list(hidden_states)[-1]]
presents = () if inputs["use_cache"] else None
all_attentions = () if inputs["output_attentions"] else None
all_hidden_states = () if inputs["output_hidden_states"] else None
for i, (block, layer_past) in enumerate(zip(self.h, inputs["past_key_values"])):
if inputs["output_hidden_states"]:
presents = () if use_cache else None
all_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
if output_hidden_states:
all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),)
outputs = block(
hidden_states,
layer_past,
inputs["attention_mask"],
inputs["head_mask"][i],
inputs["use_cache"],
inputs["output_attentions"],
training=inputs["training"],
attention_mask,
head_mask[i],
use_cache,
output_attentions,
training=training,
)
hidden_states = outputs[0]
if inputs["use_cache"]:
if use_cache:
presents = presents + (outputs[1],)
if inputs["output_attentions"]:
all_attentions = all_attentions + (outputs[2 if inputs["use_cache"] else 1],)
if output_attentions:
all_attentions = all_attentions + (outputs[2 if use_cache else 1],)
hidden_states = self.ln_f(hidden_states)
hidden_states = tf.reshape(hidden_states, output_shape)
# Add last hidden state
if inputs["output_hidden_states"]:
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if inputs["output_attentions"]:
if output_attentions:
# let the number of heads free (-1) so we can extract attention even after head pruning
attention_output_shape = input_shape[:-1] + [-1] + shape_list(all_attentions[0])[-2:]
all_attentions = tuple(tf.reshape(t, attention_output_shape) for t in all_attentions)
if not inputs["return_dict"]:
if not return_dict:
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_attentions] if v is not None)
return TFBaseModelOutputWithPast(
......
......@@ -30,8 +30,8 @@ from ...modeling_tf_utils import (
TFSharedEmbeddings,
TFWrappedEmbeddings,
get_initializer,
input_processing,
keras_serializable,
unpack_inputs,
)
from ...tf_utils import shape_list
from ...utils import (
......@@ -1654,6 +1654,7 @@ class TFLEDEncoder(tf.keras.layers.Layer):
def set_embed_tokens(self, embed_tokens):
self.embed_tokens = embed_tokens
@unpack_inputs
def call(
self,
input_ids=None,
......@@ -1703,95 +1704,74 @@ class TFLEDEncoder(tf.keras.layers.Layer):
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
global_attention_mask=global_attention_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif inputs["input_ids"] is not None:
input_shape = shape_list(inputs["input_ids"])
inputs["inputs_embeds"] = self.embed_tokens(inputs["input_ids"])
elif inputs["inputs_embeds"] is not None:
input_shape = shape_list(inputs["inputs_embeds"])[:-1]
elif input_ids is not None:
input_shape = shape_list(input_ids)
inputs_embeds = self.embed_tokens(input_ids)
elif inputs_embeds is not None:
input_shape = shape_list(inputs_embeds)[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs["attention_mask"] is None:
inputs["attention_mask"] = tf.fill(input_shape, 1)
if attention_mask is None:
attention_mask = tf.fill(input_shape, 1)
# merge `global_attention_mask` and `attention_mask`
if inputs["global_attention_mask"] is not None:
inputs["attention_mask"] = inputs["attention_mask"] * tf.cast(
(inputs["global_attention_mask"] + 1), dtype=inputs["attention_mask"].dtype
)
if global_attention_mask is not None:
attention_mask = attention_mask * tf.cast((global_attention_mask + 1), dtype=attention_mask.dtype)
(
padding_len,
inputs["input_ids"],
inputs["attention_mask"],
inputs["inputs_embeds"],
) = self._pad_to_window_size(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
inputs_embeds=inputs["inputs_embeds"],
(padding_len, input_ids, attention_mask, inputs_embeds,) = self._pad_to_window_size(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
pad_token_id=self.padding_idx,
)
input_shape = shape_list(inputs["attention_mask"])
input_shape = shape_list(attention_mask)
# is index masked or global attention
is_index_masked = tf.math.less(tf.cast(inputs["attention_mask"], tf.int8), 1)
is_index_global_attn = tf.math.greater(tf.cast(inputs["attention_mask"], tf.int8), 1)
is_index_masked = tf.math.less(tf.cast(attention_mask, tf.int8), 1)
is_index_global_attn = tf.math.greater(tf.cast(attention_mask, tf.int8), 1)
is_global_attn = tf.math.reduce_any(is_index_global_attn)
embed_pos = self.embed_positions(input_shape)
hidden_states = inputs["inputs_embeds"] + embed_pos
hidden_states = inputs_embeds + embed_pos
hidden_states = self.layernorm_embedding(hidden_states)
hidden_states = self.dropout(hidden_states, training=inputs["training"])
hidden_states = self.dropout(hidden_states, training=training)
# check attention mask and invert
if inputs["attention_mask"] is not None:
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
inputs["attention_mask"] = _expand_mask(inputs["attention_mask"])[:, 0, 0, :]
inputs["attention_mask"] = inputs["attention_mask"][:, :, None, None]
attention_mask = _expand_mask(attention_mask)[:, 0, 0, :]
attention_mask = attention_mask[:, :, None, None]
encoder_states = () if inputs["output_hidden_states"] else None
all_attentions = all_global_attentions = () if inputs["output_attentions"] else None
encoder_states = () if output_hidden_states else None
all_attentions = all_global_attentions = () if output_attentions else None
# check if head_mask has a correct number of layers specified if desired
if inputs["head_mask"] is not None and tf.executing_eagerly():
if head_mask is not None and tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(inputs["head_mask"])[0],
shape_list(head_mask)[0],
len(self.layers),
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.",
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(head_mask)[0]}.",
)
# encoder layers
for idx, encoder_layer in enumerate(self.layers):
if inputs["output_hidden_states"]:
if output_hidden_states:
hidden_states_to_add = self.compute_hidden_states(hidden_states, padding_len)
encoder_states = encoder_states + (hidden_states_to_add,)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
dropout_probability = random.uniform(0, 1)
if inputs["training"] and (dropout_probability < self.layerdrop): # skip the layer
if training and (dropout_probability < self.layerdrop): # skip the layer
continue
layer_outputs = encoder_layer(
hidden_states=hidden_states,
attention_mask=inputs["attention_mask"],
layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None,
attention_mask=attention_mask,
layer_head_mask=head_mask[idx] if head_mask is not None else None,
is_index_masked=is_index_masked,
is_index_global_attn=is_index_global_attn,
is_global_attn=is_global_attn,
......@@ -1799,7 +1779,7 @@ class TFLEDEncoder(tf.keras.layers.Layer):
hidden_states = layer_outputs[0]
if inputs["output_attentions"]:
if output_attentions:
# bzs x seq_len x num_attn_heads x (num_global_attn + attention_window_len + 1) => bzs x num_attn_heads x seq_len x (num_global_attn + attention_window_len + 1)
all_attentions = all_attentions + (tf.transpose(layer_outputs[1], (0, 2, 1, 3)),)
......@@ -1811,17 +1791,17 @@ class TFLEDEncoder(tf.keras.layers.Layer):
hidden_states = self.compute_hidden_states(hidden_states, padding_len)
# undo padding
if inputs["output_attentions"]:
if output_attentions:
all_attentions = (
tuple([state[:, :, :-padding_len, :] for state in all_attentions])
if padding_len > 0
else all_attentions
)
if inputs["output_hidden_states"]:
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if not inputs["return_dict"]:
if not return_dict:
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
return TFLEDEncoderBaseModelOutput(
last_hidden_state=hidden_states,
......@@ -1915,6 +1895,7 @@ class TFLEDDecoder(tf.keras.layers.Layer):
def set_embed_tokens(self, embed_tokens):
self.embed_tokens = embed_tokens
@unpack_inputs
def call(
self,
input_ids=None,
......@@ -1985,45 +1966,25 @@ class TFLEDDecoder(tf.keras.layers.Layer):
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
head_mask=head_mask,
encoder_head_mask=encoder_head_mask,
inputs_embeds=inputs_embeds,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif inputs["input_ids"] is not None:
input_shape = shape_list(inputs["input_ids"])
elif inputs["inputs_embeds"] is not None:
input_shape = shape_list(inputs["inputs_embeds"])[:-1]
elif input_ids is not None:
input_shape = shape_list(input_ids)
elif inputs_embeds is not None:
input_shape = shape_list(inputs_embeds)[:-1]
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
past_key_values_length = (
shape_list(inputs["past_key_values"][0][0])[2] if inputs["past_key_values"] is not None else 0
)
past_key_values_length = shape_list(past_key_values[0][0])[2] if past_key_values is not None else 0
# embed positions
positions = self.embed_positions(input_shape, past_key_values_length)
if inputs["inputs_embeds"] is None:
inputs["inputs_embeds"] = self.embed_tokens(inputs["input_ids"])
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = inputs["inputs_embeds"]
hidden_states = inputs_embeds
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
if input_shape[-1] > 1:
......@@ -2033,17 +1994,15 @@ class TFLEDDecoder(tf.keras.layers.Layer):
tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1]
)
if inputs["attention_mask"] is not None and input_shape[-1] > 1:
combined_attention_mask = combined_attention_mask + _expand_mask(
inputs["attention_mask"], tgt_len=input_shape[-1]
)
if attention_mask is not None and input_shape[-1] > 1:
combined_attention_mask = combined_attention_mask + _expand_mask(attention_mask, tgt_len=input_shape[-1])
if inputs["encoder_hidden_states"] is not None and inputs["encoder_attention_mask"] is not None:
if encoder_hidden_states is not None and encoder_attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
inputs["encoder_attention_mask"] = _expand_mask(inputs["encoder_attention_mask"], tgt_len=input_shape[-1])
encoder_attention_mask = _expand_mask(encoder_attention_mask, tgt_len=input_shape[-1])
hidden_states = self.layernorm_embedding(hidden_states + positions)
hidden_states = self.dropout(hidden_states, training=inputs["training"])
hidden_states = self.dropout(hidden_states, training=training)
# decoder layers
all_hidden_states = ()
......@@ -2052,54 +2011,52 @@ class TFLEDDecoder(tf.keras.layers.Layer):
present_key_values = ()
# check if head_mask has a correct number of layers specified if desired
if inputs["head_mask"] is not None and tf.executing_eagerly():
if head_mask is not None and tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(inputs["head_mask"])[0],
shape_list(head_mask)[0],
len(self.layers),
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.",
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(head_mask)[0]}.",
)
for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if inputs["output_hidden_states"]:
if output_hidden_states:
all_hidden_states += (hidden_states,)
dropout_probability = random.uniform(0, 1)
if inputs["training"] and (dropout_probability < self.layerdrop):
if training and (dropout_probability < self.layerdrop):
continue
past_key_value = inputs["past_key_values"][idx] if inputs["past_key_values"] is not None else None
past_key_value = past_key_values[idx] if past_key_values is not None else None
hidden_states, layer_self_attn, layer_cross_attn, present_key_value = decoder_layer(
hidden_states,
attention_mask=combined_attention_mask,
encoder_hidden_states=inputs["encoder_hidden_states"],
encoder_attention_mask=inputs["encoder_attention_mask"],
layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None,
encoder_layer_head_mask=inputs["encoder_head_mask"][idx]
if inputs["encoder_head_mask"] is not None
else None,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
layer_head_mask=head_mask[idx] if head_mask is not None else None,
encoder_layer_head_mask=encoder_head_mask[idx] if encoder_head_mask is not None else None,
past_key_value=past_key_value,
)
if inputs["use_cache"]:
if use_cache:
present_key_values += (present_key_value,)
if inputs["output_attentions"]:
if output_attentions:
all_self_attns += (layer_self_attn,)
all_cross_attentions += (layer_cross_attn,)
if inputs["output_hidden_states"]:
if output_hidden_states:
all_hidden_states += (hidden_states,)
else:
all_hidden_states = None
all_self_attns = all_self_attns if inputs["output_attentions"] else None
all_cross_attentions = all_cross_attentions if inputs["output_attentions"] else None
all_self_attns = all_self_attns if output_attentions else None
all_cross_attentions = all_cross_attentions if output_attentions else None
present_key_values = present_key_values if inputs["use_cache"] else None
present_key_values = present_key_values if use_cache else None
if not inputs["return_dict"]:
if not return_dict:
return tuple(
v
for v in [hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attentions]
......@@ -2149,6 +2106,7 @@ class TFLEDMainLayer(tf.keras.layers.Layer):
self.encoder.set_embed_tokens(embed_tokens)
self.decoder.set_embed_tokens(embed_tokens)
@unpack_inputs
def call(
self,
input_ids=None,
......@@ -2169,72 +2127,51 @@ class TFLEDMainLayer(tf.keras.layers.Layer):
training=False,
**kwargs
):
inputs = input_processing(
func=self.call,
config=self.config,
if decoder_input_ids is None and decoder_inputs_embeds is None:
use_cache = False
if encoder_outputs is None:
encoder_outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
encoder_outputs=encoder_outputs,
global_attention_mask=global_attention_mask,
past_key_values=past_key_values,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
if inputs["decoder_input_ids"] is None and inputs["decoder_inputs_embeds"] is None:
inputs["use_cache"] = False
if inputs["encoder_outputs"] is None:
inputs["encoder_outputs"] = self.encoder(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
global_attention_mask=inputs["global_attention_mask"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
# If the user passed a tuple for encoder_outputs, we wrap it in a TFLEDEncoderBaseModelOutput when return_dict=True
elif inputs["return_dict"] and not isinstance(inputs["encoder_outputs"], TFLEDEncoderBaseModelOutput):
inputs["encoder_outputs"] = TFLEDEncoderBaseModelOutput(
last_hidden_state=inputs["encoder_outputs"][0],
hidden_states=inputs["encoder_outputs"][1] if len(inputs["encoder_outputs"]) > 1 else None,
attentions=inputs["encoder_outputs"][2] if len(inputs["encoder_outputs"]) > 2 else None,
elif return_dict and not isinstance(encoder_outputs, TFLEDEncoderBaseModelOutput):
encoder_outputs = TFLEDEncoderBaseModelOutput(
last_hidden_state=encoder_outputs[0],
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
)
# If the user passed a TFLEDEncoderBaseModelOutput for encoder_outputs, we wrap it in a tuple when return_dict=False
elif not inputs["return_dict"] and not isinstance(inputs["encoder_outputs"], tuple):
inputs["encoder_outputs"] = inputs["encoder_outputs"].to_tuple()
elif not return_dict and not isinstance(encoder_outputs, tuple):
encoder_outputs = encoder_outputs.to_tuple()
decoder_outputs = self.decoder(
inputs["decoder_input_ids"],
attention_mask=inputs["decoder_attention_mask"],
encoder_hidden_states=inputs["encoder_outputs"][0],
encoder_attention_mask=inputs["attention_mask"],
head_mask=inputs["decoder_head_mask"],
encoder_head_mask=inputs["head_mask"],
past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["decoder_inputs_embeds"],
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
if not inputs["return_dict"]:
return decoder_outputs + inputs["encoder_outputs"]
decoder_input_ids,
attention_mask=decoder_attention_mask,
encoder_hidden_states=encoder_outputs[0],
encoder_attention_mask=attention_mask,
head_mask=decoder_head_mask,
encoder_head_mask=head_mask,
past_key_values=past_key_values,
inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
if not return_dict:
return decoder_outputs + encoder_outputs
return TFLEDSeq2SeqModelOutput(
last_hidden_state=decoder_outputs.last_hidden_state,
......@@ -2242,10 +2179,10 @@ class TFLEDMainLayer(tf.keras.layers.Layer):
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state,
encoder_hidden_states=inputs["encoder_outputs"].hidden_states,
encoder_attentions=inputs["encoder_outputs"].attentions,
encoder_global_attentions=inputs["encoder_outputs"].global_attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
encoder_global_attentions=encoder_outputs.global_attentions,
)
......@@ -2265,6 +2202,7 @@ class TFLEDModel(TFLEDPreTrainedModel):
def get_decoder(self):
return self.led.decoder
@unpack_inputs
@add_start_docstrings_to_model_forward(LED_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
......@@ -2292,17 +2230,16 @@ class TFLEDModel(TFLEDPreTrainedModel):
training=False,
**kwargs
):
inputs = input_processing(
func=self.call,
config=self.config,
outputs = self.led(
input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
encoder_outputs=encoder_outputs,
global_attention_mask=global_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds,
......@@ -2311,25 +2248,6 @@ class TFLEDModel(TFLEDPreTrainedModel):
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
outputs = self.led(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
decoder_input_ids=inputs["decoder_input_ids"],
decoder_attention_mask=inputs["decoder_attention_mask"],
encoder_outputs=inputs["encoder_outputs"],
global_attention_mask=inputs["global_attention_mask"],
head_mask=inputs["head_mask"],
decoder_head_mask=inputs["decoder_head_mask"],
past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["inputs_embeds"],
decoder_inputs_embeds=inputs["decoder_inputs_embeds"],
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
return outputs
......@@ -2393,6 +2311,7 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel):
def set_output_embeddings(self, value):
self.set_input_embeddings(value)
@unpack_inputs
@add_start_docstrings_to_model_forward(LED_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFLEDSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
def call(
......@@ -2435,17 +2354,22 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel):
>>> # probs[5] is associated with the mask token
```"""
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
if labels is not None:
use_cache = False
if decoder_input_ids is None:
decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id
)
outputs = self.led(
input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
encoder_outputs=encoder_outputs,
global_attention_mask=global_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds,
......@@ -2453,41 +2377,13 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel):
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
if inputs["labels"] is not None:
inputs["use_cache"] = False
if inputs["decoder_input_ids"] is None:
inputs["decoder_input_ids"] = shift_tokens_right(
inputs["labels"], self.config.pad_token_id, self.config.decoder_start_token_id
)
outputs = self.led(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
decoder_input_ids=inputs["decoder_input_ids"],
decoder_attention_mask=inputs["decoder_attention_mask"],
encoder_outputs=inputs["encoder_outputs"],
global_attention_mask=inputs["global_attention_mask"],
head_mask=inputs["head_mask"],
decoder_head_mask=inputs["decoder_head_mask"],
past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["inputs_embeds"],
decoder_inputs_embeds=inputs["decoder_inputs_embeds"],
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
lm_logits = self.led.shared(outputs[0], mode="linear")
lm_logits = lm_logits + self.final_logits_bias
masked_lm_loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], lm_logits)
masked_lm_loss = None if labels is None else self.hf_compute_loss(labels, lm_logits)
if not inputs["return_dict"]:
if not return_dict:
output = (lm_logits,) + outputs[1:]
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
return TFLEDSeq2SeqLMOutput(
......
......@@ -965,12 +965,12 @@ class TFMBartDecoder(tf.keras.layers.Layer):
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
for attn_mask in [head_mask, cross_attn_head_mask]:
for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]:
if attn_mask is not None and tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attn_mask)[0],
len(self.layers),
message=f"The {attn_mask} should be specified for {len(self.layers)} layers, but it is for {shape_list(attn_mask)[0]}.",
message=f"The {attn_mask_name} should be specified for {len(self.layers)} layers, but it is for {shape_list(attn_mask)[0]}.",
)
for idx, decoder_layer in enumerate(self.layers):
......
......@@ -1060,12 +1060,12 @@ class TFSpeech2TextDecoder(tf.keras.layers.Layer):
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they have to be disabled in other modes than eager.
for attn_mask in [head_mask, cross_attn_head_mask]:
for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]:
if attn_mask is not None and tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attn_mask)[0],
len(self.layers),
message=f"The {attn_mask} should be specified for {len(self.layers)} layers, but it is for {shape_list(attn_mask)[0]}.",
message=f"The {attn_mask_name} should be specified for {len(self.layers)} layers, but it is for {shape_list(attn_mask)[0]}.",
)
for idx, decoder_layer in enumerate(self.layers):
......
......@@ -50,8 +50,8 @@ from ...modeling_tf_utils import (
TFSequenceSummary,
TFTokenClassificationLoss,
get_initializer,
input_processing,
keras_serializable,
unpack_inputs,
)
from ...tf_utils import shape_list
from ...utils import logging
......@@ -636,6 +636,7 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
"""
raise NotImplementedError
@unpack_inputs
def call(
self,
input_ids: Optional[TFModelInputType] = None,
......@@ -654,59 +655,40 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
training: bool = False,
**kwargs,
) -> Union[TFBaseModelOutputWithPastAndCrossAttentions, Tuple[tf.Tensor]]:
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
if not self.config.is_decoder:
inputs["use_cache"] = False
use_cache = False
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif inputs["input_ids"] is not None:
input_shape = shape_list(inputs["input_ids"])
elif inputs["inputs_embeds"] is not None:
input_shape = shape_list(inputs["inputs_embeds"])[:-1]
elif input_ids is not None:
input_shape = shape_list(input_ids)
elif inputs_embeds is not None:
input_shape = shape_list(inputs_embeds)[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
batch_size, seq_length = input_shape
if inputs["past_key_values"] is None:
if past_key_values is None:
past_key_values_length = 0
inputs["past_key_values"] = [None] * len(self.encoder.layer)
past_key_values = [None] * len(self.encoder.layer)
else:
past_key_values_length = shape_list(inputs["past_key_values"][0][0])[-2]
past_key_values_length = shape_list(past_key_values[0][0])[-2]
if inputs["attention_mask"] is None:
inputs["attention_mask"] = tf.fill(dims=(batch_size, seq_length + past_key_values_length), value=1)
if attention_mask is None:
attention_mask = tf.fill(dims=(batch_size, seq_length + past_key_values_length), value=1)
if inputs["token_type_ids"] is None:
inputs["token_type_ids"] = tf.fill(dims=input_shape, value=0)
if token_type_ids is None:
token_type_ids = tf.fill(dims=input_shape, value=0)
embedding_output = self.embeddings(
input_ids=inputs["input_ids"],
position_ids=inputs["position_ids"],
token_type_ids=inputs["token_type_ids"],
inputs_embeds=inputs["inputs_embeds"],
input_ids=input_ids,
position_ids=position_ids,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
past_key_values_length=past_key_values_length,
training=inputs["training"],
training=training,
)
# We create a 3D attention mask from a 2D tensor mask.
......@@ -714,7 +696,7 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
attention_mask_shape = shape_list(inputs["attention_mask"])
attention_mask_shape = shape_list(attention_mask)
mask_seq_length = seq_length + past_key_values_length
# Copied from `modeling_tf_t5.py`
......@@ -727,18 +709,18 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
tf.tile(seq_ids[None, None, :], (batch_size, mask_seq_length, 1)),
seq_ids[None, :, None],
)
causal_mask = tf.cast(causal_mask, dtype=inputs["attention_mask"].dtype)
extended_attention_mask = causal_mask * inputs["attention_mask"][:, None, :]
causal_mask = tf.cast(causal_mask, dtype=attention_mask.dtype)
extended_attention_mask = causal_mask * attention_mask[:, None, :]
attention_mask_shape = shape_list(extended_attention_mask)
extended_attention_mask = tf.reshape(
extended_attention_mask, (attention_mask_shape[0], 1, attention_mask_shape[1], attention_mask_shape[2])
)
if inputs["past_key_values"][0] is not None:
if past_key_values[0] is not None:
# attention_mask needs to be sliced to the shape `[batch_size, 1, from_seq_length - cached_seq_length, to_seq_length]
extended_attention_mask = extended_attention_mask[:, :, -seq_length:, :]
else:
extended_attention_mask = tf.reshape(
inputs["attention_mask"], (attention_mask_shape[0], 1, 1, attention_mask_shape[1])
attention_mask, (attention_mask_shape[0], 1, 1, attention_mask_shape[1])
)
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
......@@ -752,18 +734,18 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst)
# Copied from `modeling_tf_t5.py` with -1e9 -> -10000
if self.is_decoder and inputs["encoder_attention_mask"] is not None:
if self.is_decoder and encoder_attention_mask is not None:
# If a 2D ou 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length]
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
inputs["encoder_attention_mask"] = tf.cast(
inputs["encoder_attention_mask"], dtype=extended_attention_mask.dtype
encoder_attention_mask = tf.cast(
encoder_attention_mask, dtype=extended_attention_mask.dtype
)
num_dims_encoder_attention_mask = len(shape_list(inputs["encoder_attention_mask"]))
num_dims_encoder_attention_mask = len(shape_list(encoder_attention_mask))
if num_dims_encoder_attention_mask == 3:
encoder_extended_attention_mask = inputs["encoder_attention_mask"][:, None, :, :]
encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
if num_dims_encoder_attention_mask == 2:
encoder_extended_attention_mask = inputs["encoder_attention_mask"][:, None, None, :]
encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
# T5 has a mask that can compare sequence ids, we can simulate this here with this transposition
# Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270
......@@ -779,28 +761,28 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
if inputs["head_mask"] is not None:
if head_mask is not None:
raise NotImplementedError
else:
inputs["head_mask"] = [None] * self.config.num_hidden_layers
head_mask = [None] * self.config.num_hidden_layers
encoder_outputs = self.encoder(
hidden_states=embedding_output,
attention_mask=extended_attention_mask,
head_mask=inputs["head_mask"],
encoder_hidden_states=inputs["encoder_hidden_states"],
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask,
past_key_values=inputs["past_key_values"],
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
sequence_output = encoder_outputs[0]
if not inputs["return_dict"]:
if not return_dict:
return (
sequence_output,
) + encoder_outputs[1:]
......@@ -943,6 +925,7 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod
self.{{cookiecutter.lowercase_modelname}} = TF{{cookiecutter.camelcase_modelname}}MainLayer(config, name="{{cookiecutter.lowercase_modelname}}")
@unpack_inputs
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
......@@ -988,9 +971,7 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod
If set to `True`, `past_key_values` key value states are returned and can be used to speed up
decoding (see `past_key_values`). Set to `False` during training, `True` during generation
"""
inputs = input_processing(
func=self.call,
config=self.config,
outputs = self.{{cookiecutter.lowercase_modelname}}(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
......@@ -1005,23 +986,6 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
outputs = self.{{cookiecutter.lowercase_modelname}}(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
encoder_hidden_states=inputs["encoder_hidden_states"],
encoder_attention_mask=inputs["encoder_attention_mask"],
past_key_values=inputs["past_key_values"],
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
return outputs
......@@ -1064,6 +1028,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForMaskedLM(TF{{cookiecutter.camelca
def get_lm_head(self) -> tf.keras.layers.Layer:
return self.mlm.predictions
@unpack_inputs
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
......@@ -1091,9 +1056,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForMaskedLM(TF{{cookiecutter.camelca
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
"""
inputs = input_processing(
func=self.call,
config=self.config,
outputs = self.{{cookiecutter.lowercase_modelname}}(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
......@@ -1103,29 +1066,15 @@ class TF{{cookiecutter.camelcase_modelname}}ForMaskedLM(TF{{cookiecutter.camelca
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
outputs = self.{{cookiecutter.lowercase_modelname}}(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
sequence_output = outputs[0]
prediction_scores = self.mlm(sequence_output=sequence_output, training=inputs["training"])
prediction_scores = self.mlm(sequence_output=sequence_output, training=training)
loss = (
None if inputs["labels"] is None else self.hf_compute_loss(labels=inputs["labels"], logits=prediction_scores)
None if labels is None else self.hf_compute_loss(labels=labels, logits=prediction_scores)
)
if not inputs["return_dict"]:
if not return_dict:
output = (prediction_scores,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
......@@ -1173,6 +1122,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForCausalLM(TF{{cookiecutter.camelca
"use_cache": model_kwargs["use_cache"],
}
@unpack_inputs
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
......@@ -1220,9 +1170,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForCausalLM(TF{{cookiecutter.camelca
labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the cross entropy classification loss. Indices should be in `[0, ..., config.vocab_size - 1]`.
"""
inputs = input_processing(
func=self.call,
config=self.config,
outputs = self.{{cookiecutter.lowercase_modelname}}(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
......@@ -1236,37 +1184,19 @@ class TF{{cookiecutter.camelcase_modelname}}ForCausalLM(TF{{cookiecutter.camelca
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
outputs = self.{{cookiecutter.lowercase_modelname}}(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
encoder_hidden_states=inputs["encoder_hidden_states"],
encoder_attention_mask=inputs["encoder_attention_mask"],
past_key_values=inputs["past_key_values"],
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
sequence_output = outputs[0]
logits = self.mlm(sequence_output=sequence_output, training=inputs["training"])
logits = self.mlm(sequence_output=sequence_output, training=training)
loss = None
if inputs["labels"] is not None:
if labels is not None:
# shift labels to the left and cut last logit token
shifted_logits = logits[:, :-1]
labels = inputs["labels"][:, 1:]
labels = labels[:, 1:]
loss = self.hf_compute_loss(labels=labels, logits=shifted_logits)
if not inputs["return_dict"]:
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
......@@ -1338,6 +1268,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForSequenceClassification(TF{{cookie
self.{{cookiecutter.lowercase_modelname}} = TF{{cookiecutter.camelcase_modelname}}MainLayer(config, name="{{cookiecutter.lowercase_modelname}}")
self.classifier = TF{{cookiecutter.camelcase_modelname}}ClassificationHead(config, name="classifier")
@unpack_inputs
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
......@@ -1365,9 +1296,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForSequenceClassification(TF{{cookie
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss),
If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
inputs = input_processing(
func=self.call,
config=self.config,
outputs = self.{{cookiecutter.lowercase_modelname}}(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
......@@ -1377,26 +1306,12 @@ class TF{{cookiecutter.camelcase_modelname}}ForSequenceClassification(TF{{cookie
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
outputs = self.{{cookiecutter.lowercase_modelname}}(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
logits = self.classifier(hidden_states=outputs[0], training=inputs["training"])
loss = None if inputs["labels"] is None else self.hf_compute_loss(labels=inputs["labels"], logits=logits)
if not inputs["return_dict"]:
logits = self.classifier(hidden_states=outputs[0], training=training)
loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits)
if not return_dict:
output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
......@@ -1443,6 +1358,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice(TF{{cookiecutter.c
"""
return {"input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS)}
@unpack_inputs
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
......@@ -1470,53 +1386,37 @@ class TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice(TF{{cookiecutter.c
Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]` where `num_choices` is the size of the second dimension of the input tensors. (See
`input_ids` above)
"""
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
if inputs["input_ids"] is not None:
num_choices = shape_list(inputs["input_ids"])[1]
seq_length = shape_list(inputs["input_ids"])[2]
if input_ids is not None:
num_choices = shape_list(input_ids)[1]
seq_length = shape_list(input_ids)[2]
else:
num_choices = shape_list(inputs["inputs_embeds"])[1]
seq_length = shape_list(inputs["inputs_embeds"])[2]
num_choices = shape_list(inputs_embeds)[1]
seq_length = shape_list(inputs_embeds)[2]
flat_input_ids = (
tf.reshape(tensor=inputs["input_ids"], shape=(-1, seq_length)) if inputs["input_ids"] is not None else None
tf.reshape(tensor=input_ids, shape=(-1, seq_length)) if input_ids is not None else None
)
flat_attention_mask = (
tf.reshape(tensor=inputs["attention_mask"], shape=(-1, seq_length))
if inputs["attention_mask"] is not None
tf.reshape(tensor=attention_mask, shape=(-1, seq_length))
if attention_mask is not None
else None
)
flat_token_type_ids = (
tf.reshape(tensor=inputs["token_type_ids"], shape=(-1, seq_length))
if inputs["token_type_ids"] is not None
tf.reshape(tensor=token_type_ids, shape=(-1, seq_length))
if token_type_ids is not None
else None
)
flat_position_ids = (
tf.reshape(tensor=inputs["position_ids"], shape=(-1, seq_length))
if inputs["position_ids"] is not None
tf.reshape(tensor=position_ids, shape=(-1, seq_length))
if position_ids is not None
else None
)
flat_inputs_embeds = (
tf.reshape(
tensor=inputs["inputs_embeds"], shape=(-1, seq_length, shape_list(inputs["inputs_embeds"])[3])
tensor=inputs_embeds, shape=(-1, seq_length, shape_list(inputs_embeds)[3])
)
if inputs["inputs_embeds"] is not None
if inputs_embeds is not None
else None
)
outputs = self.{{cookiecutter.lowercase_modelname}}(
......@@ -1524,19 +1424,19 @@ class TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice(TF{{cookiecutter.c
attention_mask=flat_attention_mask,
token_type_ids=flat_token_type_ids,
position_ids=flat_position_ids,
head_mask=inputs["head_mask"],
head_mask=head_mask,
inputs_embeds=flat_inputs_embeds,
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
logits = self.sequence_summary(inputs=outputs[0], training=inputs["training"])
logits = self.sequence_summary(inputs=outputs[0], training=training)
logits = self.classifier(inputs=logits)
reshaped_logits = tf.reshape(tensor=logits, shape=(-1, num_choices))
loss = None if inputs["labels"] is None else self.hf_compute_loss(labels=inputs["labels"], logits=reshaped_logits)
loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=reshaped_logits)
if not inputs["return_dict"]:
if not return_dict:
output = (reshaped_logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
......@@ -1585,6 +1485,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForTokenClassification(TF{{cookiecut
units=config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
)
@unpack_inputs
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
......@@ -1611,9 +1512,8 @@ class TF{{cookiecutter.camelcase_modelname}}ForTokenClassification(TF{{cookiecut
labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
"""
inputs = input_processing(
func=self.call,
config=self.config,
outputs = self.{{cookiecutter.lowercase_modelname}}(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
......@@ -1623,28 +1523,14 @@ class TF{{cookiecutter.camelcase_modelname}}ForTokenClassification(TF{{cookiecut
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
outputs = self.{{cookiecutter.lowercase_modelname}}(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
sequence_output = outputs[0]
sequence_output = self.dropout(inputs=sequence_output, training=inputs["training"])
sequence_output = self.dropout(inputs=sequence_output, training=training)
logits = self.classifier(inputs=sequence_output)
loss = None if inputs["labels"] is None else self.hf_compute_loss(labels=inputs["labels"], logits=logits)
loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits)
if not inputs["return_dict"]:
if not return_dict:
output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
......@@ -1680,6 +1566,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering(TF{{cookiecutte
units=config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
)
@unpack_inputs
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
......@@ -1713,9 +1600,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering(TF{{cookiecutte
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the
sequence are not taken into account for computing the loss.
"""
inputs = input_processing(
func=self.call,
config=self.config,
outputs = self.{{cookiecutter.lowercase_modelname}}(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
......@@ -1725,22 +1610,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering(TF{{cookiecutte
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
start_positions=start_positions,
end_positions=end_positions,
training=training,
kwargs_call=kwargs,
)
outputs = self.{{cookiecutter.lowercase_modelname}}(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
sequence_output = outputs[0]
logits = self.qa_outputs(inputs=sequence_output)
......@@ -1749,12 +1619,12 @@ class TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering(TF{{cookiecutte
end_logits = tf.squeeze(input=end_logits, axis=-1)
loss = None
if inputs["start_positions"] is not None and inputs["end_positions"] is not None:
labels = {"start_position": inputs["start_positions"]}
labels["end_position"] = inputs["end_positions"]
if start_positions is not None and end_positions is not None:
labels = {"start_position": start_positions}
labels["end_position"] = end_positions
loss = self.hf_compute_loss(labels=labels, logits=(start_logits, end_logits))
if not inputs["return_dict"]:
if not return_dict:
output = (start_logits, end_logits) + outputs[2:]
return ((loss,) + output) if loss is not None else output
......@@ -1801,8 +1671,8 @@ from ...modeling_tf_utils import (
TFPreTrainedModel,
TFSharedEmbeddings,
TFWrappedEmbeddings,
input_processing,
keras_serializable,
unpack_inputs,
); from ...tf_utils import (shape_list,
)
from ...utils import logging
......@@ -2381,6 +2251,7 @@ class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer):
def set_embed_tokens(self, embed_tokens):
self.embed_tokens = embed_tokens
@unpack_inputs
def call(
self,
input_ids=None,
......@@ -2435,78 +2306,65 @@ class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer):
Whether or not to use the model in training mode (some modules like dropout modules have different
behaviors between training and evaluation).
"""
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif inputs["input_ids"] is not None:
input_shape = shape_list(inputs["input_ids"])
elif inputs["inputs_embeds"] is not None:
input_shape = shape_list(inputs["inputs_embeds"])[:-1]
elif input_ids is not None:
input_shape = shape_list(input_ids)
elif inputs_embeds is not None:
input_shape = shape_list(inputs_embeds)[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs["inputs_embeds"] is None:
inputs["inputs_embeds"] = self.embed_tokens(inputs["input_ids"]) * self.embed_scale
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
embed_pos = self.embed_positions(input_shape)
hidden_states = inputs["inputs_embeds"] + embed_pos
hidden_states = inputs_embeds + embed_pos
hidden_states = self.layernorm_embedding(hidden_states)
hidden_states = self.dropout(hidden_states, training=inputs["training"])
hidden_states = self.dropout(hidden_states, training=training)
# check attention mask and invert
if inputs["attention_mask"] is not None:
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
inputs["attention_mask"] = _expand_mask(inputs["attention_mask"])
attention_mask = _expand_mask(attention_mask)
encoder_states = () if inputs["output_hidden_states"] else None
all_attentions = () if inputs["output_attentions"] else None
encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
# check if head_mask has a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if inputs["head_mask"] is not None and tf.executing_eagerly():
if head_mask is not None and tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(inputs["head_mask"])[0],
shape_list(head_mask)[0],
len(self.layers),
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.",
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(head_mask)[0]}.",
)
# encoder layers
for idx, encoder_layer in enumerate(self.layers):
if inputs["output_hidden_states"]:
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
dropout_probability = random.uniform(0, 1)
if inputs["training"] and (dropout_probability < self.layerdrop): # skip the layer
if training and (dropout_probability < self.layerdrop): # skip the layer
continue
hidden_states, attn = encoder_layer(
hidden_states,
inputs["attention_mask"],
inputs["head_mask"][idx] if inputs["head_mask"] is not None else None,
attention_mask,
head_mask[idx] if head_mask is not None else None,
)
if inputs["output_attentions"]:
if output_attentions:
all_attentions += (attn,)
if inputs["output_hidden_states"]:
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if not inputs["return_dict"]:
if not return_dict:
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
return TFBaseModelOutput(
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
......@@ -2547,6 +2405,7 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer):
def set_embed_tokens(self, embed_tokens):
self.embed_tokens = embed_tokens
@unpack_inputs
def call(
self,
input_ids=None,
......@@ -2632,111 +2491,93 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer):
Whether or not to use the model in training mode (some modules like dropout modules have different
behaviors between training and evaluation).
"""
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
head_mask=head_mask,
cross_attn_head_mask=cross_attn_head_mask,
inputs_embeds=inputs_embeds,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif inputs["input_ids"] is not None:
input_shape = shape_list(inputs["input_ids"])
elif inputs["inputs_embeds"] is not None:
input_shape = shape_list(inputs["inputs_embeds"])[:-1]
elif input_ids is not None:
input_shape = shape_list(input_ids)
elif inputs_embeds is not None:
input_shape = shape_list(inputs_embeds)[:-1]
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
past_key_values_length = (
shape_list(inputs["past_key_values"][0][0])[2] if inputs["past_key_values"] is not None else 0
shape_list(past_key_values[0][0])[2] if past_key_values is not None else 0
)
# embed positions
positions = self.embed_positions(input_shape, past_key_values_length)
if inputs["inputs_embeds"] is None:
inputs["inputs_embeds"] = self.embed_tokens(inputs["input_ids"])
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = inputs["inputs_embeds"]
hidden_states = inputs_embeds
inputs["attention_mask"], combined_attention_mask = self.compute_combined_attns_mask(
inputs, input_shape, past_key_values_length
attention_mask, combined_attention_mask = self.compute_combined_attns_mask(
input_ids, attention_mask, input_shape, past_key_values_length
)
if inputs["encoder_hidden_states"] is not None and inputs["encoder_attention_mask"] is not None:
if encoder_hidden_states is not None and encoder_attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
inputs["encoder_attention_mask"] = _expand_mask(inputs["encoder_attention_mask"], tgt_len=input_shape[-1])
encoder_attention_mask = _expand_mask(encoder_attention_mask, tgt_len=input_shape[-1])
hidden_states = self.layernorm_embedding(hidden_states + positions)
hidden_states = self.dropout(hidden_states, training=inputs["training"])
hidden_states = self.dropout(hidden_states, training=training)
# decoder layers
all_hidden_states = () if inputs["output_hidden_states"] else None
all_self_attns = () if inputs["output_attentions"] else None
all_cross_attns = () if (inputs["output_attentions"] and inputs["encoder_hidden_states"] is not None) else None
present_key_values = () if inputs["use_cache"] else None
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
all_cross_attns = () if (output_attentions and encoder_hidden_states is not None) else None
present_key_values = () if use_cache else None
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
for attn_mask in ["head_mask", "cross_attn_head_mask"]:
if inputs[attn_mask] is not None and tf.executing_eagerly():
for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]:
if attn_mask is not None and tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(inputs[attn_mask])[0],
shape_list(attn_mask)[0],
len(self.layers),
message=f"The {attn_mask} should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs[attn_mask])[0]}.",
message=f"The {attn_mask_name} should be specified for {len(self.layers)} layers, but it is for {shape_list(attn_mask)[0]}.",
)
for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if inputs["output_hidden_states"]:
if output_hidden_states:
all_hidden_states += (hidden_states,)
dropout_probability = random.uniform(0, 1)
if inputs["training"] and (dropout_probability < self.layerdrop):
if training and (dropout_probability < self.layerdrop):
continue
past_key_value = inputs["past_key_values"][idx] if inputs["past_key_values"] is not None else None
past_key_value = past_key_values[idx] if past_key_values is not None else None
hidden_states, layer_self_attn, layer_cross_attn, present_key_value = decoder_layer(
hidden_states,
attention_mask=combined_attention_mask,
encoder_hidden_states=inputs["encoder_hidden_states"],
encoder_attention_mask=inputs["encoder_attention_mask"],
layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None,
cross_attn_layer_head_mask=inputs["cross_attn_head_mask"][idx]
if inputs["cross_attn_head_mask"] is not None
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
layer_head_mask=head_mask[idx] if head_mask is not None else None,
cross_attn_layer_head_mask=cross_attn_head_mask[idx]
if cross_attn_head_mask is not None
else None,
past_key_value=past_key_value,
)
if inputs["use_cache"]:
if use_cache:
present_key_values += (present_key_value,)
if inputs["output_attentions"]:
if output_attentions:
all_self_attns += (layer_self_attn,)
if inputs["encoder_hidden_states"] is not None:
if encoder_hidden_states is not None:
all_cross_attns += (layer_cross_attn,)
if inputs["output_hidden_states"]:
if output_hidden_states:
all_hidden_states += (hidden_states,)
if not inputs["return_dict"]:
if not return_dict:
return hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attns
else:
return TFBaseModelOutputWithPastAndCrossAttentions(
......@@ -2748,7 +2589,7 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer):
)
@tf.function
def compute_combined_attns_mask(self, inputs, input_shape, past_key_values_length):
def compute_combined_attns_mask(self, input_ids, attention_mask, input_shape, past_key_values_length):
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1:
......@@ -2758,9 +2599,9 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer):
tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1]
)
if inputs["attention_mask"] is None and inputs["input_ids"] is not None and input_shape[-1] > 1:
if attention_mask is None and input_ids is not None and input_shape[-1] > 1:
attention_mask = tf.cast(
tf.math.not_equal(inputs["input_ids"], self.config.pad_token_id), inputs["input_ids"].dtype
tf.math.not_equal(input_ids, self.config.pad_token_id), input_ids.dtype
)
attention_mask = tf.concat(
[
......@@ -2810,6 +2651,7 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
self.encoder.set_embed_tokens(embed_tokens)
self.decoder.set_embed_tokens(embed_tokens)
@unpack_inputs
def call(
self,
input_ids=None,
......@@ -2830,71 +2672,50 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
training=False,
**kwargs
):
inputs = input_processing(
func=self.call,
config=self.config,
if decoder_input_ids is None and decoder_inputs_embeds is None:
use_cache = False
if encoder_outputs is None:
encoder_outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
encoder_outputs=encoder_outputs,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
if inputs["decoder_input_ids"] is None and inputs["decoder_inputs_embeds"] is None:
inputs["use_cache"] = False
if inputs["encoder_outputs"] is None:
inputs["encoder_outputs"] = self.encoder(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
# If the user passed a tuple for encoder_outputs, we wrap it in a TFBaseModelOutput when return_dict=True
elif inputs["return_dict"] and not isinstance(inputs["encoder_outputs"], TFBaseModelOutput):
inputs["encoder_outputs"] = TFBaseModelOutput(
last_hidden_state=inputs["encoder_outputs"][0],
hidden_states=inputs["encoder_outputs"][1] if len(inputs["encoder_outputs"]) > 1 else None,
attentions=inputs["encoder_outputs"][2] if len(inputs["encoder_outputs"]) > 2 else None,
elif return_dict and not isinstance(encoder_outputs, TFBaseModelOutput):
encoder_outputs = TFBaseModelOutput(
last_hidden_state=encoder_outputs[0],
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
)
# If the user passed a TFBaseModelOutput for encoder_outputs, we wrap it in a tuple when return_dict=False
elif not inputs["return_dict"] and not isinstance(inputs["encoder_outputs"], tuple):
inputs["encoder_outputs"] = inputs["encoder_outputs"].to_tuple()
elif not return_dict and not isinstance(encoder_outputs, tuple):
encoder_outputs = encoder_outputs.to_tuple()
decoder_outputs = self.decoder(
inputs["decoder_input_ids"],
attention_mask=inputs["decoder_attention_mask"],
encoder_hidden_states=inputs["encoder_outputs"][0],
encoder_attention_mask=inputs["attention_mask"],
head_mask=inputs["decoder_head_mask"],
cross_attn_head_mask=inputs["cross_attn_head_mask"],
past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["decoder_inputs_embeds"],
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
if not inputs["return_dict"]:
return decoder_outputs + inputs["encoder_outputs"]
decoder_input_ids,
attention_mask=decoder_attention_mask,
encoder_hidden_states=encoder_outputs[0],
encoder_attention_mask=attention_mask,
head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values,
inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
if not return_dict:
return decoder_outputs + encoder_outputs
return TFSeq2SeqModelOutput(
last_hidden_state=decoder_outputs.last_hidden_state,
......@@ -2902,9 +2723,9 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state,
encoder_hidden_states=inputs["encoder_outputs"].hidden_states,
encoder_attentions=inputs["encoder_outputs"].attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
)
......@@ -2924,6 +2745,7 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod
def get_decoder(self):
return self.model.decoder
@unpack_inputs
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
......@@ -2951,9 +2773,8 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod
training=False,
**kwargs
):
inputs = input_processing(
func=self.call,
config=self.config,
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
......@@ -2970,26 +2791,6 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
outputs = self.model(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
decoder_input_ids=inputs["decoder_input_ids"],
decoder_attention_mask=inputs["decoder_attention_mask"],
head_mask=inputs["head_mask"],
decoder_head_mask=inputs["decoder_head_mask"],
cross_attn_head_mask=inputs["cross_attn_head_mask"],
encoder_outputs=inputs["encoder_outputs"],
past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["inputs_embeds"],
decoder_inputs_embeds=inputs["decoder_inputs_embeds"],
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
return outputs
......@@ -3053,6 +2854,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec
def set_output_embeddings(self, value):
self.set_input_embeddings(value)
@unpack_inputs
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
def call(
......@@ -3093,17 +2895,23 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec
>>> probs = tf.nn.softmax(logits[0])
>>> # probs[5] is associated with the mask token
```"""
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
if labels is not None:
use_cache = False
if decoder_input_ids is None:
decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id
)
outputs = self.model(
input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
encoder_outputs=encoder_outputs,
decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
encoder_outputs=encoder_outputs,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds,
......@@ -3111,41 +2919,13 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
if inputs["labels"] is not None:
inputs["use_cache"] = False
if inputs["decoder_input_ids"] is None:
inputs["decoder_input_ids"] = shift_tokens_right(
inputs["labels"], self.config.pad_token_id, self.config.decoder_start_token_id
)
outputs = self.model(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
decoder_input_ids=inputs["decoder_input_ids"],
encoder_outputs=inputs["encoder_outputs"],
decoder_attention_mask=inputs["decoder_attention_mask"],
head_mask=inputs["head_mask"],
decoder_head_mask=inputs["decoder_head_mask"],
cross_attn_head_mask=inputs["cross_attn_head_mask"],
past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["inputs_embeds"],
decoder_inputs_embeds=inputs["decoder_inputs_embeds"],
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"]
training=training
)
lm_logits = self.model.shared(outputs[0], mode="linear")
lm_logits = lm_logits + self.final_logits_bias
masked_lm_loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], lm_logits)
masked_lm_loss = None if labels is None else self.hf_compute_loss(labels, lm_logits)
if not inputs["return_dict"]:
if not return_dict:
output = (lm_logits,) + outputs[1:]
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
return TFSeq2SeqLMOutput(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment