Commit 93fc3c9f authored by Ruoxin Sang's avatar Ruoxin Sang Committed by A. Unique TensorFlower
Browse files

Set `use_tf_while_loop=True` for Resnet eager benchmarks.

PiperOrigin-RevId: 300351756
parent 7708847e
...@@ -281,6 +281,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark): ...@@ -281,6 +281,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_eager') FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_eager')
FLAGS.batch_size = 128 FLAGS.batch_size = 128
FLAGS.use_tf_function = False FLAGS.use_tf_function = False
FLAGS.use_tf_while_loop = False
FLAGS.single_l2_loss_op = True FLAGS.single_l2_loss_op = True
self._run_and_report_benchmark() self._run_and_report_benchmark()
...@@ -294,6 +295,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark): ...@@ -294,6 +295,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
FLAGS.batch_size = 250 FLAGS.batch_size = 250
FLAGS.dtype = 'fp16' FLAGS.dtype = 'fp16'
FLAGS.use_tf_function = False FLAGS.use_tf_function = False
FLAGS.use_tf_while_loop = False
FLAGS.single_l2_loss_op = True FLAGS.single_l2_loss_op = True
self._run_and_report_benchmark() self._run_and_report_benchmark()
...@@ -324,6 +326,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark): ...@@ -324,6 +326,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
FLAGS.num_gpus = 8 FLAGS.num_gpus = 8
FLAGS.use_tf_function = False FLAGS.use_tf_function = False
FLAGS.use_tf_while_loop = False
FLAGS.distribution_strategy = 'mirrored' FLAGS.distribution_strategy = 'mirrored'
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_eager') FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_eager')
FLAGS.batch_size = 128 FLAGS.batch_size = 128
...@@ -336,6 +339,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark): ...@@ -336,6 +339,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
FLAGS.num_gpus = 8 FLAGS.num_gpus = 8
FLAGS.dtype = 'fp16' FLAGS.dtype = 'fp16'
FLAGS.use_tf_function = False FLAGS.use_tf_function = False
FLAGS.use_tf_while_loop = False
FLAGS.distribution_strategy = 'mirrored' FLAGS.distribution_strategy = 'mirrored'
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_eager_fp16') FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_eager_fp16')
FLAGS.batch_size = 128 FLAGS.batch_size = 128
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment