Commit 51cb03b0 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Deprecate the graph rewrite path for fp16. This is no longer a TF2 api and there is no usage.

PiperOrigin-RevId: 410629444
parent 743c28aa
...@@ -14,14 +14,19 @@ ...@@ -14,14 +14,19 @@
"""Functions and classes related to training performance.""" """Functions and classes related to training performance."""
from absl import logging
import tensorflow as tf import tensorflow as tf
def configure_optimizer(optimizer, def configure_optimizer(optimizer,
use_float16=False, use_float16=False,
use_graph_rewrite=False, loss_scale=None,
loss_scale=None): use_graph_rewrite=None):
"""Configures optimizer object with performance options.""" """Configures optimizer object with performance options."""
if use_graph_rewrite is not None:
logging.warning('`use_graph_rewrite` is deprecated inside '
'`configure_optimizer`. Please remove the usage.')
del use_graph_rewrite
if use_float16: if use_float16:
if loss_scale in (None, 'dynamic'): if loss_scale in (None, 'dynamic'):
optimizer = tf.keras.mixed_precision.LossScaleOptimizer(optimizer) optimizer = tf.keras.mixed_precision.LossScaleOptimizer(optimizer)
...@@ -29,13 +34,6 @@ def configure_optimizer(optimizer, ...@@ -29,13 +34,6 @@ def configure_optimizer(optimizer,
# loss_scale is a number. We interpret that as a fixed loss scale. # loss_scale is a number. We interpret that as a fixed loss scale.
optimizer = tf.keras.mixed_precision.LossScaleOptimizer( optimizer = tf.keras.mixed_precision.LossScaleOptimizer(
optimizer, dynamic=False, initial_scale=loss_scale) optimizer, dynamic=False, initial_scale=loss_scale)
if use_graph_rewrite:
# Note: the model dtype must be 'float32', which will ensure
# tf.keras.mixed_precision and enable_mixed_precision_graph_rewrite do not
# double up.
optimizer = (
tf.compat.v1.mixed_precision.enable_mixed_precision_graph_rewrite(
optimizer))
return optimizer return optimizer
......
...@@ -121,9 +121,5 @@ def use_float16(): ...@@ -121,9 +121,5 @@ def use_float16():
return flags_core.get_tf_dtype(flags.FLAGS) == tf.float16 return flags_core.get_tf_dtype(flags.FLAGS) == tf.float16
def use_graph_rewrite():
return flags.FLAGS.fp16_implementation == 'graph_rewrite'
def get_loss_scale(): def get_loss_scale():
return flags_core.get_loss_scale(flags.FLAGS, default_for_fp16='dynamic') return flags_core.get_loss_scale(flags.FLAGS, default_for_fp16='dynamic')
...@@ -150,8 +150,7 @@ def run_bert_classifier(strategy, ...@@ -150,8 +150,7 @@ def run_bert_classifier(strategy,
FLAGS.optimizer_type) FLAGS.optimizer_type)
classifier_model.optimizer = performance.configure_optimizer( classifier_model.optimizer = performance.configure_optimizer(
optimizer, optimizer,
use_float16=common_flags.use_float16(), use_float16=common_flags.use_float16())
use_graph_rewrite=common_flags.use_graph_rewrite())
return classifier_model, core_model return classifier_model, core_model
# tf.keras.losses objects accept optional sample_weight arguments (eg. coming # tf.keras.losses objects accept optional sample_weight arguments (eg. coming
......
...@@ -125,8 +125,7 @@ def run_customized_training(strategy, ...@@ -125,8 +125,7 @@ def run_customized_training(strategy,
end_lr, optimizer_type) end_lr, optimizer_type)
pretrain_model.optimizer = performance.configure_optimizer( pretrain_model.optimizer = performance.configure_optimizer(
optimizer, optimizer,
use_float16=common_flags.use_float16(), use_float16=common_flags.use_float16())
use_graph_rewrite=common_flags.use_graph_rewrite())
return pretrain_model, core_model return pretrain_model, core_model
trained_model = model_training_utils.run_customized_training_loop( trained_model = model_training_utils.run_customized_training_loop(
......
...@@ -252,8 +252,7 @@ def train_squad(strategy, ...@@ -252,8 +252,7 @@ def train_squad(strategy,
squad_model.optimizer = performance.configure_optimizer( squad_model.optimizer = performance.configure_optimizer(
optimizer, optimizer,
use_float16=common_flags.use_float16(), use_float16=common_flags.use_float16())
use_graph_rewrite=common_flags.use_graph_rewrite())
return squad_model, core_model return squad_model, core_model
# Only when explicit_allreduce = True, post_allreduce_callbacks and # Only when explicit_allreduce = True, post_allreduce_callbacks and
......
...@@ -440,7 +440,6 @@ class TransformerTask(object): ...@@ -440,7 +440,6 @@ class TransformerTask(object):
opt = performance.configure_optimizer( opt = performance.configure_optimizer(
opt, opt,
use_float16=params["dtype"] == tf.float16, use_float16=params["dtype"] == tf.float16,
use_graph_rewrite=self.flags_obj.fp16_implementation == "graph_rewrite",
loss_scale=flags_core.get_loss_scale( loss_scale=flags_core.get_loss_scale(
self.flags_obj, default_for_fp16="dynamic")) self.flags_obj, default_for_fp16="dynamic"))
......
...@@ -401,7 +401,6 @@ class YoloTask(base_task.Task): ...@@ -401,7 +401,6 @@ class YoloTask(base_task.Task):
use_float16 = runtime_config.mixed_precision_dtype == 'float16' use_float16 = runtime_config.mixed_precision_dtype == 'float16'
optimizer = performance.configure_optimizer( optimizer = performance.configure_optimizer(
optimizer, optimizer,
use_graph_rewrite=False,
use_float16=use_float16, use_float16=use_float16,
loss_scale=runtime_config.loss_scale) loss_scale=runtime_config.loss_scale)
......
...@@ -72,14 +72,9 @@ class ResnetRunnable(orbit.StandardTrainer, orbit.StandardEvaluator): ...@@ -72,14 +72,9 @@ class ResnetRunnable(orbit.StandardTrainer, orbit.StandardEvaluator):
# Make sure iterations variable is created inside scope. # Make sure iterations variable is created inside scope.
self.global_step = self.optimizer.iterations self.global_step = self.optimizer.iterations
use_graph_rewrite = flags_obj.fp16_implementation == 'graph_rewrite'
if use_graph_rewrite and not flags_obj.use_tf_function:
raise ValueError('--fp16_implementation=graph_rewrite requires '
'--use_tf_function to be true')
self.optimizer = performance.configure_optimizer( self.optimizer = performance.configure_optimizer(
self.optimizer, self.optimizer,
use_float16=self.dtype == tf.float16, use_float16=self.dtype == tf.float16,
use_graph_rewrite=use_graph_rewrite,
loss_scale=flags_core.get_loss_scale(flags_obj, default_for_fp16=128)) loss_scale=flags_core.get_loss_scale(flags_obj, default_for_fp16=128))
self.train_loss = tf.keras.metrics.Mean('train_loss', dtype=tf.float32) self.train_loss = tf.keras.metrics.Mean('train_loss', dtype=tf.float32)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment