Commit 48192d54 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Add global_step to reduce_aggregated_logs().

PiperOrigin-RevId: 360256877
parent c2793c1b
......@@ -291,6 +291,8 @@ class Task(tf.Module, metaclass=abc.ABCMeta):
"""Optional aggregation over logs returned from a validation step."""
pass
def reduce_aggregated_logs(self, aggregated_logs):
def reduce_aggregated_logs(self,
aggregated_logs,
global_step: Optional[tf.Tensor] = None):
"""Optional reduce of aggregated logs over validation steps."""
return {}
......@@ -334,7 +334,8 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator):
# loss was not returned from the task's `validation_step` method.
logging.info("The task did not report validation loss.")
if aggregated_logs:
metrics = self.task.reduce_aggregated_logs(aggregated_logs)
metrics = self.task.reduce_aggregated_logs(
aggregated_logs, global_step=self.global_step)
logs.update(metrics)
if self._checkpoint_exporter:
......
......@@ -161,7 +161,8 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator):
for metric in task_metrics + [task_loss]:
logs[metric.name] = metric.result()
if outputs:
metrics = task.reduce_aggregated_logs(outputs)
metrics = task.reduce_aggregated_logs(
outputs, global_step=self.global_step)
logs.update(metrics)
results[name] = logs
......
......@@ -89,7 +89,9 @@ class MockTask(base_task.Task):
np.concatenate([np.expand_dims(v.numpy(), axis=0) for v in value]))
return state
def reduce_aggregated_logs(self, aggregated_logs):
def reduce_aggregated_logs(self,
aggregated_logs,
global_step=None):
for k, v in aggregated_logs.items():
aggregated_logs[k] = np.sum(np.stack(v, axis=0))
return aggregated_logs
......
......@@ -277,7 +277,7 @@ class QuestionAnsweringTask(base_task.Task):
end_logits=values[2]))
return state
def reduce_aggregated_logs(self, aggregated_logs):
def reduce_aggregated_logs(self, aggregated_logs, global_step=None):
all_predictions, _, scores_diff = (
self.squad_lib.postprocess_output(
self._eval_examples,
......
......@@ -183,7 +183,7 @@ class SentencePredictionTask(base_task.Task):
np.concatenate([v.numpy() for v in step_outputs['labels']], axis=0))
return state
def reduce_aggregated_logs(self, aggregated_logs):
def reduce_aggregated_logs(self, aggregated_logs, global_step=None):
if self.metric_type == 'accuracy':
return None
elif self.metric_type == 'matthews_corrcoef':
......
......@@ -189,7 +189,7 @@ class TaggingTask(base_task.Task):
state['label_class'].extend(id_to_class_name(step_outputs['label_ids']))
return state
def reduce_aggregated_logs(self, aggregated_logs):
def reduce_aggregated_logs(self, aggregated_logs, global_step=None):
"""Reduces aggregated logs over validation steps."""
label_class = aggregated_logs['label_class']
predict_class = aggregated_logs['predict_class']
......
......@@ -338,7 +338,7 @@ class TranslationTask(base_task.Task):
state[u_id] = (in_ids, out_ids)
return state
def reduce_aggregated_logs(self, aggregated_logs):
def reduce_aggregated_logs(self, aggregated_logs, global_step=None):
def _decode(ids):
return self._sp_tokenizer.detokenize(ids).numpy().decode()
......
......@@ -88,7 +88,7 @@ class MockTask(base_task.Task):
np.concatenate([np.expand_dims(v.numpy(), axis=0) for v in value]))
return state
def reduce_aggregated_logs(self, aggregated_logs):
def reduce_aggregated_logs(self, aggregated_logs, global_step=None):
for k, v in aggregated_logs.items():
aggregated_logs[k] = np.sum(np.stack(v, axis=0))
return aggregated_logs
......
......@@ -341,5 +341,5 @@ class MaskRCNNTask(base_task.Task):
step_outputs[self.coco_metric.name][1])
return state
def reduce_aggregated_logs(self, aggregated_logs):
def reduce_aggregated_logs(self, aggregated_logs, global_step=None):
return self.coco_metric.result()
......@@ -292,5 +292,5 @@ class RetinaNetTask(base_task.Task):
step_outputs[self.coco_metric.name][1])
return state
def reduce_aggregated_logs(self, aggregated_logs):
def reduce_aggregated_logs(self, aggregated_logs, global_step=None):
return self.coco_metric.result()
......@@ -263,7 +263,7 @@ class SemanticSegmentationTask(base_task.Task):
step_outputs[self.iou_metric.name][1])
return state
def reduce_aggregated_logs(self, aggregated_logs):
def reduce_aggregated_logs(self, aggregated_logs, global_step=None):
result = {}
ious = self.iou_metric.result()
# TODO(arashwan): support loading class name from a label map file.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment