".github/git@developer.sourcefind.cn:ox696c/ktransformers.git" did not exist on "ca1dc1e7d16958893aa4ef3e005ad419e55a4b71"
Unverified Commit 269581dc authored by Toby Boyd's avatar Toby Boyd Committed by GitHub
Browse files

Add XLA to transformer (#7048)



* set default steps to 300K.

* Log flags to perfzero.

* Add XLA support to transformer

- Moved config logic to keras_utils
- Added enable_xla flag to _performance flags
- Did not refactor enable_xla flag from keras resnet due to
  reliance on calling FLAGs in estimator keras and that is
  a needed refactor for another time.

* fix g3 lint complaint.

* Refactor set config into keras_utils.

* Move flags out of main.

* pipe through enable_xla

* Update official/transformer/v2/misc.py
Co-Authored-By: default avatarReed <reedwm@google.com>
parent 1e527fb5
...@@ -348,11 +348,14 @@ def imagenet_model_fn(features, labels, mode, params): ...@@ -348,11 +348,14 @@ def imagenet_model_fn(features, labels, mode, params):
) )
def define_imagenet_flags(dynamic_loss_scale=False, fp16_implementation=False): def define_imagenet_flags(dynamic_loss_scale=False,
fp16_implementation=False,
enable_xla=False):
resnet_run_loop.define_resnet_flags( resnet_run_loop.define_resnet_flags(
resnet_size_choices=['18', '34', '50', '101', '152', '200'], resnet_size_choices=['18', '34', '50', '101', '152', '200'],
dynamic_loss_scale=dynamic_loss_scale, dynamic_loss_scale=dynamic_loss_scale,
fp16_implementation=fp16_implementation) fp16_implementation=fp16_implementation,
enable_xla=enable_xla)
flags.adopt_module_key_flags(resnet_run_loop) flags.adopt_module_key_flags(resnet_run_loop)
flags_core.set_defaults(train_epochs=90) flags_core.set_defaults(train_epochs=90)
......
...@@ -28,6 +28,7 @@ from official.resnet.keras import resnet_cifar_model ...@@ -28,6 +28,7 @@ from official.resnet.keras import resnet_cifar_model
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
from official.utils.logs import logger from official.utils.logs import logger
from official.utils.misc import distribution_utils from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils
LR_SCHEDULE = [ # (multiplier, epoch to start) tuples LR_SCHEDULE = [ # (multiplier, epoch to start) tuples
...@@ -98,17 +99,7 @@ def run(flags_obj): ...@@ -98,17 +99,7 @@ def run(flags_obj):
Returns: Returns:
Dictionary of training and eval stats. Dictionary of training and eval stats.
""" """
# TODO(tobyboyd): Remove eager flag when tf 1.0 testing ends. keras_utils.set_session_config(enable_eager=flags_obj.enable_eager)
# Eager is default in tf 2.0 and should not be toggled
if keras_common.is_v2_0():
keras_common.set_config_v2()
else:
config = keras_common.get_config_proto_v1()
if flags_obj.enable_eager:
tf.compat.v1.enable_eager_execution(config=config)
else:
sess = tf.Session(config=config)
tf.keras.backend.set_session(sess)
dtype = flags_core.get_tf_dtype(flags_obj) dtype = flags_core.get_tf_dtype(flags_obj)
if dtype == 'fp16': if dtype == 'fp16':
......
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Common util functions and classes used by both keras cifar and imagenet.""" """Common util functions and classes used by both keras cifar and imagenet."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
...@@ -29,7 +28,6 @@ import tensorflow as tf ...@@ -29,7 +28,6 @@ import tensorflow as tf
from official.utils.misc import keras_utils from official.utils.misc import keras_utils
# pylint: disable=ungrouped-imports # pylint: disable=ungrouped-imports
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.keras.optimizer_v2 import (gradient_descent as from tensorflow.python.keras.optimizer_v2 import (gradient_descent as
gradient_descent_v2) gradient_descent_v2)
...@@ -145,48 +143,6 @@ class PiecewiseConstantDecayWithWarmup( ...@@ -145,48 +143,6 @@ class PiecewiseConstantDecayWithWarmup(
} }
def get_config_proto_v1():
"""Return config proto according to flag settings, or None to use default."""
config = None
if FLAGS.enable_xla:
config = tf.compat.v1.ConfigProto()
config.graph_options.optimizer_options.global_jit_level = (
tf.OptimizerOptions.ON_2)
# Disable PinToHostOptimizer in grappler when enabling XLA because it causes
# OOM and performance regression.
config.graph_options.rewrite_options.pin_to_host_optimization = (
rewriter_config_pb2.RewriterConfig.OFF)
# TODO(b/76028325): Remove when generic layout optimizer will be ready.
if not FLAGS.enable_grappler_layout_optimizer:
if config is None:
config = tf.compat.v1.ConfigProto()
# Disable LayoutOptimizer in grappler, because it might de-optimize fp16
# graphs, and force NCHW data format in all convolutions and batch
# normalizations.
config.graph_options.rewrite_options.layout_optimizer = (
rewriter_config_pb2.RewriterConfig.OFF)
return config
def set_config_v2():
"""Config eager context according to flag values using TF 2.0 API."""
if FLAGS.enable_xla:
tf.config.optimizer.set_jit(True)
# Disable PinToHostOptimizer in grappler when enabling XLA because it
# causes OOM and performance regression.
tf.config.optimizer.set_experimental_options(
{'pin_to_host_optimization': False}
)
# TODO(b/76028325): Remove when generic layout optimizer will be ready.
if not FLAGS.enable_grappler_layout_optimizer:
# Disable LayoutOptimizer in grappler, because it might de-optimize fp16
# graphs, and force NCHW data format in all convolutions and batch
# normalizations.
tf.config.optimizer.set_experimental_options(
{'layout_optimizer': False}
)
def set_gpu_thread_mode_and_count(flags_obj): def set_gpu_thread_mode_and_count(flags_obj):
"""Set GPU thread mode and count, and adjust dataset threads count.""" """Set GPU thread mode and count, and adjust dataset threads count."""
cpu_count = multiprocessing.cpu_count() cpu_count = multiprocessing.cpu_count()
...@@ -306,10 +262,6 @@ def define_keras_flags(): ...@@ -306,10 +262,6 @@ def define_keras_flags():
help='Report metrics during training and evaluation.') help='Report metrics during training and evaluation.')
flags.DEFINE_boolean(name='use_tensor_lr', default=False, flags.DEFINE_boolean(name='use_tensor_lr', default=False,
help='Use learning rate tensor instead of a callback.') help='Use learning rate tensor instead of a callback.')
flags.DEFINE_boolean(
name='enable_xla', default=False,
help='Whether to enable XLA auto jit compilation. This is still an '
'experimental feature, and is not yet effective with TF 2.0.')
flags.DEFINE_boolean( flags.DEFINE_boolean(
name='enable_tensorboard', default=False, name='enable_tensorboard', default=False,
help='Whether to enable Tensorboard callback.') help='Whether to enable Tensorboard callback.')
...@@ -404,14 +356,6 @@ def get_synth_input_fn(height, width, num_channels, num_classes, ...@@ -404,14 +356,6 @@ def get_synth_input_fn(height, width, num_channels, num_classes,
return input_fn return input_fn
def is_v2_0():
"""Returns true if using tf 2.0."""
if hasattr(tf, 'contrib'):
return False
else:
return True
def data_delay_prefetch(): def data_delay_prefetch():
"""Use unstable code for perf tuning purposes.""" """Use unstable code for perf tuning purposes."""
if not FLAGS.use_synthetic_data: if not FLAGS.use_synthetic_data:
...@@ -419,8 +363,11 @@ def data_delay_prefetch(): ...@@ -419,8 +363,11 @@ def data_delay_prefetch():
def set_cudnn_batchnorm_mode(): def set_cudnn_batchnorm_mode():
"""Set CuDNN batchnorm mode for better performance. Note that the spatial """Set CuDNN batchnorm mode for better performance.
persistent mode may lead to accuracy losses for certain models."""
Note: Spatial Persistent mode may lead to accuracy losses for certain
models.
"""
if FLAGS.batchnorm_spatial_persistent: if FLAGS.batchnorm_spatial_persistent:
os.environ['TF_USE_CUDNN_BATCHNORM_SPATIAL_PERSISTENT'] = '1' os.environ['TF_USE_CUDNN_BATCHNORM_SPATIAL_PERSISTENT'] = '1'
else: else:
......
...@@ -21,9 +21,7 @@ import time ...@@ -21,9 +21,7 @@ import time
from absl import flags from absl import flags
import tensorflow as tf # pylint: disable=g-bad-import-order import tensorflow as tf # pylint: disable=g-bad-import-order
from official.resnet import imagenet_main
from official.resnet.keras import keras_benchmark from official.resnet.keras import keras_benchmark
from official.resnet.keras import keras_common
from official.resnet.keras import keras_imagenet_main from official.resnet.keras import keras_imagenet_main
MIN_TOP_1_ACCURACY = 0.76 MIN_TOP_1_ACCURACY = 0.76
...@@ -46,10 +44,7 @@ class Resnet50KerasAccuracy(keras_benchmark.KerasBenchmark): ...@@ -46,10 +44,7 @@ class Resnet50KerasAccuracy(keras_benchmark.KerasBenchmark):
named arguments before updating the constructor. named arguments before updating the constructor.
""" """
flag_methods = [ flag_methods = [keras_imagenet_main.define_imagenet_keras_flags]
keras_common.define_keras_flags,
lambda: imagenet_main.define_imagenet_flags(dynamic_loss_scale=True)
]
self.data_dir = os.path.join(root_data_dir, 'imagenet') self.data_dir = os.path.join(root_data_dir, 'imagenet')
super(Resnet50KerasAccuracy, self).__init__( super(Resnet50KerasAccuracy, self).__init__(
...@@ -206,10 +201,7 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark): ...@@ -206,10 +201,7 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
"""Resnet50 benchmarks.""" """Resnet50 benchmarks."""
def __init__(self, output_dir=None, default_flags=None): def __init__(self, output_dir=None, default_flags=None):
flag_methods = [ flag_methods = [keras_imagenet_main.define_imagenet_keras_flags]
keras_common.define_keras_flags,
lambda: imagenet_main.define_imagenet_flags(dynamic_loss_scale=True)
]
super(Resnet50KerasBenchmarkBase, self).__init__( super(Resnet50KerasBenchmarkBase, self).__init__(
output_dir=output_dir, output_dir=output_dir,
...@@ -1153,10 +1145,8 @@ class TrivialKerasBenchmarkReal(keras_benchmark.KerasBenchmark): ...@@ -1153,10 +1145,8 @@ class TrivialKerasBenchmarkReal(keras_benchmark.KerasBenchmark):
"""Trivial model with real data benchmark tests.""" """Trivial model with real data benchmark tests."""
def __init__(self, output_dir=None, root_data_dir=None, **kwargs): def __init__(self, output_dir=None, root_data_dir=None, **kwargs):
flag_methods = [ flag_methods = [keras_imagenet_main.define_imagenet_keras_flags]
keras_common.define_keras_flags,
lambda: imagenet_main.define_imagenet_flags(dynamic_loss_scale=True)
]
def_flags = {} def_flags = {}
def_flags['use_trivial_model'] = True def_flags['use_trivial_model'] = True
def_flags['skip_eval'] = True def_flags['skip_eval'] = True
......
...@@ -29,6 +29,7 @@ from official.resnet.keras import trivial_model ...@@ -29,6 +29,7 @@ from official.resnet.keras import trivial_model
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
from official.utils.logs import logger from official.utils.logs import logger
from official.utils.misc import distribution_utils from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils
from official.utils.misc import model_helpers from official.utils.misc import model_helpers
...@@ -92,17 +93,11 @@ def run(flags_obj): ...@@ -92,17 +93,11 @@ def run(flags_obj):
Returns: Returns:
Dictionary of training and eval stats. Dictionary of training and eval stats.
""" """
# TODO(tobyboyd): Remove eager flag when tf 1.0 testing ends. keras_utils.set_session_config(
# Eager is default in tf 2.0 and should not be toggled enable_eager=flags_obj.enable_eager,
if keras_common.is_v2_0(): enable_xla=flags_obj.enable_xla,
keras_common.set_config_v2() enable_grappler_layout_optimizer=
else: flags_obj.enable_grappler_layout_optimizer)
config = keras_common.get_config_proto_v1()
if flags_obj.enable_eager:
tf.compat.v1.enable_eager_execution(config=config)
else:
sess = tf.Session(config=config)
tf.keras.backend.set_session(sess)
# Execute flag override logic for better model performance # Execute flag override logic for better model performance
if flags_obj.tf_gpu_thread_mode: if flags_obj.tf_gpu_thread_mode:
...@@ -253,6 +248,11 @@ def run(flags_obj): ...@@ -253,6 +248,11 @@ def run(flags_obj):
return stats return stats
def define_imagenet_keras_flags():
imagenet_main.define_imagenet_flags(dynamic_loss_scale=True, enable_xla=True)
keras_common.define_keras_flags()
def main(_): def main(_):
model_helpers.apply_clean(flags.FLAGS) model_helpers.apply_clean(flags.FLAGS)
with logger.benchmark_context(flags.FLAGS): with logger.benchmark_context(flags.FLAGS):
...@@ -261,6 +261,5 @@ def main(_): ...@@ -261,6 +261,5 @@ def main(_):
if __name__ == '__main__': if __name__ == '__main__':
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO) tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)
imagenet_main.define_imagenet_flags(dynamic_loss_scale=True) define_imagenet_keras_flags()
keras_common.define_keras_flags()
absl_app.run(main) absl_app.run(main)
...@@ -460,8 +460,9 @@ def resnet_model_fn(features, labels, mode, model_class, ...@@ -460,8 +460,9 @@ def resnet_model_fn(features, labels, mode, model_class,
fp16_implementation = getattr(flags.FLAGS, 'fp16_implementation', None) fp16_implementation = getattr(flags.FLAGS, 'fp16_implementation', None)
if fp16_implementation == 'graph_rewrite': if fp16_implementation == 'graph_rewrite':
optimizer = tf.compat.v1.train.experimental.enable_mixed_precision_graph_rewrite( optimizer = (
optimizer, loss_scale=loss_scale) tf.compat.v1.train.experimental.enable_mixed_precision_graph_rewrite(
optimizer, loss_scale=loss_scale))
def _dense_grad_filter(gvs): def _dense_grad_filter(gvs):
"""Only apply gradient updates to the final layer. """Only apply gradient updates to the final layer.
...@@ -722,7 +723,7 @@ def resnet_main( ...@@ -722,7 +723,7 @@ def resnet_main(
def define_resnet_flags(resnet_size_choices=None, dynamic_loss_scale=False, def define_resnet_flags(resnet_size_choices=None, dynamic_loss_scale=False,
fp16_implementation=False): fp16_implementation=False, enable_xla=False):
"""Add flags and validators for ResNet.""" """Add flags and validators for ResNet."""
flags_core.define_base() flags_core.define_base()
flags_core.define_performance(num_parallel_calls=False, flags_core.define_performance(num_parallel_calls=False,
...@@ -731,7 +732,8 @@ def define_resnet_flags(resnet_size_choices=None, dynamic_loss_scale=False, ...@@ -731,7 +732,8 @@ def define_resnet_flags(resnet_size_choices=None, dynamic_loss_scale=False,
dynamic_loss_scale=dynamic_loss_scale, dynamic_loss_scale=dynamic_loss_scale,
fp16_implementation=fp16_implementation, fp16_implementation=fp16_implementation,
loss_scale=True, loss_scale=True,
tf_data_experimental_slack=True) tf_data_experimental_slack=True,
enable_xla=enable_xla)
flags_core.define_image() flags_core.define_image()
flags_core.define_benchmark() flags_core.define_benchmark()
flags.adopt_module_key_flags(flags_core) flags.adopt_module_key_flags(flags_core)
......
...@@ -25,6 +25,7 @@ from absl.testing import flagsaver ...@@ -25,6 +25,7 @@ from absl.testing import flagsaver
import tensorflow as tf # pylint: disable=g-bad-import-order import tensorflow as tf # pylint: disable=g-bad-import-order
from official.transformer import transformer_main as transformer_main from official.transformer import transformer_main as transformer_main
from official.utils.flags import core as flags_core
from official.utils.logs import hooks from official.utils.logs import hooks
TRANSFORMER_EN2DE_DATA_DIR_NAME = 'wmt32k-en2de-official' TRANSFORMER_EN2DE_DATA_DIR_NAME = 'wmt32k-en2de-official'
...@@ -100,10 +101,12 @@ class EstimatorBenchmark(tf.test.Benchmark): ...@@ -100,10 +101,12 @@ class EstimatorBenchmark(tf.test.Benchmark):
exp_per_sec = sum(exp_per_second_list) / (len(exp_per_second_list)) exp_per_sec = sum(exp_per_second_list) / (len(exp_per_second_list))
metrics.append({'name': 'exp_per_second', metrics.append({'name': 'exp_per_second',
'value': exp_per_sec}) 'value': exp_per_sec})
self.report_benchmark(
iters=eval_results['global_step'], flags_str = flags_core.get_nondefault_flags_as_str()
wall_time=wall_time_sec, self.report_benchmark(iters=eval_results['global_step'],
metrics=metrics) wall_time=wall_time_sec,
metrics=metrics,
extras={'flags': flags_str})
class TransformerBigEstimatorAccuracy(EstimatorBenchmark): class TransformerBigEstimatorAccuracy(EstimatorBenchmark):
......
...@@ -69,8 +69,21 @@ def define_transformer_flags(): ...@@ -69,8 +69,21 @@ def define_transformer_flags():
synthetic_data=True, synthetic_data=True,
max_train_steps=False, max_train_steps=False,
dtype=False, dtype=False,
all_reduce_alg=True all_reduce_alg=True,
enable_xla=True
) )
# Additional performance flags
# TODO(b/76028325): Remove when generic layout optimizer is ready.
flags.DEFINE_boolean(
name='enable_grappler_layout_optimizer',
default=True,
help='Enable Grappler layout optimizer. Currently Grappler can '
'de-optimize fp16 graphs by forcing NCHW layout for all '
'convolutions and batch normalizations, and this flag allows to '
'disable it.'
)
flags_core.define_benchmark() flags_core.define_benchmark()
flags_core.define_device(tpu=True) flags_core.define_device(tpu=True)
......
...@@ -39,6 +39,7 @@ from official.transformer.v2 import transformer ...@@ -39,6 +39,7 @@ from official.transformer.v2 import transformer
from official.transformer.v2 import translate from official.transformer.v2 import translate
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
from official.utils.logs import logger from official.utils.logs import logger
from official.utils.misc import keras_utils
from official.utils.misc import distribution_utils from official.utils.misc import distribution_utils
...@@ -121,6 +122,12 @@ class TransformerTask(object): ...@@ -121,6 +122,12 @@ class TransformerTask(object):
def train(self): def train(self):
"""Trains the model.""" """Trains the model."""
params, flags_obj, is_train = self.params, self.flags_obj, True params, flags_obj, is_train = self.params, self.flags_obj, True
# Sets config options.
keras_utils.set_session_config(
enable_xla=flags_obj.enable_xla,
enable_grappler_layout_optimizer=
flags_obj.enable_grappler_layout_optimizer)
_ensure_dir(flags_obj.model_dir) _ensure_dir(flags_obj.model_dir)
if self.distribution_strategy: if self.distribution_strategy:
with self.distribution_strategy.scope(): with self.distribution_strategy.scope():
......
...@@ -61,7 +61,7 @@ def define_performance(num_parallel_calls=True, inter_op=True, intra_op=True, ...@@ -61,7 +61,7 @@ def define_performance(num_parallel_calls=True, inter_op=True, intra_op=True,
datasets_num_parallel_batches=False, datasets_num_parallel_batches=False,
dynamic_loss_scale=False, fp16_implementation=False, dynamic_loss_scale=False, fp16_implementation=False,
loss_scale=False, loss_scale=False,
tf_data_experimental_slack=False): tf_data_experimental_slack=False, enable_xla=False):
"""Register flags for specifying performance tuning arguments. """Register flags for specifying performance tuning arguments.
Args: Args:
...@@ -86,6 +86,7 @@ def define_performance(num_parallel_calls=True, inter_op=True, intra_op=True, ...@@ -86,6 +86,7 @@ def define_performance(num_parallel_calls=True, inter_op=True, intra_op=True,
training. Can only be turned on if dtype is also True. training. Can only be turned on if dtype is also True.
tf_data_experimental_slack: Determines whether to enable tf.data's tf_data_experimental_slack: Determines whether to enable tf.data's
`experimental_slack` option. `experimental_slack` option.
enable_xla: Determines if XLA (auto clustering) is turned on.
Returns: Returns:
A list of flags for core.py to marks as key flags. A list of flags for core.py to marks as key flags.
...@@ -270,4 +271,9 @@ def define_performance(num_parallel_calls=True, inter_op=True, intra_op=True, ...@@ -270,4 +271,9 @@ def define_performance(num_parallel_calls=True, inter_op=True, intra_op=True,
"Whether to enable tf.data's `experimental_slack` option.") "Whether to enable tf.data's `experimental_slack` option.")
) )
if enable_xla:
flags.DEFINE_boolean(
name="enable_xla", default=False,
help="Whether to enable XLA auto jit compilation")
return key_flags return key_flags
...@@ -21,6 +21,7 @@ from __future__ import print_function ...@@ -21,6 +21,7 @@ from __future__ import print_function
import time import time
import tensorflow as tf import tensorflow as tf
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.eager import profiler from tensorflow.python.eager import profiler
...@@ -128,3 +129,73 @@ class ProfilerCallback(tf.keras.callbacks.Callback): ...@@ -128,3 +129,73 @@ class ProfilerCallback(tf.keras.callbacks.Callback):
tf.compat.v1.logging.info( tf.compat.v1.logging.info(
'Profiler saved profiles for steps between %s and %s to %s', 'Profiler saved profiles for steps between %s and %s to %s',
self.start_step, self.stop_step, self.log_dir) self.start_step, self.stop_step, self.log_dir)
def set_session_config(enable_eager=False,
enable_xla=False,
enable_grappler_layout_optimizer=True):
"""Sets the session config."""
if is_v2_0():
set_config_v2(
enable_xla=enable_xla,
enable_grappler_layout_optimizer=enable_grappler_layout_optimizer)
else:
config = get_config_proto_v1(
enable_xla=enable_xla,
enable_grappler_layout_optimizer=enable_grappler_layout_optimizer)
if enable_eager:
tf.compat.v1.enable_eager_execution(config=config)
else:
sess = tf.Session(config=config)
tf.keras.backend.set_session(sess)
def get_config_proto_v1(enable_xla=False,
enable_grappler_layout_optimizer=True):
"""Return config proto according to flag settings, or None to use default."""
config = None
if enable_xla:
config = tf.compat.v1.ConfigProto()
config.graph_options.optimizer_options.global_jit_level = (
tf.OptimizerOptions.ON_2)
# Disable PinToHostOptimizer in grappler when enabling XLA because it causes
# OOM and performance regression.
config.graph_options.rewrite_options.pin_to_host_optimization = (
rewriter_config_pb2.RewriterConfig.OFF)
# TODO(b/76028325): Remove when generic layout optimizer will be ready.
if not enable_grappler_layout_optimizer:
if config is None:
config = tf.compat.v1.ConfigProto()
# Disable LayoutOptimizer in grappler, because it might de-optimize fp16
# graphs, and force NCHW data format in all convolutions and batch
# normalizations.
config.graph_options.rewrite_options.layout_optimizer = (
rewriter_config_pb2.RewriterConfig.OFF)
return config
def set_config_v2(enable_xla=False,
enable_grappler_layout_optimizer=False):
"""Config eager context according to flag values using TF 2.0 API."""
if enable_xla:
tf.config.optimizer.set_jit(True)
# Disable PinToHostOptimizer in grappler when enabling XLA because it
# causes OOM and performance regression.
tf.config.optimizer.set_experimental_options(
{'pin_to_host_optimization': False}
)
# TODO(b/76028325): Remove when generic layout optimizer will be ready.
if not enable_grappler_layout_optimizer:
# Disable LayoutOptimizer in grappler, because it might de-optimize fp16
# graphs, and force NCHW data format in all convolutions and batch
# normalizations.
tf.config.optimizer.set_experimental_options(
{'layout_optimizer': False}
)
def is_v2_0():
"""Returns true if using tf 2.0."""
if hasattr(tf, 'contrib'):
return False
else:
return True
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment