Unverified Commit 1d16f473 authored by Haoyu Zhang's avatar Haoyu Zhang Committed by GitHub
Browse files

Remove assert_broadcastable monkey patch (#6901)

parent 4c1d95cc
......@@ -149,9 +149,6 @@ def get_config_proto_v1():
"""Return config proto according to flag settings, or None to use default."""
config = None
if FLAGS.enable_xla:
# TODO(haoyuzhang): Remove this monkey patch when XLA OOM issue is fixed.
_monkey_patch_org_assert_broadcastable()
config = tf.compat.v1.ConfigProto()
config.graph_options.optimizer_options.global_jit_level = (
tf.OptimizerOptions.ON_2)
......@@ -165,9 +162,6 @@ def get_config_proto_v1():
def set_config_v2():
"""Config eager context according to flag values using TF 2.0 API."""
if FLAGS.enable_xla:
# TODO(haoyuzhang): Remove this monkey patch when XLA OOM issue is fixed.
_monkey_patch_org_assert_broadcastable()
tf.config.optimizer.set_jit(True)
# Disable PinToHostOptimizer in grappler when enabling XLA because it
# causes OOM and performance regression.
......@@ -402,29 +396,6 @@ def set_cudnn_batchnorm_mode():
os.environ.pop('TF_USE_CUDNN_BATCHNORM_SPATIAL_PERSISTENT', None)
def _monkey_patch_org_assert_broadcastable():
"""Monkey-patch `assert_broadcast` op to avoid OOM when enabling XLA."""
def no_op_assert_broadcastable(weights, values):
del weights, values
tf.compat.v1.logging.info(
'Using monkey-patched version of assert_broadcastable op, which always '
'returns an no_op. It should be removed after XLA OOM issue is fixed.')
return tf.constant([], dtype=tf.float32)
from tensorflow.python.ops import weights_broadcast_ops # pylint: disable=g-import-not-at-top
if not hasattr(weights_broadcast_ops, 'org_assert_broadcastable'):
weights_broadcast_ops.org_assert_broadcastable = (
weights_broadcast_ops.assert_broadcastable)
weights_broadcast_ops.assert_broadcastable = no_op_assert_broadcastable
def _undo_monkey_patch_org_assert_broadcastable():
from tensorflow.python.ops import weights_broadcast_ops # pylint: disable=g-import-not-at-top
if hasattr(weights_broadcast_ops, 'org_assert_broadcastable'):
weights_broadcast_ops.assert_broadcastable = (
weights_broadcast_ops.org_assert_broadcastable)
# TODO(haoyuzhang): remove this monkey patch when the "prefetch with slack"
# feature is available in tf.data.
def _monkey_patch_org_create_device_dataset():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment