Unverified Commit 0a96c7b4 authored by Reed's avatar Reed Committed by GitHub
Browse files

Add graph rewrite convergence benchmark (#6712)

parent 17ecf9db
...@@ -160,6 +160,21 @@ class Resnet50EstimatorAccuracy(EstimatorBenchmark): ...@@ -160,6 +160,21 @@ class Resnet50EstimatorAccuracy(EstimatorBenchmark):
FLAGS.hooks = ['ExamplesPerSecondHook'] FLAGS.hooks = ['ExamplesPerSecondHook']
self._run_and_report_benchmark() self._run_and_report_benchmark()
def benchmark_graph_fp16_graph_rewrite_8_gpu(self):
"""Test FP16 graph rewrite 8 GPUs graph mode."""
self._setup()
FLAGS.num_gpus = 8
FLAGS.data_dir = self.data_dir
FLAGS.batch_size = 256 * 8
FLAGS.train_epochs = 90
FLAGS.epochs_between_evals = 10
FLAGS.model_dir = self._get_model_dir(
'benchmark_graph_fp16_graph_rewrite_8_gpu')
FLAGS.dtype = 'fp16'
FLAGS.fp16_implementation = 'graph_rewrite'
FLAGS.hooks = ['ExamplesPerSecondHook']
self._run_and_report_benchmark()
def _run_and_report_benchmark(self): def _run_and_report_benchmark(self):
start_time_sec = time.time() start_time_sec = time.time()
stats = imagenet_main.run_imagenet(flags.FLAGS) stats = imagenet_main.run_imagenet(flags.FLAGS)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment