Commit 1a8f129b authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

remove calls to metrics.result() inside the train_step

PiperOrigin-RevId: 364096705
parent 2b91303a
...@@ -302,7 +302,6 @@ class TranslationTask(base_task.Task): ...@@ -302,7 +302,6 @@ class TranslationTask(base_task.Task):
logs = {self.loss: loss} logs = {self.loss: loss}
if metrics: if metrics:
self.process_metrics(metrics, inputs["targets"], outputs) self.process_metrics(metrics, inputs["targets"], outputs)
logs.update({m.name: m.result() for m in metrics})
return logs return logs
def validation_step(self, inputs, model: tf.keras.Model, metrics=None): def validation_step(self, inputs, model: tf.keras.Model, metrics=None):
......
...@@ -199,7 +199,6 @@ class ImageClassificationTask(base_task.Task): ...@@ -199,7 +199,6 @@ class ImageClassificationTask(base_task.Task):
logs = {self.loss: loss} logs = {self.loss: loss}
if metrics: if metrics:
self.process_metrics(metrics, labels, outputs) self.process_metrics(metrics, labels, outputs)
logs.update({m.name: m.result() for m in metrics})
elif model.compiled_metrics: elif model.compiled_metrics:
self.process_compiled_metrics(model.compiled_metrics, labels, outputs) self.process_compiled_metrics(model.compiled_metrics, labels, outputs)
logs.update({m.name: m.result() for m in model.metrics}) logs.update({m.name: m.result() for m in model.metrics})
...@@ -228,7 +227,6 @@ class ImageClassificationTask(base_task.Task): ...@@ -228,7 +227,6 @@ class ImageClassificationTask(base_task.Task):
logs = {self.loss: loss} logs = {self.loss: loss}
if metrics: if metrics:
self.process_metrics(metrics, labels, outputs) self.process_metrics(metrics, labels, outputs)
logs.update({m.name: m.result() for m in metrics})
elif model.compiled_metrics: elif model.compiled_metrics:
self.process_compiled_metrics(model.compiled_metrics, labels, outputs) self.process_compiled_metrics(model.compiled_metrics, labels, outputs)
logs.update({m.name: m.result() for m in model.metrics}) logs.update({m.name: m.result() for m in model.metrics})
......
...@@ -291,7 +291,6 @@ class MaskRCNNTask(base_task.Task): ...@@ -291,7 +291,6 @@ class MaskRCNNTask(base_task.Task):
if metrics: if metrics:
for m in metrics: for m in metrics:
m.update_state(losses[m.name]) m.update_state(losses[m.name])
logs.update({m.name: m.result()})
return logs return logs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment