Commit df0179e0 authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 355065007
parent 59434199
...@@ -320,9 +320,12 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator): ...@@ -320,9 +320,12 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator):
# `self.validation_loss` metric was not updated, because the validation # `self.validation_loss` metric was not updated, because the validation
# loss was not returned from the task's `validation_step` method. # loss was not returned from the task's `validation_step` method.
logging.info("The task did not report validation loss.") logging.info("The task did not report validation loss.")
if aggregated_logs:
metrics = self.task.reduce_aggregated_logs(aggregated_logs) # Merges additional metrics from `reduce_aggregated_logs` method.
logs.update(metrics) # By default, the method in `base_task.Task` returns an empty dict, while
# the subclass may override it to return metrics computed on host.
metrics = self.task.reduce_aggregated_logs(aggregated_logs)
logs.update(metrics)
if self._checkpoint_exporter: if self._checkpoint_exporter:
self._checkpoint_exporter.maybe_export_checkpoint( self._checkpoint_exporter.maybe_export_checkpoint(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment