Unverified Commit 74f16b82 authored by Kiyoung Kim's avatar Kiyoung Kim Committed by GitHub
Browse files

TFBart lables consider both pad token and -100 (#9847)



* TFBart lables consider both pad token and -100

* make style

* fix for all other models

Co-authored-by: kykim <kykim>
Co-authored-by: default avatarpatrickvonplaten <patrick.v.platen@gmail.com>
parent 22121e81
...@@ -38,6 +38,7 @@ from ...modeling_tf_outputs import ( ...@@ -38,6 +38,7 @@ from ...modeling_tf_outputs import (
# Public API # Public API
from ...modeling_tf_utils import ( from ...modeling_tf_utils import (
DUMMY_INPUTS, DUMMY_INPUTS,
TFCausalLanguageModelingLoss,
TFPreTrainedModel, TFPreTrainedModel,
TFSharedEmbeddings, TFSharedEmbeddings,
TFWrappedEmbeddings, TFWrappedEmbeddings,
...@@ -1234,7 +1235,7 @@ class TFBartModel(TFBartPretrainedModel): ...@@ -1234,7 +1235,7 @@ class TFBartModel(TFBartPretrainedModel):
"The BART Model with a language modeling head. Can be used for summarization.", "The BART Model with a language modeling head. Can be used for summarization.",
BART_START_DOCSTRING, BART_START_DOCSTRING,
) )
class TFBartForConditionalGeneration(TFBartPretrainedModel): class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageModelingLoss):
_keys_to_ignore_on_load_unexpected = [ _keys_to_ignore_on_load_unexpected = [
r"model.encoder.embed_tokens.weight", r"model.encoder.embed_tokens.weight",
r"model.decoder.embed_tokens.weight", r"model.decoder.embed_tokens.weight",
...@@ -1322,6 +1323,11 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel): ...@@ -1322,6 +1323,11 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel):
) )
if inputs["labels"] is not None: if inputs["labels"] is not None:
inputs["labels"] = tf.where(
inputs["labels"] == self.config.pad_token_id,
tf.fill(shape_list(inputs["labels"]), -100),
inputs["labels"],
)
inputs["use_cache"] = False inputs["use_cache"] = False
if inputs["decoder_input_ids"] is None: if inputs["decoder_input_ids"] is None:
inputs["decoder_input_ids"] = shift_tokens_right( inputs["decoder_input_ids"] = shift_tokens_right(
...@@ -1448,15 +1454,3 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel): ...@@ -1448,15 +1454,3 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel):
return tf.where(vocab_range != self.config.eos_token_id, LARGE_NEGATIVE, logits) return tf.where(vocab_range != self.config.eos_token_id, LARGE_NEGATIVE, logits)
else: else:
return logits return logits
def compute_loss(self, labels, logits):
"""CrossEntropyLoss that ignores pad tokens"""
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True,
reduction=tf.keras.losses.Reduction.NONE,
)
melted_labels = tf.reshape(labels, (-1,))
active_loss = tf.not_equal(melted_labels, self.config.pad_token_id)
reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
labels = tf.boolean_mask(melted_labels, active_loss)
return loss_fn(labels, reduced_logits)
...@@ -40,6 +40,7 @@ from ...modeling_tf_outputs import ( ...@@ -40,6 +40,7 @@ from ...modeling_tf_outputs import (
# Public API # Public API
from ...modeling_tf_utils import ( from ...modeling_tf_utils import (
DUMMY_INPUTS, DUMMY_INPUTS,
TFCausalLanguageModelingLoss,
TFPreTrainedModel, TFPreTrainedModel,
TFSharedEmbeddings, TFSharedEmbeddings,
TFWrappedEmbeddings, TFWrappedEmbeddings,
...@@ -1251,7 +1252,7 @@ class TFBlenderbotModel(TFBlenderbotPreTrainedModel): ...@@ -1251,7 +1252,7 @@ class TFBlenderbotModel(TFBlenderbotPreTrainedModel):
"The BLENDERBOT Model with a language modeling head. Can be used for summarization.", "The BLENDERBOT Model with a language modeling head. Can be used for summarization.",
BLENDERBOT_START_DOCSTRING, BLENDERBOT_START_DOCSTRING,
) )
class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel): class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel, TFCausalLanguageModelingLoss):
_keys_to_ignore_on_load_unexpected = [ _keys_to_ignore_on_load_unexpected = [
r"model.encoder.embed_tokens.weight", r"model.encoder.embed_tokens.weight",
r"model.decoder.embed_tokens.weight", r"model.decoder.embed_tokens.weight",
...@@ -1352,6 +1353,12 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel): ...@@ -1352,6 +1353,12 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel):
) )
if inputs["labels"] is not None: if inputs["labels"] is not None:
inputs["labels"] = tf.where(
inputs["labels"] == self.config.pad_token_id,
tf.fill(shape_list(inputs["labels"]), -100),
inputs["labels"],
)
inputs["use_cache"] = False
if inputs["decoder_input_ids"] is None: if inputs["decoder_input_ids"] is None:
inputs["decoder_input_ids"] = shift_tokens_right( inputs["decoder_input_ids"] = shift_tokens_right(
inputs["labels"], self.config.pad_token_id, self.config.decoder_start_token_id inputs["labels"], self.config.pad_token_id, self.config.decoder_start_token_id
...@@ -1477,16 +1484,3 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel): ...@@ -1477,16 +1484,3 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel):
return tf.where(vocab_range != self.config.eos_token_id, LARGE_NEGATIVE, logits) return tf.where(vocab_range != self.config.eos_token_id, LARGE_NEGATIVE, logits)
else: else:
return logits return logits
# Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration.compute_loss
def compute_loss(self, labels, logits):
"""CrossEntropyLoss that ignores pad tokens"""
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True,
reduction=tf.keras.losses.Reduction.NONE,
)
melted_labels = tf.reshape(labels, (-1,))
active_loss = tf.not_equal(melted_labels, self.config.pad_token_id)
reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
labels = tf.boolean_mask(melted_labels, active_loss)
return loss_fn(labels, reduced_logits)
...@@ -38,6 +38,7 @@ from ...modeling_tf_outputs import ( ...@@ -38,6 +38,7 @@ from ...modeling_tf_outputs import (
# Public API # Public API
from ...modeling_tf_utils import ( from ...modeling_tf_utils import (
DUMMY_INPUTS, DUMMY_INPUTS,
TFCausalLanguageModelingLoss,
TFPreTrainedModel, TFPreTrainedModel,
TFSharedEmbeddings, TFSharedEmbeddings,
TFWrappedEmbeddings, TFWrappedEmbeddings,
...@@ -1239,7 +1240,7 @@ class TFBlenderbotSmallModel(TFBlenderbotSmallPreTrainedModel): ...@@ -1239,7 +1240,7 @@ class TFBlenderbotSmallModel(TFBlenderbotSmallPreTrainedModel):
"The BLENDERBOT_SMALL Model with a language modeling head. Can be used for summarization.", "The BLENDERBOT_SMALL Model with a language modeling head. Can be used for summarization.",
BLENDERBOT_SMALL_START_DOCSTRING, BLENDERBOT_SMALL_START_DOCSTRING,
) )
class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel): class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel, TFCausalLanguageModelingLoss):
_keys_to_ignore_on_load_unexpected = [ _keys_to_ignore_on_load_unexpected = [
r"model.encoder.embed_tokens.weight", r"model.encoder.embed_tokens.weight",
r"model.decoder.embed_tokens.weight", r"model.decoder.embed_tokens.weight",
...@@ -1327,6 +1328,12 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel ...@@ -1327,6 +1328,12 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel
) )
if inputs["labels"] is not None: if inputs["labels"] is not None:
inputs["labels"] = tf.where(
inputs["labels"] == self.config.pad_token_id,
tf.fill(shape_list(inputs["labels"]), -100),
inputs["labels"],
)
inputs["use_cache"] = False
if inputs["decoder_input_ids"] is None: if inputs["decoder_input_ids"] is None:
inputs["decoder_input_ids"] = shift_tokens_right( inputs["decoder_input_ids"] = shift_tokens_right(
inputs["labels"], self.config.pad_token_id, self.config.decoder_start_token_id inputs["labels"], self.config.pad_token_id, self.config.decoder_start_token_id
...@@ -1452,16 +1459,3 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel ...@@ -1452,16 +1459,3 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel
return tf.where(vocab_range != self.config.eos_token_id, LARGE_NEGATIVE, logits) return tf.where(vocab_range != self.config.eos_token_id, LARGE_NEGATIVE, logits)
else: else:
return logits return logits
# Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration.compute_loss
def compute_loss(self, labels, logits):
"""CrossEntropyLoss that ignores pad tokens"""
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True,
reduction=tf.keras.losses.Reduction.NONE,
)
melted_labels = tf.reshape(labels, (-1,))
active_loss = tf.not_equal(melted_labels, self.config.pad_token_id)
reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
labels = tf.boolean_mask(melted_labels, active_loss)
return loss_fn(labels, reduced_logits)
...@@ -39,6 +39,7 @@ from ...modeling_tf_outputs import ( ...@@ -39,6 +39,7 @@ from ...modeling_tf_outputs import (
# Public API # Public API
from ...modeling_tf_utils import ( from ...modeling_tf_utils import (
DUMMY_INPUTS, DUMMY_INPUTS,
TFCausalLanguageModelingLoss,
TFPreTrainedModel, TFPreTrainedModel,
TFSharedEmbeddings, TFSharedEmbeddings,
TFWrappedEmbeddings, TFWrappedEmbeddings,
...@@ -1256,7 +1257,7 @@ class TFMarianModel(TFMarianPreTrainedModel): ...@@ -1256,7 +1257,7 @@ class TFMarianModel(TFMarianPreTrainedModel):
"The MARIAN Model with a language modeling head. Can be used for summarization.", "The MARIAN Model with a language modeling head. Can be used for summarization.",
MARIAN_START_DOCSTRING, MARIAN_START_DOCSTRING,
) )
class TFMarianMTModel(TFMarianPreTrainedModel): class TFMarianMTModel(TFMarianPreTrainedModel, TFCausalLanguageModelingLoss):
_keys_to_ignore_on_load_unexpected = [ _keys_to_ignore_on_load_unexpected = [
r"model.encoder.embed_tokens.weight", r"model.encoder.embed_tokens.weight",
r"model.decoder.embed_tokens.weight", r"model.decoder.embed_tokens.weight",
...@@ -1344,6 +1345,11 @@ class TFMarianMTModel(TFMarianPreTrainedModel): ...@@ -1344,6 +1345,11 @@ class TFMarianMTModel(TFMarianPreTrainedModel):
) )
if inputs["labels"] is not None: if inputs["labels"] is not None:
inputs["labels"] = tf.where(
inputs["labels"] == self.config.pad_token_id,
tf.fill(shape_list(inputs["labels"]), -100),
inputs["labels"],
)
inputs["use_cache"] = False inputs["use_cache"] = False
if inputs["decoder_input_ids"] is None: if inputs["decoder_input_ids"] is None:
inputs["decoder_input_ids"] = shift_tokens_right( inputs["decoder_input_ids"] = shift_tokens_right(
...@@ -1471,16 +1477,3 @@ class TFMarianMTModel(TFMarianPreTrainedModel): ...@@ -1471,16 +1477,3 @@ class TFMarianMTModel(TFMarianPreTrainedModel):
if cur_len == max_length - 1: if cur_len == max_length - 1:
logits = tf.where(vocab_range != self.config.eos_token_id, LARGE_NEGATIVE, logits) logits = tf.where(vocab_range != self.config.eos_token_id, LARGE_NEGATIVE, logits)
return logits return logits
# Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration.compute_loss
def compute_loss(self, labels, logits):
"""CrossEntropyLoss that ignores pad tokens"""
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True,
reduction=tf.keras.losses.Reduction.NONE,
)
melted_labels = tf.reshape(labels, (-1,))
active_loss = tf.not_equal(melted_labels, self.config.pad_token_id)
reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
labels = tf.boolean_mask(melted_labels, active_loss)
return loss_fn(labels, reduced_logits)
...@@ -38,6 +38,7 @@ from ...modeling_tf_outputs import ( ...@@ -38,6 +38,7 @@ from ...modeling_tf_outputs import (
# Public API # Public API
from ...modeling_tf_utils import ( from ...modeling_tf_utils import (
DUMMY_INPUTS, DUMMY_INPUTS,
TFCausalLanguageModelingLoss,
TFPreTrainedModel, TFPreTrainedModel,
TFSharedEmbeddings, TFSharedEmbeddings,
TFWrappedEmbeddings, TFWrappedEmbeddings,
...@@ -1257,7 +1258,7 @@ class TFMBartModel(TFMBartPreTrainedModel): ...@@ -1257,7 +1258,7 @@ class TFMBartModel(TFMBartPreTrainedModel):
"The MBART Model with a language modeling head. Can be used for summarization.", "The MBART Model with a language modeling head. Can be used for summarization.",
MBART_START_DOCSTRING, MBART_START_DOCSTRING,
) )
class TFMBartForConditionalGeneration(TFMBartPreTrainedModel): class TFMBartForConditionalGeneration(TFMBartPreTrainedModel, TFCausalLanguageModelingLoss):
_keys_to_ignore_on_load_unexpected = [ _keys_to_ignore_on_load_unexpected = [
r"model.encoder.embed_tokens.weight", r"model.encoder.embed_tokens.weight",
r"model.decoder.embed_tokens.weight", r"model.decoder.embed_tokens.weight",
...@@ -1345,6 +1346,11 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel): ...@@ -1345,6 +1346,11 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel):
) )
if inputs["labels"] is not None: if inputs["labels"] is not None:
inputs["labels"] = tf.where(
inputs["labels"] == self.config.pad_token_id,
tf.fill(shape_list(inputs["labels"]), -100),
inputs["labels"],
)
inputs["use_cache"] = False inputs["use_cache"] = False
if inputs["decoder_input_ids"] is None: if inputs["decoder_input_ids"] is None:
inputs["decoder_input_ids"] = shift_tokens_right(inputs["labels"], self.config.pad_token_id) inputs["decoder_input_ids"] = shift_tokens_right(inputs["labels"], self.config.pad_token_id)
...@@ -1469,16 +1475,3 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel): ...@@ -1469,16 +1475,3 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel):
return tf.where(vocab_range != self.config.eos_token_id, LARGE_NEGATIVE, logits) return tf.where(vocab_range != self.config.eos_token_id, LARGE_NEGATIVE, logits)
else: else:
return logits return logits
# Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration.compute_loss
def compute_loss(self, labels, logits):
"""CrossEntropyLoss that ignores pad tokens"""
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True,
reduction=tf.keras.losses.Reduction.NONE,
)
melted_labels = tf.reshape(labels, (-1,))
active_loss = tf.not_equal(melted_labels, self.config.pad_token_id)
reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
labels = tf.boolean_mask(melted_labels, active_loss)
return loss_fn(labels, reduced_logits)
...@@ -39,6 +39,7 @@ from ...modeling_tf_outputs import ( ...@@ -39,6 +39,7 @@ from ...modeling_tf_outputs import (
# Public API # Public API
from ...modeling_tf_utils import ( from ...modeling_tf_utils import (
DUMMY_INPUTS, DUMMY_INPUTS,
TFCausalLanguageModelingLoss,
TFPreTrainedModel, TFPreTrainedModel,
TFSharedEmbeddings, TFSharedEmbeddings,
TFWrappedEmbeddings, TFWrappedEmbeddings,
...@@ -1270,7 +1271,7 @@ class TFPegasusModel(TFPegasusPreTrainedModel): ...@@ -1270,7 +1271,7 @@ class TFPegasusModel(TFPegasusPreTrainedModel):
"The PEGASUS Model with a language modeling head. Can be used for summarization.", "The PEGASUS Model with a language modeling head. Can be used for summarization.",
PEGASUS_START_DOCSTRING, PEGASUS_START_DOCSTRING,
) )
class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel): class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLanguageModelingLoss):
_keys_to_ignore_on_load_unexpected = [ _keys_to_ignore_on_load_unexpected = [
r"model.encoder.embed_tokens.weight", r"model.encoder.embed_tokens.weight",
r"model.decoder.embed_tokens.weight", r"model.decoder.embed_tokens.weight",
...@@ -1358,6 +1359,11 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel): ...@@ -1358,6 +1359,11 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel):
) )
if inputs["labels"] is not None: if inputs["labels"] is not None:
inputs["labels"] = tf.where(
inputs["labels"] == self.config.pad_token_id,
tf.fill(shape_list(inputs["labels"]), -100),
inputs["labels"],
)
inputs["use_cache"] = False inputs["use_cache"] = False
if inputs["decoder_input_ids"] is None: if inputs["decoder_input_ids"] is None:
inputs["decoder_input_ids"] = shift_tokens_right( inputs["decoder_input_ids"] = shift_tokens_right(
...@@ -1484,16 +1490,3 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel): ...@@ -1484,16 +1490,3 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel):
return tf.where(vocab_range != self.config.eos_token_id, LARGE_NEGATIVE, logits) return tf.where(vocab_range != self.config.eos_token_id, LARGE_NEGATIVE, logits)
else: else:
return logits return logits
# Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration.compute_loss
def compute_loss(self, labels, logits):
"""CrossEntropyLoss that ignores pad tokens"""
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True,
reduction=tf.keras.losses.Reduction.NONE,
)
melted_labels = tf.reshape(labels, (-1,))
active_loss = tf.not_equal(melted_labels, self.config.pad_token_id)
reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
labels = tf.boolean_mask(melted_labels, active_loss)
return loss_fn(labels, reduced_logits)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment