Commit 3eafa4c5 authored by Pankaj Kanwar's avatar Pankaj Kanwar Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 339927293
parent 1308a574
......@@ -195,14 +195,14 @@ class Task(tf.Module, metaclass=abc.ABCMeta):
# For mixed precision, when a LossScaleOptimizer is used, the loss is
# scaled to avoid numeric underflow.
if isinstance(optimizer,
tf.keras.mixed_precision.experimental.LossScaleOptimizer):
tf.keras.mixed_precision.LossScaleOptimizer):
scaled_loss = optimizer.get_scaled_loss(scaled_loss)
tvars = model.trainable_variables
grads = tape.gradient(scaled_loss, tvars)
if isinstance(optimizer,
tf.keras.mixed_precision.experimental.LossScaleOptimizer):
tf.keras.mixed_precision.LossScaleOptimizer):
grads = optimizer.get_unscaled_gradients(grads)
optimizer.apply_gradients(list(zip(grads, tvars)))
logs = {self.loss: loss}
......
......@@ -96,7 +96,10 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase):
},
'learning_rate': {
'type': 'constant'
}
},
'use_experimental_api': {
'type': False
},
})))
trainer = self.create_test_trainer(config)
if mixed_precision_dtype != 'float16':
......@@ -106,7 +109,7 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase):
else:
self.assertIsInstance(
trainer.optimizer,
tf.keras.mixed_precision.experimental.LossScaleOptimizer)
tf.keras.mixed_precision.LossScaleOptimizer)
metrics = trainer.train(tf.convert_to_tensor(5, dtype=tf.int32))
self.assertIn('training_loss', metrics)
......
......@@ -21,15 +21,24 @@ import tensorflow as tf
def configure_optimizer(optimizer,
use_float16=False,
use_graph_rewrite=False,
loss_scale='dynamic'):
loss_scale='dynamic',
use_experimental_api=True):
"""Configures optimizer object with performance options."""
if use_float16:
# TODO(b/171936854): Move all methods to non-experimental api.
if use_experimental_api:
# Wraps optimizer with a LossScaleOptimizer. This is done automatically
# in compile() with the "mixed_float16" policy, but since we do not call
# compile(), we must wrap the optimizer manually.
optimizer = (
tf.keras.mixed_precision.experimental.LossScaleOptimizer(
optimizer, loss_scale=loss_scale))
elif loss_scale == 'dynamic':
optimizer = tf.keras.mixed_precision.LossScaleOptimizer(optimizer)
else:
# loss_scale is a number. We interpret that as a fixed loss scale.
optimizer = tf.keras.mixed_precision.LossScaleOptimizer(
optimizer, dynamic=False, initial_scale=loss_scale)
if use_graph_rewrite:
# Note: the model dtype must be 'float32', which will ensure
# tf.ckeras.mixed_precision and
......@@ -40,16 +49,26 @@ def configure_optimizer(optimizer,
return optimizer
def set_mixed_precision_policy(dtype, loss_scale=None):
def set_mixed_precision_policy(dtype, loss_scale=None,
use_experimental_api=True):
"""Sets mix precision policy."""
if dtype == tf.float16:
# TODO(b/171936854): Move all methods to non-experimental api.
if use_experimental_api:
policy = tf.keras.mixed_precision.experimental.Policy(
'mixed_float16', loss_scale=loss_scale)
tf.keras.mixed_precision.experimental.set_policy(policy)
else:
tf.keras.mixed_precision.set_global_policy('mixed_float16')
elif dtype == tf.bfloat16:
policy = tf.keras.mixed_precision.experimental.Policy('mixed_bfloat16')
tf.keras.mixed_precision.experimental.set_policy(policy)
if use_experimental_api:
tf.keras.mixed_precision.experimental.set_policy('mixed_bfloat16')
else:
tf.keras.mixed_precision.set_global_policy('mixed_bfloat16')
elif dtype == tf.float32:
if use_experimental_api:
tf.keras.mixed_precision.experimental.set_policy('float32')
else:
tf.keras.mixed_precision.set_global_policy('float32')
else:
raise ValueError('Unexpected dtype: %s' % dtype)
......@@ -122,7 +122,7 @@ def minimize_using_explicit_allreduce(tape,
in one pack.
"""
if isinstance(optimizer,
tf.keras.mixed_precision.experimental.LossScaleOptimizer):
tf.keras.mixed_precision.LossScaleOptimizer):
# FP16 GPU code path
with tape:
scaled_loss = optimizer.get_scaled_loss(loss)
......
......@@ -168,7 +168,7 @@ class ImageClassificationTask(base_task.Task):
# For mixed_precision policy, when LossScaleOptimizer is used, loss is
# scaled for numerical stability.
if isinstance(
optimizer, tf.keras.mixed_precision.experimental.LossScaleOptimizer):
optimizer, tf.keras.mixed_precision.LossScaleOptimizer):
scaled_loss = optimizer.get_scaled_loss(scaled_loss)
tvars = model.trainable_variables
......@@ -176,7 +176,7 @@ class ImageClassificationTask(base_task.Task):
# Scales back gradient before apply_gradients when LossScaleOptimizer is
# used.
if isinstance(
optimizer, tf.keras.mixed_precision.experimental.LossScaleOptimizer):
optimizer, tf.keras.mixed_precision.LossScaleOptimizer):
grads = optimizer.get_unscaled_gradients(grads)
# Apply gradient clipping.
......
......@@ -99,7 +99,8 @@ def run(flags_obj):
"""
keras_utils.set_session_config(
enable_xla=flags_obj.enable_xla)
performance.set_mixed_precision_policy(flags_core.get_tf_dtype(flags_obj))
performance.set_mixed_precision_policy(flags_core.get_tf_dtype(flags_obj),
use_experimental_api=False)
if tf.config.list_physical_devices('GPU'):
if flags_obj.tf_gpu_thread_mode:
......
......@@ -81,7 +81,8 @@ class ResnetRunnable(orbit.StandardTrainer, orbit.StandardEvaluator):
self.optimizer,
use_float16=self.dtype == tf.float16,
use_graph_rewrite=use_graph_rewrite,
loss_scale=flags_core.get_loss_scale(flags_obj, default_for_fp16=128))
loss_scale=flags_core.get_loss_scale(flags_obj, default_for_fp16=128),
use_experimental_api=False)
self.train_loss = tf.keras.metrics.Mean('train_loss', dtype=tf.float32)
self.train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment