Commit fa9ed456 authored by Haoyu Zhang's avatar Haoyu Zhang Committed by Toby Boyd
Browse files

Add Keras XLA Tests (#6286)

* Added XLA test with a monkey-patched op to avoid OOM

* Added doc strings in Keras benchmarks to avoid Lint error
parent a76cd3ac
......@@ -133,6 +133,9 @@ def get_config_proto():
"""Return config proto according to flag settings, or None to use default."""
config = None
if FLAGS.enable_xla:
# TODO(haoyuzhang): Remove this monkey patch when XLA OOM issue is fixed.
_monkey_patch_org_assert_broadcastable()
config = tf.ConfigProto()
config.graph_options.optimizer_options.global_jit_level = (
tf.OptimizerOptions.ON_2)
......@@ -295,3 +298,26 @@ class DummyContextManager(object):
def __exit__(self, *args):
pass
def _monkey_patch_org_assert_broadcastable():
"""Monkey-patch `assert_broadcast` op to avoid OOM when enabling XLA."""
def no_op_assert_broadcastable(weights, values):
del weights, values
tf.compat.v1.logging.info(
'Using monkey-patched version of assert_broadcastable op, which always '
'returns an no_op. It should be removed after XLA OOM issue is fixed.')
return tf.constant([], dtype=tf.float32)
from tensorflow.python.ops import weights_broadcast_ops # pylint: disable=g-import-not-at-top
if not hasattr(weights_broadcast_ops, 'org_assert_broadcastable'):
weights_broadcast_ops.org_assert_broadcastable = (
weights_broadcast_ops.assert_broadcastable)
weights_broadcast_ops.assert_broadcastable = no_op_assert_broadcastable
def _undo_monkey_patch_org_assert_broadcastable():
from tensorflow.python.ops import weights_broadcast_ops # pylint: disable=g-import-not-at-top
if hasattr(weights_broadcast_ops, 'org_assert_broadcastable'):
weights_broadcast_ops.assert_broadcastable = (
weights_broadcast_ops.org_assert_broadcastable)
......@@ -120,6 +120,7 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
log_steps=FLAGS.log_steps)
def benchmark_1_gpu_no_dist_strat(self):
"""Test Keras model with 1 GPU, no distribution strategy."""
self._setup()
FLAGS.num_gpus = 1
......@@ -130,6 +131,7 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
self._run_and_report_benchmark()
def benchmark_graph_1_gpu_no_dist_strat(self):
"""Test Keras model in legacy graph mode with 1 GPU, no dist strat."""
self._setup()
FLAGS.num_gpus = 1
......@@ -140,6 +142,7 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
self._run_and_report_benchmark()
def benchmark_1_gpu(self):
"""Test Keras model with 1 GPU."""
self._setup()
FLAGS.num_gpus = 1
......@@ -149,7 +152,20 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
FLAGS.batch_size = 128
self._run_and_report_benchmark()
def benchmark_xla_1_gpu(self):
"""Test Keras model with XLA and 1 GPU."""
self._setup()
FLAGS.num_gpus = 1
FLAGS.enable_eager = True
FLAGS.enable_xla = True
FLAGS.distribution_strategy = 'default'
FLAGS.model_dir = self._get_model_dir('benchmark_xla_1_gpu')
FLAGS.batch_size = 128
self._run_and_report_benchmark()
def benchmark_graph_1_gpu(self):
"""Test Keras model in legacy graph mode with 1 GPU."""
self._setup()
FLAGS.num_gpus = 1
......@@ -159,7 +175,20 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
FLAGS.batch_size = 128
self._run_and_report_benchmark()
def benchmark_graph_xla_1_gpu(self):
"""Test Keras model in legacy graph mode with XLA and 1 GPU."""
self._setup()
FLAGS.num_gpus = 1
FLAGS.enable_eager = False
FLAGS.enable_xla = True
FLAGS.distribution_strategy = 'default'
FLAGS.model_dir = self._get_model_dir('benchmark_graph_xla_1_gpu')
FLAGS.batch_size = 128
self._run_and_report_benchmark()
def benchmark_8_gpu(self):
"""Test Keras model with 8 GPUs."""
self._setup()
FLAGS.num_gpus = 8
......@@ -170,6 +199,7 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
self._run_and_report_benchmark()
def benchmark_8_gpu_tweaked(self):
"""Test Keras model with manual config tuning and 8 GPUs."""
self._setup()
FLAGS.num_gpus = 8
......@@ -180,7 +210,21 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
FLAGS.datasets_num_private_threads = 14
self._run_and_report_benchmark()
def benchmark_xla_8_gpu(self):
"""Test Keras model with XLA and 8 GPUs."""
self._setup()
FLAGS.num_gpus = 8
FLAGS.enable_eager = True
FLAGS.enable_xla = True
FLAGS.distribution_strategy = 'default'
FLAGS.model_dir = self._get_model_dir('benchmark_xla_8_gpu')
# TODO(haoyuzhang): Set size to 128 per GPU when multi-GPU XLA OOM is fixed
FLAGS.batch_size = 64 * 8 # 8 GPUs
self._run_and_report_benchmark()
def benchmark_graph_8_gpu(self):
"""Test Keras model in legacy graph mode with 8 GPUs."""
self._setup()
FLAGS.num_gpus = 8
......@@ -190,6 +234,19 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
FLAGS.batch_size = 128 * 8 # 8 GPUs
self._run_and_report_benchmark()
def benchmark_graph_xla_8_gpu(self):
"""Test Keras model in legacy graph mode with XLA and 8 GPUs."""
self._setup()
FLAGS.num_gpus = 8
FLAGS.enable_eager = False
FLAGS.enable_xla = True
FLAGS.distribution_strategy = 'default'
FLAGS.model_dir = self._get_model_dir('benchmark_graph_xla_8_gpu')
# TODO(haoyuzhang): Set size to 128 per GPU when multi-GPU XLA OOM is fixed
FLAGS.batch_size = 64 * 8 # 8 GPUs
self._run_and_report_benchmark()
def fill_report_object(self, stats):
super(Resnet50KerasBenchmarkBase, self).fill_report_object(
stats,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment