Commit 4571d3fa authored by Haoyu Zhang's avatar Haoyu Zhang Committed by Toby Boyd
Browse files

Add flag to enable XLA in Keras models (#6240)

* Add flag to enable XLA in Keras models

* Fix lint errors (some of them are old errors)
parent 733535dd
......@@ -53,6 +53,7 @@ def learning_rate_schedule(current_epoch,
Returns:
Adjusted learning rate.
"""
del current_batch, batches_per_epoch # not used
initial_learning_rate = keras_common.BASE_LEARNING_RATE * batch_size / 128
learning_rate = initial_learning_rate
for mult, start_epoch in LR_SCHEDULE:
......@@ -97,10 +98,16 @@ def run(flags_obj):
Returns:
Dictionary of training and eval stats.
"""
config = keras_common.get_config_proto()
# TODO(tobyboyd): Remove eager flag when tf 1.0 testing ends.
# Eager is default in tf 2.0 and should not be toggled
if flags_obj.enable_eager and not keras_common.is_v2_0():
tf.compat.v1.enable_eager_execution()
if not keras_common.is_v2_0():
if flags_obj.enable_eager:
tf.compat.v1.enable_eager_execution(config=config)
else:
sess = tf.Session(config=config)
tf.keras.backend.set_session(sess)
# TODO(haoyuzhang): Set config properly in TF2.0 when the config API is ready.
dtype = flags_core.get_tf_dtype(flags_obj)
if dtype == 'fp16':
......
......@@ -25,6 +25,8 @@ import numpy as np
# pylint: disable=g-bad-import-order
from absl import flags
import tensorflow as tf
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.keras.optimizer_v2 import (gradient_descent as
gradient_descent_v2)
......@@ -49,6 +51,7 @@ class TimeHistory(tf.keras.callbacks.Callback):
Args:
batch_size: Total batch size.
log_steps: Interval of time history logs.
"""
self.batch_size = batch_size
......@@ -126,6 +129,20 @@ class LearningRateBatchScheduler(tf.keras.callbacks.Callback):
'change learning rate to %s.', self.epochs, batch, lr)
def get_config_proto():
"""Return config proto according to flag settings, or None to use default."""
config = None
if FLAGS.enable_xla:
config = tf.ConfigProto()
config.graph_options.optimizer_options.global_jit_level = (
tf.OptimizerOptions.ON_2)
# Disable PinToHostOptimizer in grappler when enabling XLA because it causes
# OOM and performance regression.
config.graph_options.rewrite_options.pin_to_host_optimization = (
rewriter_config_pb2.RewriterConfig.OFF)
return config
def get_optimizer():
"""Returns optimizer to use."""
# The learning_rate is overwritten at the beginning of each step by callback.
......@@ -189,8 +206,13 @@ def build_stats(history, eval_output, time_callback):
def define_keras_flags():
"""Define flags for Keras models."""
flags.DEFINE_boolean(name='enable_eager', default=False, help='Enable eager?')
flags.DEFINE_boolean(name='skip_eval', default=False, help='Skip evaluation?')
flags.DEFINE_boolean(
name='enable_xla', default=False,
help='Whether to enable XLA auto jit compilation. This is still an '
'experimental feature, and is not yet effective with TF 2.0.')
flags.DEFINE_integer(
name='train_steps', default=None,
help='The number of steps to run for training. If it is larger than '
......
......@@ -86,11 +86,20 @@ def run(flags_obj):
Raises:
ValueError: If fp16 is passed as it is not currently supported.
Returns:
Dictionary of training and eval stats.
"""
config = keras_common.get_config_proto()
# TODO(tobyboyd): Remove eager flag when tf 1.0 testing ends.
# Eager is default in tf 2.0 and should not be toggled
if flags_obj.enable_eager and not keras_common.is_v2_0():
tf.compat.v1.enable_eager_execution()
if not keras_common.is_v2_0():
if flags_obj.enable_eager:
tf.compat.v1.enable_eager_execution(config=config)
else:
sess = tf.Session(config=config)
tf.keras.backend.set_session(sess)
# TODO(haoyuzhang): Set config properly in TF2.0 when the config API is ready.
dtype = flags_core.get_tf_dtype(flags_obj)
if dtype == 'fp16':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment