"torchvision/models/vscode:/vscode.git/clone" did not exist on "f16b67234260f0a32f6106313ad42e102ada6fa0"
Commit 0886b384 authored by Lucy Fox's avatar Lucy Fox Committed by A. Unique TensorFlower
Browse files

Test Transformer training and evaluation with MLIR TF to XLA bridge enabled.

PiperOrigin-RevId: 308666667
parent 5ab76b51
......@@ -109,6 +109,10 @@ def define_transformer_flags():
flags.DEFINE_boolean(
name='enable_metrics_in_training', default=False,
help='Whether to enable metrics during training.')
flags.DEFINE_boolean(
name='enable_mlir_bridge',
default=False,
help='Whether to enable the TF to XLA bridge.')
flags.DEFINE_string(
name='profile_steps', default=None,
help='Save profiling data to model dir at given range of steps. The '
......
......@@ -470,6 +470,8 @@ def _ensure_dir(log_dir):
def main(_):
flags_obj = flags.FLAGS
if flags_obj.enable_mlir_bridge:
tf.config.experimental.enable_mlir_bridge()
task = TransformerTask(flags_obj)
# Execute flag override logic for better model performance
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment