Commit bab477df authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 378899878
parent 93ae9c4d
...@@ -246,10 +246,11 @@ class Trainer(_AsyncTrainer): ...@@ -246,10 +246,11 @@ class Trainer(_AsyncTrainer):
self._train_loss = tf.keras.metrics.Mean("training_loss", dtype=tf.float32) self._train_loss = tf.keras.metrics.Mean("training_loss", dtype=tf.float32)
self._validation_loss = tf.keras.metrics.Mean( self._validation_loss = tf.keras.metrics.Mean(
"validation_loss", dtype=tf.float32) "validation_loss", dtype=tf.float32)
model_metrics = model.metrics if hasattr(model, "metrics") else []
self._train_metrics = self.task.build_metrics( self._train_metrics = self.task.build_metrics(
training=True) + self.model.metrics training=True) + model_metrics
self._validation_metrics = self.task.build_metrics( self._validation_metrics = self.task.build_metrics(
training=False) + self.model.metrics training=False) + model_metrics
self.init_async() self.init_async()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment