Commit 4571d3fa authored by Haoyu Zhang's avatar Haoyu Zhang Committed by Toby Boyd
Browse files

Add flag to enable XLA in Keras models (#6240)

* Add flag to enable XLA in Keras models

* Fix lint errors (some of them are old errors)
parent 733535dd
...@@ -53,6 +53,7 @@ def learning_rate_schedule(current_epoch, ...@@ -53,6 +53,7 @@ def learning_rate_schedule(current_epoch,
Returns: Returns:
Adjusted learning rate. Adjusted learning rate.
""" """
del current_batch, batches_per_epoch # not used
initial_learning_rate = keras_common.BASE_LEARNING_RATE * batch_size / 128 initial_learning_rate = keras_common.BASE_LEARNING_RATE * batch_size / 128
learning_rate = initial_learning_rate learning_rate = initial_learning_rate
for mult, start_epoch in LR_SCHEDULE: for mult, start_epoch in LR_SCHEDULE:
...@@ -97,10 +98,16 @@ def run(flags_obj): ...@@ -97,10 +98,16 @@ def run(flags_obj):
Returns: Returns:
Dictionary of training and eval stats. Dictionary of training and eval stats.
""" """
config = keras_common.get_config_proto()
# TODO(tobyboyd): Remove eager flag when tf 1.0 testing ends. # TODO(tobyboyd): Remove eager flag when tf 1.0 testing ends.
# Eager is default in tf 2.0 and should not be toggled # Eager is default in tf 2.0 and should not be toggled
if flags_obj.enable_eager and not keras_common.is_v2_0(): if not keras_common.is_v2_0():
tf.compat.v1.enable_eager_execution() if flags_obj.enable_eager:
tf.compat.v1.enable_eager_execution(config=config)
else:
sess = tf.Session(config=config)
tf.keras.backend.set_session(sess)
# TODO(haoyuzhang): Set config properly in TF2.0 when the config API is ready.
dtype = flags_core.get_tf_dtype(flags_obj) dtype = flags_core.get_tf_dtype(flags_obj)
if dtype == 'fp16': if dtype == 'fp16':
......
...@@ -25,6 +25,8 @@ import numpy as np ...@@ -25,6 +25,8 @@ import numpy as np
# pylint: disable=g-bad-import-order # pylint: disable=g-bad-import-order
from absl import flags from absl import flags
import tensorflow as tf import tensorflow as tf
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.keras.optimizer_v2 import (gradient_descent as from tensorflow.python.keras.optimizer_v2 import (gradient_descent as
gradient_descent_v2) gradient_descent_v2)
...@@ -49,6 +51,7 @@ class TimeHistory(tf.keras.callbacks.Callback): ...@@ -49,6 +51,7 @@ class TimeHistory(tf.keras.callbacks.Callback):
Args: Args:
batch_size: Total batch size. batch_size: Total batch size.
log_steps: Interval of time history logs.
""" """
self.batch_size = batch_size self.batch_size = batch_size
...@@ -126,6 +129,20 @@ class LearningRateBatchScheduler(tf.keras.callbacks.Callback): ...@@ -126,6 +129,20 @@ class LearningRateBatchScheduler(tf.keras.callbacks.Callback):
'change learning rate to %s.', self.epochs, batch, lr) 'change learning rate to %s.', self.epochs, batch, lr)
def get_config_proto():
"""Return config proto according to flag settings, or None to use default."""
config = None
if FLAGS.enable_xla:
config = tf.ConfigProto()
config.graph_options.optimizer_options.global_jit_level = (
tf.OptimizerOptions.ON_2)
# Disable PinToHostOptimizer in grappler when enabling XLA because it causes
# OOM and performance regression.
config.graph_options.rewrite_options.pin_to_host_optimization = (
rewriter_config_pb2.RewriterConfig.OFF)
return config
def get_optimizer(): def get_optimizer():
"""Returns optimizer to use.""" """Returns optimizer to use."""
# The learning_rate is overwritten at the beginning of each step by callback. # The learning_rate is overwritten at the beginning of each step by callback.
...@@ -189,8 +206,13 @@ def build_stats(history, eval_output, time_callback): ...@@ -189,8 +206,13 @@ def build_stats(history, eval_output, time_callback):
def define_keras_flags(): def define_keras_flags():
"""Define flags for Keras models."""
flags.DEFINE_boolean(name='enable_eager', default=False, help='Enable eager?') flags.DEFINE_boolean(name='enable_eager', default=False, help='Enable eager?')
flags.DEFINE_boolean(name='skip_eval', default=False, help='Skip evaluation?') flags.DEFINE_boolean(name='skip_eval', default=False, help='Skip evaluation?')
flags.DEFINE_boolean(
name='enable_xla', default=False,
help='Whether to enable XLA auto jit compilation. This is still an '
'experimental feature, and is not yet effective with TF 2.0.')
flags.DEFINE_integer( flags.DEFINE_integer(
name='train_steps', default=None, name='train_steps', default=None,
help='The number of steps to run for training. If it is larger than ' help='The number of steps to run for training. If it is larger than '
......
...@@ -86,11 +86,20 @@ def run(flags_obj): ...@@ -86,11 +86,20 @@ def run(flags_obj):
Raises: Raises:
ValueError: If fp16 is passed as it is not currently supported. ValueError: If fp16 is passed as it is not currently supported.
Returns:
Dictionary of training and eval stats.
""" """
config = keras_common.get_config_proto()
# TODO(tobyboyd): Remove eager flag when tf 1.0 testing ends. # TODO(tobyboyd): Remove eager flag when tf 1.0 testing ends.
# Eager is default in tf 2.0 and should not be toggled # Eager is default in tf 2.0 and should not be toggled
if flags_obj.enable_eager and not keras_common.is_v2_0(): if not keras_common.is_v2_0():
tf.compat.v1.enable_eager_execution() if flags_obj.enable_eager:
tf.compat.v1.enable_eager_execution(config=config)
else:
sess = tf.Session(config=config)
tf.keras.backend.set_session(sess)
# TODO(haoyuzhang): Set config properly in TF2.0 when the config API is ready.
dtype = flags_core.get_tf_dtype(flags_obj) dtype = flags_core.get_tf_dtype(flags_obj)
if dtype == 'fp16': if dtype == 'fp16':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment