"...py_src/git@developer.sourcefind.cn:change/sglang.git" did not exist on "1468769bde6feae691e101df200888125ced5fd0"
Unverified Commit da842e4e authored by Julien Plu's avatar Julien Plu Committed by GitHub
Browse files

Add next sentence prediction loss computation (#8462)

* Add next sentence prediction loss computation

* Apply style

* Fix tests

* Add forgotten import

* Add forgotten import

* Use a new parameter

* Remove kwargs and use positional arguments
parent 23290836
...@@ -46,6 +46,7 @@ from .modeling_tf_utils import ( ...@@ -46,6 +46,7 @@ from .modeling_tf_utils import (
TFCausalLanguageModelingLoss, TFCausalLanguageModelingLoss,
TFMaskedLanguageModelingLoss, TFMaskedLanguageModelingLoss,
TFMultipleChoiceLoss, TFMultipleChoiceLoss,
TFNextSentencePredictionLoss,
TFPreTrainedModel, TFPreTrainedModel,
TFQuestionAnsweringLoss, TFQuestionAnsweringLoss,
TFSequenceClassificationLoss, TFSequenceClassificationLoss,
...@@ -1036,7 +1037,7 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -1036,7 +1037,7 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss):
"""Bert Model with a `next sentence prediction (classification)` head on top. """, """Bert Model with a `next sentence prediction (classification)` head on top. """,
BERT_START_DOCSTRING, BERT_START_DOCSTRING,
) )
class TFBertForNextSentencePrediction(TFBertPreTrainedModel): class TFBertForNextSentencePrediction(TFBertPreTrainedModel, TFNextSentencePredictionLoss):
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)
...@@ -1045,7 +1046,20 @@ class TFBertForNextSentencePrediction(TFBertPreTrainedModel): ...@@ -1045,7 +1046,20 @@ class TFBertForNextSentencePrediction(TFBertPreTrainedModel):
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=TFNextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=TFNextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC)
def call(self, inputs, **kwargs): def call(
self,
inputs=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
next_sentence_label=None,
training=False,
):
r""" r"""
Return: Return:
...@@ -1064,17 +1078,43 @@ class TFBertForNextSentencePrediction(TFBertPreTrainedModel): ...@@ -1064,17 +1078,43 @@ class TFBertForNextSentencePrediction(TFBertPreTrainedModel):
>>> logits = model(encoding['input_ids'], token_type_ids=encoding['token_type_ids'])[0] >>> logits = model(encoding['input_ids'], token_type_ids=encoding['token_type_ids'])[0]
>>> assert logits[0][0] < logits[0][1] # the next sentence was random >>> assert logits[0][0] < logits[0][1] # the next sentence was random
""" """
return_dict = kwargs.get("return_dict")
return_dict = return_dict if return_dict is not None else self.bert.return_dict return_dict = return_dict if return_dict is not None else self.bert.return_dict
outputs = self.bert(inputs, **kwargs)
if isinstance(inputs, (tuple, list)):
next_sentence_label = inputs[9] if len(inputs) > 9 else next_sentence_label
if len(inputs) > 9:
inputs = inputs[:9]
elif isinstance(inputs, (dict, BatchEncoding)):
next_sentence_label = inputs.pop("next_sentence_label", next_sentence_label)
outputs = self.bert(
inputs,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
pooled_output = outputs[1] pooled_output = outputs[1]
seq_relationship_score = self.nsp(pooled_output) seq_relationship_scores = self.nsp(pooled_output)
next_sentence_loss = (
None
if next_sentence_label is None
else self.compute_loss(labels=next_sentence_label, logits=seq_relationship_scores)
)
if not return_dict: if not return_dict:
return (seq_relationship_score,) + outputs[2:] output = (seq_relationship_scores,) + outputs[2:]
return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output
return TFNextSentencePredictorOutput( return TFNextSentencePredictorOutput(
logits=seq_relationship_score, loss=next_sentence_loss,
logits=seq_relationship_scores,
hidden_states=outputs.hidden_states, hidden_states=outputs.hidden_states,
attentions=outputs.attentions, attentions=outputs.attentions,
) )
......
...@@ -44,6 +44,7 @@ from .modeling_tf_outputs import ( ...@@ -44,6 +44,7 @@ from .modeling_tf_outputs import (
from .modeling_tf_utils import ( from .modeling_tf_utils import (
TFMaskedLanguageModelingLoss, TFMaskedLanguageModelingLoss,
TFMultipleChoiceLoss, TFMultipleChoiceLoss,
TFNextSentencePredictionLoss,
TFPreTrainedModel, TFPreTrainedModel,
TFQuestionAnsweringLoss, TFQuestionAnsweringLoss,
TFSequenceClassificationLoss, TFSequenceClassificationLoss,
...@@ -1119,7 +1120,7 @@ class TFMobileBertOnlyNSPHead(tf.keras.layers.Layer): ...@@ -1119,7 +1120,7 @@ class TFMobileBertOnlyNSPHead(tf.keras.layers.Layer):
"""MobileBert Model with a `next sentence prediction (classification)` head on top. """, """MobileBert Model with a `next sentence prediction (classification)` head on top. """,
MOBILEBERT_START_DOCSTRING, MOBILEBERT_START_DOCSTRING,
) )
class TFMobileBertForNextSentencePrediction(TFMobileBertPreTrainedModel): class TFMobileBertForNextSentencePrediction(TFMobileBertPreTrainedModel, TFNextSentencePredictionLoss):
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)
...@@ -1128,7 +1129,20 @@ class TFMobileBertForNextSentencePrediction(TFMobileBertPreTrainedModel): ...@@ -1128,7 +1129,20 @@ class TFMobileBertForNextSentencePrediction(TFMobileBertPreTrainedModel):
@add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=TFNextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=TFNextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC)
def call(self, inputs, **kwargs): def call(
self,
inputs=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
next_sentence_label=None,
training=False,
):
r""" r"""
Return: Return:
...@@ -1146,18 +1160,44 @@ class TFMobileBertForNextSentencePrediction(TFMobileBertPreTrainedModel): ...@@ -1146,18 +1160,44 @@ class TFMobileBertForNextSentencePrediction(TFMobileBertPreTrainedModel):
>>> logits = model(encoding['input_ids'], token_type_ids=encoding['token_type_ids'])[0] >>> logits = model(encoding['input_ids'], token_type_ids=encoding['token_type_ids'])[0]
""" """
return_dict = kwargs.get("return_dict")
return_dict = return_dict if return_dict is not None else self.mobilebert.return_dict return_dict = return_dict if return_dict is not None else self.mobilebert.return_dict
outputs = self.mobilebert(inputs, **kwargs)
if isinstance(inputs, (tuple, list)):
next_sentence_label = inputs[9] if len(inputs) > 9 else next_sentence_label
if len(inputs) > 9:
inputs = inputs[:9]
elif isinstance(inputs, (dict, BatchEncoding)):
next_sentence_label = inputs.pop("next_sentence_label", next_sentence_label)
outputs = self.mobilebert(
inputs,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
pooled_output = outputs[1] pooled_output = outputs[1]
seq_relationship_score = self.cls(pooled_output) seq_relationship_scores = self.cls(pooled_output)
next_sentence_loss = (
None
if next_sentence_label is None
else self.compute_loss(labels=next_sentence_label, logits=seq_relationship_scores)
)
if not return_dict: if not return_dict:
return (seq_relationship_score,) + outputs[2:] output = (seq_relationship_scores,) + outputs[2:]
return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output
return TFNextSentencePredictorOutput( return TFNextSentencePredictorOutput(
logits=seq_relationship_score, loss=next_sentence_loss,
logits=seq_relationship_scores,
hidden_states=outputs.hidden_states, hidden_states=outputs.hidden_states,
attentions=outputs.attentions, attentions=outputs.attentions,
) )
......
...@@ -307,6 +307,8 @@ class TFNextSentencePredictorOutput(ModelOutput): ...@@ -307,6 +307,8 @@ class TFNextSentencePredictorOutput(ModelOutput):
Base class for outputs of models predicting if two sentences are consecutive or not. Base class for outputs of models predicting if two sentences are consecutive or not.
Args: Args:
loss (:obj:`tf.Tensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`next_sentence_label` is provided):
Next sentence prediction loss.
logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, 2)`): logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, 2)`):
Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
before SoftMax). before SoftMax).
...@@ -323,6 +325,7 @@ class TFNextSentencePredictorOutput(ModelOutput): ...@@ -323,6 +325,7 @@ class TFNextSentencePredictorOutput(ModelOutput):
heads. heads.
""" """
loss: tf.Tensor = None
logits: tf.Tensor = None logits: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None attentions: Optional[Tuple[tf.Tensor]] = None
......
...@@ -215,6 +215,27 @@ class TFMaskedLanguageModelingLoss(TFCausalLanguageModelingLoss): ...@@ -215,6 +215,27 @@ class TFMaskedLanguageModelingLoss(TFCausalLanguageModelingLoss):
""" """
class TFNextSentencePredictionLoss:
"""
Loss function suitable for next sentence prediction (NSP), that is, the task of guessing the next sentence.
.. note::
Any label of -100 will be ignored (along with the corresponding logits) in the loss computation.
"""
def compute_loss(self, labels, logits):
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True, reduction=tf.keras.losses.Reduction.NONE
)
# make sure only labels that are not equal to -100
# are taken into account as loss
next_sentence_active_loss = tf.not_equal(tf.reshape(labels, (-1,)), -100)
next_sentence_reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, 2)), next_sentence_active_loss)
next_sentence_label = tf.boolean_mask(tf.reshape(labels, (-1,)), next_sentence_active_loss)
return loss_fn(next_sentence_label, next_sentence_reduced_logits)
def detect_tf_missing_unexpected_layers(model, resolved_archive_file): def detect_tf_missing_unexpected_layers(model, resolved_archive_file):
""" """
Detect missing and unexpected layers. Detect missing and unexpected layers.
......
...@@ -35,6 +35,7 @@ if is_tf_available(): ...@@ -35,6 +35,7 @@ if is_tf_available():
TF_MODEL_FOR_CAUSAL_LM_MAPPING, TF_MODEL_FOR_CAUSAL_LM_MAPPING,
TF_MODEL_FOR_MASKED_LM_MAPPING, TF_MODEL_FOR_MASKED_LM_MAPPING,
TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING, TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING, TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
...@@ -95,6 +96,8 @@ class TFModelTesterMixin: ...@@ -95,6 +96,8 @@ class TFModelTesterMixin:
inputs_dict["end_positions"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32) inputs_dict["end_positions"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
elif model_class in TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.values(): elif model_class in TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.values():
inputs_dict["labels"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32) inputs_dict["labels"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
elif model_class in TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING.values():
inputs_dict["next_sentence_label"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
elif model_class in [ elif model_class in [
*TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.values(), *TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.values(),
*TF_MODEL_FOR_CAUSAL_LM_MAPPING.values(), *TF_MODEL_FOR_CAUSAL_LM_MAPPING.values(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment