Commit e4748866 authored by Denali Molitor's avatar Denali Molitor Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 457601447
parent ca6d7c57
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
"""TFM common benchmark training driver.""" """TFM common benchmark training driver."""
import os import os
import time import time
from typing import Any, Mapping from typing import Any, Mapping, Optional
from absl import logging from absl import logging
import orbit import orbit
...@@ -29,6 +29,19 @@ from official.modeling import performance ...@@ -29,6 +29,19 @@ from official.modeling import performance
from official.projects.token_dropping import experiment_configs # pylint: disable=unused-import from official.projects.token_dropping import experiment_configs # pylint: disable=unused-import
class _OutputRecorderAction:
"""Simple `Action` that saves the outputs passed to `__call__`."""
def __init__(self):
self.train_output = {}
def __call__(
self,
output: Optional[Mapping[str, tf.Tensor]] = None) -> Mapping[str, Any]:
self.train_output = {k: v.numpy() for k, v in output.items()
} if output else {}
def run_benchmark( def run_benchmark(
execution_mode: str, execution_mode: str,
params: config_definitions.ExperimentConfig, params: config_definitions.ExperimentConfig,
...@@ -82,10 +95,13 @@ def run_benchmark( ...@@ -82,10 +95,13 @@ def run_benchmark(
steps_per_loop = params.trainer.steps_per_loop if ( steps_per_loop = params.trainer.steps_per_loop if (
execution_mode in ['accuracy', 'tflite_accuracy']) else 100 execution_mode in ['accuracy', 'tflite_accuracy']) else 100
train_output_recorder = _OutputRecorderAction()
controller = orbit.Controller( controller = orbit.Controller(
strategy=strategy, strategy=strategy,
trainer=trainer, trainer=trainer,
evaluator=trainer if (execution_mode == 'accuracy') else None, evaluator=trainer if (execution_mode == 'accuracy') else None,
train_actions=[train_output_recorder],
global_step=trainer.global_step, global_step=trainer.global_step,
steps_per_loop=steps_per_loop) steps_per_loop=steps_per_loop)
...@@ -108,7 +124,10 @@ def run_benchmark( ...@@ -108,7 +124,10 @@ def run_benchmark(
tf.convert_to_tensor(params.trainer.validation_steps)) tf.convert_to_tensor(params.trainer.validation_steps))
benchmark_data = {'metrics': eval_logs} benchmark_data = {'metrics': eval_logs}
elif execution_mode == 'performance': elif execution_mode == 'performance':
benchmark_data = {} if train_output_recorder.train_output:
benchmark_data = {'metrics': train_output_recorder.train_output}
else:
benchmark_data = {}
elif execution_mode == 'tflite_accuracy': elif execution_mode == 'tflite_accuracy':
eval_logs = tflite_utils.train_and_evaluate( eval_logs = tflite_utils.train_and_evaluate(
params, task, trainer, controller) params, task, trainer, controller)
......
...@@ -80,9 +80,7 @@ class BenchmarkLibTest(tf.test.TestCase, parameterized.TestCase): ...@@ -80,9 +80,7 @@ class BenchmarkLibTest(tf.test.TestCase, parameterized.TestCase):
self.assertIn('examples_per_second', benchmark_data) self.assertIn('examples_per_second', benchmark_data)
self.assertIn('wall_time', benchmark_data) self.assertIn('wall_time', benchmark_data)
self.assertIn('startup_time', benchmark_data) self.assertIn('startup_time', benchmark_data)
self.assertIn('metrics', benchmark_data)
if execution_mode == 'accuracy':
self.assertIn('metrics', benchmark_data)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment