Unverified Commit 26d9212e authored by Matt's avatar Matt Committed by GitHub
Browse files

TF multiple choice loss fix (#13513)

Fix issues with `TFMultipleChoiceLoss` if the choices dimension is None when `build()` is called.
parent d7b3b709
......@@ -220,9 +220,15 @@ class TFSequenceClassificationLoss:
return loss_fn(labels, logits)
class TFMultipleChoiceLoss(TFSequenceClassificationLoss):
class TFMultipleChoiceLoss:
"""Loss function suitable for multiple choice tasks."""
def compute_loss(self, labels, logits):
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True, reduction=tf.keras.losses.Reduction.NONE
)
return loss_fn(labels, logits)
class TFMaskedLanguageModelingLoss(TFCausalLanguageModelingLoss):
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment