"git@developer.sourcefind.cn:orangecat/ollama.git" did not exist on "c98669436742f79f8ffaa796ff94ee2c7f17201a"
Commit fe682959 authored by Lucy Fox's avatar Lucy Fox Committed by A. Unique TensorFlower
Browse files

Enable MLIR bridge by default for NHNet and Bert2Bert training.

Testing: Trained Bert2Bert model with MLIR bridge enabled and then NHNet based on Bert2Bert checkpoint, also with MLIR bridge enabled. Exported this to a saved model, ran inference, and confirmed model performance, accuracy, and prediction parity with the model trained without the MLIR bridge enabled.
PiperOrigin-RevId: 326734770
parent 6f91155b
...@@ -85,6 +85,10 @@ def define_flags(): ...@@ -85,6 +85,10 @@ def define_flags():
default=None, default=None,
help=("a YAML/JSON string or a YAML file which specifies additional " help=("a YAML/JSON string or a YAML file which specifies additional "
"overrides over the default parameters")) "overrides over the default parameters"))
# Enables MLIR-based TF/XLA bridge. This is part of a soft rollout and will
# eventually be the Google-wide default.
flags.DEFINE_bool("enable_mlir_bridge", True,
"Use MLIR TF/XLA bridge (experimental) -- NHNet.")
# pylint: disable=protected-access # pylint: disable=protected-access
...@@ -178,6 +182,9 @@ def train(params, strategy, dataset=None): ...@@ -178,6 +182,9 @@ def train(params, strategy, dataset=None):
def run(): def run():
"""Runs NHNet using Keras APIs.""" """Runs NHNet using Keras APIs."""
if FLAGS.enable_mlir_bridge:
tf.config.experimental.enable_mlir_bridge()
strategy = distribution_utils.get_distribution_strategy( strategy = distribution_utils.get_distribution_strategy(
distribution_strategy=FLAGS.distribution_strategy, tpu_address=FLAGS.tpu) distribution_strategy=FLAGS.distribution_strategy, tpu_address=FLAGS.tpu)
if strategy: if strategy:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment