Commit 66404fbb authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 366819679
parent 90e17531
...@@ -217,6 +217,12 @@ class Task(tf.Module, metaclass=abc.ABCMeta): ...@@ -217,6 +217,12 @@ class Task(tf.Module, metaclass=abc.ABCMeta):
with tf.GradientTape() as tape: with tf.GradientTape() as tape:
outputs = model(features, training=True) outputs = model(features, training=True)
# Computes per-replica loss. # Computes per-replica loss.
if model.compiled_loss:
loss = model.compiled_loss(
labels, outputs, regularization_losses=model.losses)
loss += self.build_losses(
labels=labels, model_outputs=outputs, aux_losses=None)
else:
loss = self.build_losses( loss = self.build_losses(
labels=labels, model_outputs=outputs, aux_losses=model.losses) labels=labels, model_outputs=outputs, aux_losses=model.losses)
# Scales loss as the default gradients allreduce performs sum inside the # Scales loss as the default gradients allreduce performs sum inside the
......
...@@ -368,6 +368,18 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -368,6 +368,18 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase):
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
_ = trainer.train(tf.convert_to_tensor(2, dtype=tf.int32)) _ = trainer.train(tf.convert_to_tensor(2, dtype=tf.int32))
def test_model_with_compiled_loss(self):
task = mock_task.MockTask()
model = task.build_model()
model.compile(loss=tf.keras.losses.CategoricalCrossentropy())
trainer = trainer_lib.Trainer(
self._config,
task,
model=model,
optimizer=task.create_optimizer(self._config.trainer.optimizer_config))
logs = trainer.train(tf.convert_to_tensor(5, dtype=tf.int32))
self.assertIn('training_loss', logs)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment