Commit f60a4f68 authored by Brandon Jiang's avatar Brandon Jiang Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 447857082
parent c0bce36e
......@@ -26,7 +26,6 @@ from official.core import config_definitions
from official.core import task_factory
from official.core import train_utils
from official.modeling import performance
from official.modeling.fast_training import stage_lib
from official.projects.token_dropping import experiment_configs # pylint: disable=unused-import
......@@ -133,54 +132,3 @@ def run_benchmark(
startup_time=startup_time))
return benchmark_data
def run_fast_training_benchmark(
execution_mode: str,
params: config_definitions.ExperimentConfig,
model_dir: str,
distribution_strategy: tf.distribute.Strategy = None
) -> Mapping[str, Any]:
"""Runs benchmark for a fast training experiment.
This benchmark tests and only tests the binary
tensorflow_models/official/modeling/fast_training/train.py
Args:
execution_mode: A 'str', specifying the mode. Can be 'accuracy',
'performance', or 'tflite_accuracy'.
params: ExperimentConfig instance.
model_dir: A 'str', a path to store model checkpoints and summaries.
distribution_strategy: A tf.distribute.Strategy to use. If specified,
it will be used instead of inferring the strategy from params.
Returns:
benchmark_data: returns benchmark data in dict format.
Raises:
NotImplementedError: If try to use unsupported setup.
"""
if execution_mode == 'performance':
logging.warn('Fast training benchmark does not support execution_mode == '
'performance. This benchmark run will be skipped..')
return dict(examples_per_second=0.0,
wall_time=0.0,
startup_time=0.0)
strategy = distribution_strategy or distribute_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg,
num_gpus=params.runtime.num_gpus,
tpu_address=params.runtime.tpu)
first_loop_start_time = time.time()
_, eval_logs = stage_lib.run_progressive_experiment(
distribution_strategy=strategy,
mode='train',
params=params,
model_dir=model_dir,
run_post_eval=True)
wall_time = time.time() - first_loop_start_time
return dict(metrics=eval_logs, wall_time=wall_time,
startup_time=0.0, examples_per_second=0.0)
......@@ -16,7 +16,6 @@
# pylint: disable=g-direct-tensorflow-import
from absl.testing import parameterized
import gin
import tensorflow as tf
from tensorflow.python.distribute import combinations
......@@ -85,43 +84,6 @@ class BenchmarkLibTest(tf.test.TestCase, parameterized.TestCase):
if execution_mode == 'accuracy':
self.assertIn('metrics', benchmark_data)
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.default_strategy,
strategy_combinations.cloud_tpu_strategy,
strategy_combinations.one_device_strategy_gpu,
],
execution_mode=['performance', 'accuracy'],
))
def test_fast_training_benchmark(self, distribution, execution_mode):
model_dir = self.get_temp_dir()
with gin.unlock_config():
gin.parse_config_files_and_bindings(
None,
"get_initialize_fn.stacking_pattern = 'dense_{:layer_id}/'\n"
"StageParamProgressor.stage_overrides = ("
" {'trainer': {'train_steps': 1}},"
" {'trainer': {'train_steps': 2}},"
")")
params = exp_factory.get_exp_config('mock')
params = hyperparams.override_params_dict(
params, self._test_config, is_strict=True)
benchmark_data = benchmark_lib.run_fast_training_benchmark(execution_mode,
params,
model_dir,
distribution)
if execution_mode == 'performance':
self.assertEqual(dict(examples_per_second=0.0,
wall_time=0.0,
startup_time=0.0),
benchmark_data)
else:
self.assertIn('wall_time', benchmark_data)
self.assertIn('metrics', benchmark_data)
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment