Commit 5406d9c7 authored by Le Hou's avatar Le Hou Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 391656704
parent 5d29a2e2
...@@ -21,6 +21,7 @@ import pprint ...@@ -21,6 +21,7 @@ import pprint
# Import libraries # Import libraries
from absl import logging from absl import logging
import gin
import tensorflow as tf import tensorflow as tf
from tensorflow.python.platform import benchmark # pylint: disable=unused-import from tensorflow.python.platform import benchmark # pylint: disable=unused-import
...@@ -46,13 +47,15 @@ def _get_benchmark_params(benchmark_models, eval_tflite=False): ...@@ -46,13 +47,15 @@ def _get_benchmark_params(benchmark_models, eval_tflite=False):
benchmark_params = ( benchmark_params = (
benchmark_name, # First arg is used by ParameterizedBenchmark. benchmark_name, # First arg is used by ParameterizedBenchmark.
benchmark_name, benchmark_name,
params.get('benchmark_function') or benchmark_lib.run_benchmark,
params['experiment_type'], params['experiment_type'],
execution_mode, execution_mode,
params['platform'], params['platform'],
params['precision'], params['precision'],
params['metric_bounds'], params['metric_bounds'],
params.get('config_files') or [], params.get('config_files') or [],
params.get('params_override') or None) params.get('params_override') or None,
params.get('gin_file') or [])
parameterized_benchmark_params.append(benchmark_params) parameterized_benchmark_params.append(benchmark_params)
return parameterized_benchmark_params return parameterized_benchmark_params
...@@ -103,13 +106,19 @@ class BaseBenchmark( # pylint: disable=undefined-variable ...@@ -103,13 +106,19 @@ class BaseBenchmark( # pylint: disable=undefined-variable
def benchmark(self, def benchmark(self,
benchmark_name, benchmark_name,
benchmark_function,
experiment_type, experiment_type,
execution_mode, execution_mode,
platform, platform,
precision, precision,
metric_bounds, metric_bounds,
config_files, config_files,
params_override): params_override,
gin_file):
with gin.unlock_config():
gin.parse_config_files_and_bindings(
[config_utils.get_config_path(g) for g in gin_file], None)
params = exp_factory.get_exp_config(experiment_type) params = exp_factory.get_exp_config(experiment_type)
...@@ -145,7 +154,7 @@ class BaseBenchmark( # pylint: disable=undefined-variable ...@@ -145,7 +154,7 @@ class BaseBenchmark( # pylint: disable=undefined-variable
logging.info('Final experiment parameters: %s', logging.info('Final experiment parameters: %s',
pp.pformat(params.as_dict())) pp.pformat(params.as_dict()))
benchmark_data = benchmark_lib.run_benchmark( benchmark_data = benchmark_function(
execution_mode, params, self._get_model_dir(benchmark_name)) execution_mode, params, self._get_model_dir(benchmark_name))
metrics = [] metrics = []
......
...@@ -27,6 +27,7 @@ from official.core import config_definitions ...@@ -27,6 +27,7 @@ from official.core import config_definitions
from official.core import task_factory from official.core import task_factory
from official.core import train_utils from official.core import train_utils
from official.modeling import performance from official.modeling import performance
from official.modeling.fast_training import stage_lib
def run_benchmark( def run_benchmark(
...@@ -132,3 +133,54 @@ def run_benchmark( ...@@ -132,3 +133,54 @@ def run_benchmark(
startup_time=startup_time)) startup_time=startup_time))
return benchmark_data return benchmark_data
def run_fast_training_benchmark(
execution_mode: str,
params: config_definitions.ExperimentConfig,
model_dir: str,
distribution_strategy: tf.distribute.Strategy = None
) -> Mapping[str, Any]:
"""Runs benchmark for a fast training experiment.
This benchmark tests and only tests the binary
tensorflow_models/official/modeling/fast_training/train.py
Args:
execution_mode: A 'str', specifying the mode. Can be 'accuracy',
'performance', or 'tflite_accuracy'.
params: ExperimentConfig instance.
model_dir: A 'str', a path to store model checkpoints and summaries.
distribution_strategy: A tf.distribute.Strategy to use. If specified,
it will be used instead of inferring the strategy from params.
Returns:
benchmark_data: returns benchmark data in dict format.
Raises:
NotImplementedError: If try to use unsupported setup.
"""
if execution_mode == 'performance':
logging.warn('Fast training benchmark does not support execution_mode == '
'performance. This benchmark run will be skipped..')
return dict(examples_per_second=0.0,
wall_time=0.0,
startup_time=0.0)
strategy = distribution_strategy or distribute_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg,
num_gpus=params.runtime.num_gpus,
tpu_address=params.runtime.tpu)
first_loop_start_time = time.time()
_, eval_logs = stage_lib.run_progressive_experiment(
distribution_strategy=strategy,
mode='train',
params=params,
model_dir=model_dir,
run_post_eval=True)
wall_time = time.time() - first_loop_start_time
return dict(metrics=eval_logs, wall_time=wall_time,
startup_time=0.0, examples_per_second=0.0)
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
# pylint: disable=g-direct-tensorflow-import # pylint: disable=g-direct-tensorflow-import
from absl.testing import parameterized from absl.testing import parameterized
import gin
import tensorflow as tf import tensorflow as tf
from tensorflow.python.distribute import combinations from tensorflow.python.distribute import combinations
...@@ -85,5 +86,43 @@ class BenchmarkLibTest(tf.test.TestCase, parameterized.TestCase): ...@@ -85,5 +86,43 @@ class BenchmarkLibTest(tf.test.TestCase, parameterized.TestCase):
if execution_mode == 'accuracy': if execution_mode == 'accuracy':
self.assertIn('metrics', benchmark_data) self.assertIn('metrics', benchmark_data)
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.default_strategy,
strategy_combinations.cloud_tpu_strategy,
strategy_combinations.one_device_strategy_gpu,
],
execution_mode=['performance', 'accuracy'],
))
def test_fast_training_benchmark(self, distribution, execution_mode):
model_dir = self.get_temp_dir()
with gin.unlock_config():
gin.parse_config_files_and_bindings(
None,
"get_initialize_fn.stacking_pattern = 'dense_{:layer_id}/'\n"
"StageParamProgressor.stage_overrides = ("
" {'trainer': {'train_steps': 1}},"
" {'trainer': {'train_steps': 2}},"
")")
params = exp_factory.get_exp_config('mock')
params = hyperparams.override_params_dict(
params, self._test_config, is_strict=True)
benchmark_data = benchmark_lib.run_fast_training_benchmark(execution_mode,
params,
model_dir,
distribution)
if execution_mode == 'performance':
self.assertEqual(dict(examples_per_second=0.0,
wall_time=0.0,
startup_time=0.0),
benchmark_data)
else:
self.assertIn('wall_time', benchmark_data)
self.assertIn('metrics', benchmark_data)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment