Unverified Commit 92bad0d2 authored by rxsang's avatar rxsang Committed by GitHub
Browse files

Add a test enabling get_next_as_optional behavior. (#6862)

* Add a test enabling get_next_as_optional behavior.

* Remove repeated flag.

* Remove trailing space.

* Make the name shorter.

* Fix lint error.

* Refine the benchmark name.
parent 68650c42
...@@ -568,6 +568,26 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark): ...@@ -568,6 +568,26 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
FLAGS.data_delay_prefetch = True FLAGS.data_delay_prefetch = True
self._run_and_report_benchmark() self._run_and_report_benchmark()
def benchmark_xla_8_gpu_fp16_tweaked_optional_next(self):
"""Test Keras model with manual config tuning, XLA, 8 GPUs, fp16 and
enabling get_next_as_optional.
"""
self._setup()
FLAGS.num_gpus = 8
FLAGS.dtype = 'fp16'
FLAGS.enable_eager = True
FLAGS.enable_xla = True
FLAGS.distribution_strategy = 'default'
FLAGS.model_dir = self._get_model_dir(
'benchmark_xla_8_gpu_fp16_tweaked_optional_next')
FLAGS.batch_size = 256 * 8 # 8 GPUs
FLAGS.use_tensor_lr = True
# FLAGS.tf_gpu_thread_mode = 'gpu_private'
FLAGS.data_delay_prefetch = True
FLAGS.enable_get_next_as_optional = True
self._run_and_report_benchmark()
def benchmark_xla_8_gpu_fp16_slack(self): def benchmark_xla_8_gpu_fp16_slack(self):
"""Test Keras model with tf.data's experimental_slack functionality, XLA, """Test Keras model with tf.data's experimental_slack functionality, XLA,
8 GPUs and fp16. 8 GPUs and fp16.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment