Unverified Commit dba24007 authored by Haoyu Zhang's avatar Haoyu Zhang Committed by GitHub
Browse files

Add config to enable XLA in TF 2.0 (#6406)

parent 04792078
...@@ -98,16 +98,17 @@ def run(flags_obj): ...@@ -98,16 +98,17 @@ def run(flags_obj):
Returns: Returns:
Dictionary of training and eval stats. Dictionary of training and eval stats.
""" """
config = keras_common.get_config_proto()
# TODO(tobyboyd): Remove eager flag when tf 1.0 testing ends. # TODO(tobyboyd): Remove eager flag when tf 1.0 testing ends.
# Eager is default in tf 2.0 and should not be toggled # Eager is default in tf 2.0 and should not be toggled
if not keras_common.is_v2_0(): if keras_common.is_v2_0():
keras_common.set_config_v2()
else:
config = keras_common.get_config_proto_v1()
if flags_obj.enable_eager: if flags_obj.enable_eager:
tf.compat.v1.enable_eager_execution(config=config) tf.compat.v1.enable_eager_execution(config=config)
else: else:
sess = tf.Session(config=config) sess = tf.Session(config=config)
tf.keras.backend.set_session(sess) tf.keras.backend.set_session(sess)
# TODO(haoyuzhang): Set config properly in TF2.0 when the config API is ready.
dtype = flags_core.get_tf_dtype(flags_obj) dtype = flags_core.get_tf_dtype(flags_obj)
if dtype == 'fp16': if dtype == 'fp16':
......
...@@ -129,14 +129,14 @@ class LearningRateBatchScheduler(tf.keras.callbacks.Callback): ...@@ -129,14 +129,14 @@ class LearningRateBatchScheduler(tf.keras.callbacks.Callback):
'change learning rate to %s.', self.epochs, batch, lr) 'change learning rate to %s.', self.epochs, batch, lr)
def get_config_proto(): def get_config_proto_v1():
"""Return config proto according to flag settings, or None to use default.""" """Return config proto according to flag settings, or None to use default."""
config = None config = None
if FLAGS.enable_xla: if FLAGS.enable_xla:
# TODO(haoyuzhang): Remove this monkey patch when XLA OOM issue is fixed. # TODO(haoyuzhang): Remove this monkey patch when XLA OOM issue is fixed.
_monkey_patch_org_assert_broadcastable() _monkey_patch_org_assert_broadcastable()
config = tf.ConfigProto() config = tf.compat.v1.ConfigProto()
config.graph_options.optimizer_options.global_jit_level = ( config.graph_options.optimizer_options.global_jit_level = (
tf.OptimizerOptions.ON_2) tf.OptimizerOptions.ON_2)
# Disable PinToHostOptimizer in grappler when enabling XLA because it causes # Disable PinToHostOptimizer in grappler when enabling XLA because it causes
...@@ -146,6 +146,20 @@ def get_config_proto(): ...@@ -146,6 +146,20 @@ def get_config_proto():
return config return config
def set_config_v2():
"""Config eager context according to flag values using TF 2.0 API."""
if FLAGS.enable_xla:
# TODO(haoyuzhang): Remove this monkey patch when XLA OOM issue is fixed.
_monkey_patch_org_assert_broadcastable()
tf.config.optimizer.set_jit(True)
# Disable PinToHostOptimizer in grappler when enabling XLA because it
# causes OOM and performance regression.
tf.config.optimizer.set_experimental_options(
{"pin_to_host_optimization": False}
)
def get_optimizer(): def get_optimizer():
"""Returns optimizer to use.""" """Returns optimizer to use."""
# The learning_rate is overwritten at the beginning of each step by callback. # The learning_rate is overwritten at the beginning of each step by callback.
......
...@@ -90,16 +90,17 @@ def run(flags_obj): ...@@ -90,16 +90,17 @@ def run(flags_obj):
Returns: Returns:
Dictionary of training and eval stats. Dictionary of training and eval stats.
""" """
config = keras_common.get_config_proto()
# TODO(tobyboyd): Remove eager flag when tf 1.0 testing ends. # TODO(tobyboyd): Remove eager flag when tf 1.0 testing ends.
# Eager is default in tf 2.0 and should not be toggled # Eager is default in tf 2.0 and should not be toggled
if not keras_common.is_v2_0(): if keras_common.is_v2_0():
keras_common.set_config_v2()
else:
config = keras_common.get_config_proto_v1()
if flags_obj.enable_eager: if flags_obj.enable_eager:
tf.compat.v1.enable_eager_execution(config=config) tf.compat.v1.enable_eager_execution(config=config)
else: else:
sess = tf.Session(config=config) sess = tf.Session(config=config)
tf.keras.backend.set_session(sess) tf.keras.backend.set_session(sess)
# TODO(haoyuzhang): Set config properly in TF2.0 when the config API is ready.
dtype = flags_core.get_tf_dtype(flags_obj) dtype = flags_core.get_tf_dtype(flags_obj)
if dtype == 'float16': if dtype == 'float16':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment