Commit e4748866 authored by Denali Molitor's avatar Denali Molitor Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 457601447
parent ca6d7c57
......@@ -15,7 +15,7 @@
"""TFM common benchmark training driver."""
import os
import time
from typing import Any, Mapping
from typing import Any, Mapping, Optional
from absl import logging
import orbit
......@@ -29,6 +29,19 @@ from official.modeling import performance
from official.projects.token_dropping import experiment_configs # pylint: disable=unused-import
class _OutputRecorderAction:
"""Simple `Action` that saves the outputs passed to `__call__`."""
def __init__(self):
self.train_output = {}
def __call__(
self,
output: Optional[Mapping[str, tf.Tensor]] = None) -> Mapping[str, Any]:
self.train_output = {k: v.numpy() for k, v in output.items()
} if output else {}
def run_benchmark(
execution_mode: str,
params: config_definitions.ExperimentConfig,
......@@ -82,10 +95,13 @@ def run_benchmark(
steps_per_loop = params.trainer.steps_per_loop if (
execution_mode in ['accuracy', 'tflite_accuracy']) else 100
train_output_recorder = _OutputRecorderAction()
controller = orbit.Controller(
strategy=strategy,
trainer=trainer,
evaluator=trainer if (execution_mode == 'accuracy') else None,
train_actions=[train_output_recorder],
global_step=trainer.global_step,
steps_per_loop=steps_per_loop)
......@@ -108,7 +124,10 @@ def run_benchmark(
tf.convert_to_tensor(params.trainer.validation_steps))
benchmark_data = {'metrics': eval_logs}
elif execution_mode == 'performance':
benchmark_data = {}
if train_output_recorder.train_output:
benchmark_data = {'metrics': train_output_recorder.train_output}
else:
benchmark_data = {}
elif execution_mode == 'tflite_accuracy':
eval_logs = tflite_utils.train_and_evaluate(
params, task, trainer, controller)
......
......@@ -80,9 +80,7 @@ class BenchmarkLibTest(tf.test.TestCase, parameterized.TestCase):
self.assertIn('examples_per_second', benchmark_data)
self.assertIn('wall_time', benchmark_data)
self.assertIn('startup_time', benchmark_data)
if execution_mode == 'accuracy':
self.assertIn('metrics', benchmark_data)
self.assertIn('metrics', benchmark_data)
if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment