Commit ffb6dbf3 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Use sparse_categorical_crossentropy for test as the loss object default does...

Use sparse_categorical_crossentropy for test as the loss object default does not work on tpustrategy +
the single task trainer already handles the reduction.

PiperOrigin-RevId: 367757677
parent e353e4e5
......@@ -107,6 +107,10 @@ class SingleTaskTrainer(orbit.StandardTrainer):
# replicas. This ensures that we don't end up multiplying our loss by
# the number of workers - gradients are summed, not averaged, across
# replicas during the apply_gradients call.
# Note, the reduction of loss is explicitly handled and scaled by
# num_replicas_in_sync. Recommend to use a plain loss function.
# If you're using tf.keras.losses.Loss object, you may need to set
# reduction argument explicitly.
loss = tf.reduce_mean(self.loss_fn(target, output))
scaled_loss = loss / self.strategy.num_replicas_in_sync
......
......@@ -30,14 +30,15 @@ class SingleTaskTrainerTest(tf.test.TestCase):
tf.keras.Input(shape=(4,), name='features'),
tf.keras.layers.Dense(10, activation=tf.nn.relu),
tf.keras.layers.Dense(10, activation=tf.nn.relu),
tf.keras.layers.Dense(3)
tf.keras.layers.Dense(3),
tf.keras.layers.Softmax(),
])
trainer = single_task_trainer.SingleTaskTrainer(
train_ds,
label_key='label',
model=model,
loss_fn=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
loss_fn=tf.keras.losses.sparse_categorical_crossentropy,
optimizer=tf.keras.optimizers.SGD(learning_rate=0.01))
controller = orbit.Controller(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment