Commit 51cb03b0 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Deprecate the graph rewrite path for fp16. This is no longer a TF2 api and there is no usage.

PiperOrigin-RevId: 410629444
parent 743c28aa
......@@ -14,14 +14,19 @@
"""Functions and classes related to training performance."""
from absl import logging
import tensorflow as tf
def configure_optimizer(optimizer,
use_float16=False,
use_graph_rewrite=False,
loss_scale=None):
loss_scale=None,
use_graph_rewrite=None):
"""Configures optimizer object with performance options."""
if use_graph_rewrite is not None:
logging.warning('`use_graph_rewrite` is deprecated inside '
'`configure_optimizer`. Please remove the usage.')
del use_graph_rewrite
if use_float16:
if loss_scale in (None, 'dynamic'):
optimizer = tf.keras.mixed_precision.LossScaleOptimizer(optimizer)
......@@ -29,13 +34,6 @@ def configure_optimizer(optimizer,
# loss_scale is a number. We interpret that as a fixed loss scale.
optimizer = tf.keras.mixed_precision.LossScaleOptimizer(
optimizer, dynamic=False, initial_scale=loss_scale)
if use_graph_rewrite:
# Note: the model dtype must be 'float32', which will ensure
# tf.keras.mixed_precision and enable_mixed_precision_graph_rewrite do not
# double up.
optimizer = (
tf.compat.v1.mixed_precision.enable_mixed_precision_graph_rewrite(
optimizer))
return optimizer
......
......@@ -121,9 +121,5 @@ def use_float16():
return flags_core.get_tf_dtype(flags.FLAGS) == tf.float16
def use_graph_rewrite():
return flags.FLAGS.fp16_implementation == 'graph_rewrite'
def get_loss_scale():
return flags_core.get_loss_scale(flags.FLAGS, default_for_fp16='dynamic')
......@@ -150,8 +150,7 @@ def run_bert_classifier(strategy,
FLAGS.optimizer_type)
classifier_model.optimizer = performance.configure_optimizer(
optimizer,
use_float16=common_flags.use_float16(),
use_graph_rewrite=common_flags.use_graph_rewrite())
use_float16=common_flags.use_float16())
return classifier_model, core_model
# tf.keras.losses objects accept optional sample_weight arguments (eg. coming
......
......@@ -125,8 +125,7 @@ def run_customized_training(strategy,
end_lr, optimizer_type)
pretrain_model.optimizer = performance.configure_optimizer(
optimizer,
use_float16=common_flags.use_float16(),
use_graph_rewrite=common_flags.use_graph_rewrite())
use_float16=common_flags.use_float16())
return pretrain_model, core_model
trained_model = model_training_utils.run_customized_training_loop(
......
......@@ -252,8 +252,7 @@ def train_squad(strategy,
squad_model.optimizer = performance.configure_optimizer(
optimizer,
use_float16=common_flags.use_float16(),
use_graph_rewrite=common_flags.use_graph_rewrite())
use_float16=common_flags.use_float16())
return squad_model, core_model
# Only when explicit_allreduce = True, post_allreduce_callbacks and
......
......@@ -440,7 +440,6 @@ class TransformerTask(object):
opt = performance.configure_optimizer(
opt,
use_float16=params["dtype"] == tf.float16,
use_graph_rewrite=self.flags_obj.fp16_implementation == "graph_rewrite",
loss_scale=flags_core.get_loss_scale(
self.flags_obj, default_for_fp16="dynamic"))
......
......@@ -401,7 +401,6 @@ class YoloTask(base_task.Task):
use_float16 = runtime_config.mixed_precision_dtype == 'float16'
optimizer = performance.configure_optimizer(
optimizer,
use_graph_rewrite=False,
use_float16=use_float16,
loss_scale=runtime_config.loss_scale)
......
......@@ -72,14 +72,9 @@ class ResnetRunnable(orbit.StandardTrainer, orbit.StandardEvaluator):
# Make sure iterations variable is created inside scope.
self.global_step = self.optimizer.iterations
use_graph_rewrite = flags_obj.fp16_implementation == 'graph_rewrite'
if use_graph_rewrite and not flags_obj.use_tf_function:
raise ValueError('--fp16_implementation=graph_rewrite requires '
'--use_tf_function to be true')
self.optimizer = performance.configure_optimizer(
self.optimizer,
use_float16=self.dtype == tf.float16,
use_graph_rewrite=use_graph_rewrite,
loss_scale=flags_core.get_loss_scale(flags_obj, default_for_fp16=128))
self.train_loss = tf.keras.metrics.Mean('train_loss', dtype=tf.float32)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment