Unverified Commit da842e4e authored by Julien Plu's avatar Julien Plu Committed by GitHub
Browse files

Add next sentence prediction loss computation (#8462)

* Add next sentence prediction loss computation

* Apply style

* Fix tests

* Add forgotten import

* Add forgotten import

* Use a new parameter

* Remove kwargs and use positional arguments
parent 23290836
......@@ -46,6 +46,7 @@ from .modeling_tf_utils import (
TFCausalLanguageModelingLoss,
TFMaskedLanguageModelingLoss,
TFMultipleChoiceLoss,
TFNextSentencePredictionLoss,
TFPreTrainedModel,
TFQuestionAnsweringLoss,
TFSequenceClassificationLoss,
......@@ -1036,7 +1037,7 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss):
"""Bert Model with a `next sentence prediction (classification)` head on top. """,
BERT_START_DOCSTRING,
)
class TFBertForNextSentencePrediction(TFBertPreTrainedModel):
class TFBertForNextSentencePrediction(TFBertPreTrainedModel, TFNextSentencePredictionLoss):
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
......@@ -1045,7 +1046,20 @@ class TFBertForNextSentencePrediction(TFBertPreTrainedModel):
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=TFNextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC)
def call(self, inputs, **kwargs):
def call(
self,
inputs=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
next_sentence_label=None,
training=False,
):
r"""
Return:
......@@ -1064,17 +1078,43 @@ class TFBertForNextSentencePrediction(TFBertPreTrainedModel):
>>> logits = model(encoding['input_ids'], token_type_ids=encoding['token_type_ids'])[0]
>>> assert logits[0][0] < logits[0][1] # the next sentence was random
"""
return_dict = kwargs.get("return_dict")
return_dict = return_dict if return_dict is not None else self.bert.return_dict
outputs = self.bert(inputs, **kwargs)
if isinstance(inputs, (tuple, list)):
next_sentence_label = inputs[9] if len(inputs) > 9 else next_sentence_label
if len(inputs) > 9:
inputs = inputs[:9]
elif isinstance(inputs, (dict, BatchEncoding)):
next_sentence_label = inputs.pop("next_sentence_label", next_sentence_label)
outputs = self.bert(
inputs,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
pooled_output = outputs[1]
seq_relationship_score = self.nsp(pooled_output)
seq_relationship_scores = self.nsp(pooled_output)
next_sentence_loss = (
None
if next_sentence_label is None
else self.compute_loss(labels=next_sentence_label, logits=seq_relationship_scores)
)
if not return_dict:
return (seq_relationship_score,) + outputs[2:]
output = (seq_relationship_scores,) + outputs[2:]
return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output
return TFNextSentencePredictorOutput(
logits=seq_relationship_score,
loss=next_sentence_loss,
logits=seq_relationship_scores,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
......
......@@ -44,6 +44,7 @@ from .modeling_tf_outputs import (
from .modeling_tf_utils import (
TFMaskedLanguageModelingLoss,
TFMultipleChoiceLoss,
TFNextSentencePredictionLoss,
TFPreTrainedModel,
TFQuestionAnsweringLoss,
TFSequenceClassificationLoss,
......@@ -1119,7 +1120,7 @@ class TFMobileBertOnlyNSPHead(tf.keras.layers.Layer):
"""MobileBert Model with a `next sentence prediction (classification)` head on top. """,
MOBILEBERT_START_DOCSTRING,
)
class TFMobileBertForNextSentencePrediction(TFMobileBertPreTrainedModel):
class TFMobileBertForNextSentencePrediction(TFMobileBertPreTrainedModel, TFNextSentencePredictionLoss):
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
......@@ -1128,7 +1129,20 @@ class TFMobileBertForNextSentencePrediction(TFMobileBertPreTrainedModel):
@add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=TFNextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC)
def call(self, inputs, **kwargs):
def call(
self,
inputs=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
next_sentence_label=None,
training=False,
):
r"""
Return:
......@@ -1146,18 +1160,44 @@ class TFMobileBertForNextSentencePrediction(TFMobileBertPreTrainedModel):
>>> logits = model(encoding['input_ids'], token_type_ids=encoding['token_type_ids'])[0]
"""
return_dict = kwargs.get("return_dict")
return_dict = return_dict if return_dict is not None else self.mobilebert.return_dict
outputs = self.mobilebert(inputs, **kwargs)
if isinstance(inputs, (tuple, list)):
next_sentence_label = inputs[9] if len(inputs) > 9 else next_sentence_label
if len(inputs) > 9:
inputs = inputs[:9]
elif isinstance(inputs, (dict, BatchEncoding)):
next_sentence_label = inputs.pop("next_sentence_label", next_sentence_label)
outputs = self.mobilebert(
inputs,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
pooled_output = outputs[1]
seq_relationship_score = self.cls(pooled_output)
seq_relationship_scores = self.cls(pooled_output)
next_sentence_loss = (
None
if next_sentence_label is None
else self.compute_loss(labels=next_sentence_label, logits=seq_relationship_scores)
)
if not return_dict:
return (seq_relationship_score,) + outputs[2:]
output = (seq_relationship_scores,) + outputs[2:]
return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output
return TFNextSentencePredictorOutput(
logits=seq_relationship_score,
loss=next_sentence_loss,
logits=seq_relationship_scores,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
......
......@@ -307,6 +307,8 @@ class TFNextSentencePredictorOutput(ModelOutput):
Base class for outputs of models predicting if two sentences are consecutive or not.
Args:
loss (:obj:`tf.Tensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`next_sentence_label` is provided):
Next sentence prediction loss.
logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, 2)`):
Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
before SoftMax).
......@@ -323,6 +325,7 @@ class TFNextSentencePredictorOutput(ModelOutput):
heads.
"""
loss: tf.Tensor = None
logits: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None
......
......@@ -215,6 +215,27 @@ class TFMaskedLanguageModelingLoss(TFCausalLanguageModelingLoss):
"""
class TFNextSentencePredictionLoss:
"""
Loss function suitable for next sentence prediction (NSP), that is, the task of guessing the next sentence.
.. note::
Any label of -100 will be ignored (along with the corresponding logits) in the loss computation.
"""
def compute_loss(self, labels, logits):
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True, reduction=tf.keras.losses.Reduction.NONE
)
# make sure only labels that are not equal to -100
# are taken into account as loss
next_sentence_active_loss = tf.not_equal(tf.reshape(labels, (-1,)), -100)
next_sentence_reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, 2)), next_sentence_active_loss)
next_sentence_label = tf.boolean_mask(tf.reshape(labels, (-1,)), next_sentence_active_loss)
return loss_fn(next_sentence_label, next_sentence_reduced_logits)
def detect_tf_missing_unexpected_layers(model, resolved_archive_file):
"""
Detect missing and unexpected layers.
......
......@@ -35,6 +35,7 @@ if is_tf_available():
TF_MODEL_FOR_CAUSAL_LM_MAPPING,
TF_MODEL_FOR_MASKED_LM_MAPPING,
TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
......@@ -95,6 +96,8 @@ class TFModelTesterMixin:
inputs_dict["end_positions"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
elif model_class in TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.values():
inputs_dict["labels"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
elif model_class in TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING.values():
inputs_dict["next_sentence_label"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
elif model_class in [
*TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.values(),
*TF_MODEL_FOR_CAUSAL_LM_MAPPING.values(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment