"git@developer.sourcefind.cn:Fzc7075/nunchaku.git" did not exist on "6cba75248689e370b41b6765b741d6af07a619ad"
Unverified Commit 3a1a56a8 authored by Matt's avatar Matt Committed by GitHub
Browse files

Fix for sequence regression fit() in TF (#19316)


Co-authored-by: default avatarYour Name <you@example.com>
parent fe10796f
...@@ -274,6 +274,9 @@ class TFSequenceClassificationLoss: ...@@ -274,6 +274,9 @@ class TFSequenceClassificationLoss:
def hf_compute_loss(self, labels, logits): def hf_compute_loss(self, labels, logits):
if logits.shape.rank == 1 or logits.shape[1] == 1: if logits.shape.rank == 1 or logits.shape[1] == 1:
loss_fn = tf.keras.losses.MeanSquaredError(reduction=tf.keras.losses.Reduction.NONE) loss_fn = tf.keras.losses.MeanSquaredError(reduction=tf.keras.losses.Reduction.NONE)
if labels.shape.rank == 1:
# MeanSquaredError returns a scalar loss if the labels are 1D, so avoid that
labels = tf.expand_dims(labels, axis=-1)
else: else:
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy( loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True, reduction=tf.keras.losses.Reduction.NONE from_logits=True, reduction=tf.keras.losses.Reduction.NONE
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment