Commit fe682959 authored by Lucy Fox's avatar Lucy Fox Committed by A. Unique TensorFlower
Browse files

Enable MLIR bridge by default for NHNet and Bert2Bert training.

Testing: Trained Bert2Bert model with MLIR bridge enabled and then NHNet based on Bert2Bert checkpoint, also with MLIR bridge enabled. Exported this to a saved model, ran inference, and confirmed model performance, accuracy, and prediction parity with the model trained without the MLIR bridge enabled.
PiperOrigin-RevId: 326734770
parent 6f91155b
......@@ -85,6 +85,10 @@ def define_flags():
default=None,
help=("a YAML/JSON string or a YAML file which specifies additional "
"overrides over the default parameters"))
# Enables MLIR-based TF/XLA bridge. This is part of a soft rollout and will
# eventually be the Google-wide default.
flags.DEFINE_bool("enable_mlir_bridge", True,
"Use MLIR TF/XLA bridge (experimental) -- NHNet.")
# pylint: disable=protected-access
......@@ -178,6 +182,9 @@ def train(params, strategy, dataset=None):
def run():
"""Runs NHNet using Keras APIs."""
if FLAGS.enable_mlir_bridge:
tf.config.experimental.enable_mlir_bridge()
strategy = distribution_utils.get_distribution_strategy(
distribution_strategy=FLAGS.distribution_strategy, tpu_address=FLAGS.tpu)
if strategy:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment