Commit bc984e50 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Merge pull request #7573 from houtoms:ctl_amp_double_batch_size

PiperOrigin-RevId: 269638447
parents 1862b9c3 c3db5a9f
...@@ -144,7 +144,7 @@ class Resnet50CtlAccuracy(CtlBenchmark): ...@@ -144,7 +144,7 @@ class Resnet50CtlAccuracy(CtlBenchmark):
self._setup() self._setup()
FLAGS.num_gpus = 8 FLAGS.num_gpus = 8
FLAGS.data_dir = self.data_dir FLAGS.data_dir = self.data_dir
FLAGS.batch_size = 128 * 8 FLAGS.batch_size = 256 * 8
FLAGS.train_epochs = 90 FLAGS.train_epochs = 90
FLAGS.epochs_between_evals = 10 FLAGS.epochs_between_evals = 10
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_amp') FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_amp')
...@@ -228,7 +228,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark): ...@@ -228,7 +228,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
FLAGS.num_gpus = 1 FLAGS.num_gpus = 1
FLAGS.distribution_strategy = 'default' FLAGS.distribution_strategy = 'default'
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_amp') FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_amp')
FLAGS.batch_size = 128 FLAGS.batch_size = 256
FLAGS.dtype = 'fp16' FLAGS.dtype = 'fp16'
FLAGS.fp16_implementation = 'graph_rewrite' FLAGS.fp16_implementation = 'graph_rewrite'
self._run_and_report_benchmark() self._run_and_report_benchmark()
...@@ -240,7 +240,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark): ...@@ -240,7 +240,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
FLAGS.num_gpus = 1 FLAGS.num_gpus = 1
FLAGS.distribution_strategy = 'default' FLAGS.distribution_strategy = 'default'
FLAGS.model_dir = self._get_model_dir('benchmark_xla_1_gpu_amp') FLAGS.model_dir = self._get_model_dir('benchmark_xla_1_gpu_amp')
FLAGS.batch_size = 128 FLAGS.batch_size = 256
FLAGS.dtype = 'fp16' FLAGS.dtype = 'fp16'
FLAGS.fp16_implementation = 'graph_rewrite' FLAGS.fp16_implementation = 'graph_rewrite'
FLAGS.enable_xla = True FLAGS.enable_xla = True
...@@ -275,7 +275,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark): ...@@ -275,7 +275,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
FLAGS.num_gpus = 8 FLAGS.num_gpus = 8
FLAGS.distribution_strategy = 'default' FLAGS.distribution_strategy = 'default'
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_amp') FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_amp')
FLAGS.batch_size = 128 * 8 # 8 GPUs FLAGS.batch_size = 256 * 8 # 8 GPUs
FLAGS.dtype = 'fp16' FLAGS.dtype = 'fp16'
FLAGS.fp16_implementation = 'graph_rewrite' FLAGS.fp16_implementation = 'graph_rewrite'
self._run_and_report_benchmark() self._run_and_report_benchmark()
...@@ -287,7 +287,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark): ...@@ -287,7 +287,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
FLAGS.num_gpus = 8 FLAGS.num_gpus = 8
FLAGS.distribution_strategy = 'default' FLAGS.distribution_strategy = 'default'
FLAGS.model_dir = self._get_model_dir('benchmark_xla_8_gpu_amp') FLAGS.model_dir = self._get_model_dir('benchmark_xla_8_gpu_amp')
FLAGS.batch_size = 128 * 8 # 8 GPUs FLAGS.batch_size = 256 * 8 # 8 GPUs
FLAGS.dtype = 'fp16' FLAGS.dtype = 'fp16'
FLAGS.fp16_implementation = 'graph_rewrite' FLAGS.fp16_implementation = 'graph_rewrite'
FLAGS.enable_xla = True FLAGS.enable_xla = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment