Commit 72a31e9e authored by Pankaj Kanwar's avatar Pankaj Kanwar Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 339927293
parent a75e870b
...@@ -195,14 +195,14 @@ class Task(tf.Module, metaclass=abc.ABCMeta): ...@@ -195,14 +195,14 @@ class Task(tf.Module, metaclass=abc.ABCMeta):
# For mixed precision, when a LossScaleOptimizer is used, the loss is # For mixed precision, when a LossScaleOptimizer is used, the loss is
# scaled to avoid numeric underflow. # scaled to avoid numeric underflow.
if isinstance(optimizer, if isinstance(optimizer,
tf.keras.mixed_precision.experimental.LossScaleOptimizer): tf.keras.mixed_precision.LossScaleOptimizer):
scaled_loss = optimizer.get_scaled_loss(scaled_loss) scaled_loss = optimizer.get_scaled_loss(scaled_loss)
tvars = model.trainable_variables tvars = model.trainable_variables
grads = tape.gradient(scaled_loss, tvars) grads = tape.gradient(scaled_loss, tvars)
if isinstance(optimizer, if isinstance(optimizer,
tf.keras.mixed_precision.experimental.LossScaleOptimizer): tf.keras.mixed_precision.LossScaleOptimizer):
grads = optimizer.get_unscaled_gradients(grads) grads = optimizer.get_unscaled_gradients(grads)
optimizer.apply_gradients(list(zip(grads, tvars))) optimizer.apply_gradients(list(zip(grads, tvars)))
logs = {self.loss: loss} logs = {self.loss: loss}
......
...@@ -96,7 +96,10 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -96,7 +96,10 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase):
}, },
'learning_rate': { 'learning_rate': {
'type': 'constant' 'type': 'constant'
} },
'use_experimental_api': {
'type': False
},
}))) })))
trainer = self.create_test_trainer(config) trainer = self.create_test_trainer(config)
if mixed_precision_dtype != 'float16': if mixed_precision_dtype != 'float16':
...@@ -106,7 +109,7 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -106,7 +109,7 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase):
else: else:
self.assertIsInstance( self.assertIsInstance(
trainer.optimizer, trainer.optimizer,
tf.keras.mixed_precision.experimental.LossScaleOptimizer) tf.keras.mixed_precision.LossScaleOptimizer)
metrics = trainer.train(tf.convert_to_tensor(5, dtype=tf.int32)) metrics = trainer.train(tf.convert_to_tensor(5, dtype=tf.int32))
self.assertIn('training_loss', metrics) self.assertIn('training_loss', metrics)
......
...@@ -21,15 +21,24 @@ import tensorflow as tf ...@@ -21,15 +21,24 @@ import tensorflow as tf
def configure_optimizer(optimizer, def configure_optimizer(optimizer,
use_float16=False, use_float16=False,
use_graph_rewrite=False, use_graph_rewrite=False,
loss_scale='dynamic'): loss_scale='dynamic',
use_experimental_api=True):
"""Configures optimizer object with performance options.""" """Configures optimizer object with performance options."""
if use_float16: if use_float16:
# Wraps optimizer with a LossScaleOptimizer. This is done automatically # TODO(b/171936854): Move all methods to non-experimental api.
# in compile() with the "mixed_float16" policy, but since we do not call if use_experimental_api:
# compile(), we must wrap the optimizer manually. # Wraps optimizer with a LossScaleOptimizer. This is done automatically
optimizer = ( # in compile() with the "mixed_float16" policy, but since we do not call
tf.keras.mixed_precision.experimental.LossScaleOptimizer( # compile(), we must wrap the optimizer manually.
optimizer, loss_scale=loss_scale)) optimizer = (
tf.keras.mixed_precision.experimental.LossScaleOptimizer(
optimizer, loss_scale=loss_scale))
elif loss_scale == 'dynamic':
optimizer = tf.keras.mixed_precision.LossScaleOptimizer(optimizer)
else:
# loss_scale is a number. We interpret that as a fixed loss scale.
optimizer = tf.keras.mixed_precision.LossScaleOptimizer(
optimizer, dynamic=False, initial_scale=loss_scale)
if use_graph_rewrite: if use_graph_rewrite:
# Note: the model dtype must be 'float32', which will ensure # Note: the model dtype must be 'float32', which will ensure
# tf.ckeras.mixed_precision and # tf.ckeras.mixed_precision and
...@@ -40,16 +49,26 @@ def configure_optimizer(optimizer, ...@@ -40,16 +49,26 @@ def configure_optimizer(optimizer,
return optimizer return optimizer
def set_mixed_precision_policy(dtype, loss_scale=None): def set_mixed_precision_policy(dtype, loss_scale=None,
use_experimental_api=True):
"""Sets mix precision policy.""" """Sets mix precision policy."""
if dtype == tf.float16: if dtype == tf.float16:
policy = tf.keras.mixed_precision.experimental.Policy( # TODO(b/171936854): Move all methods to non-experimental api.
'mixed_float16', loss_scale=loss_scale) if use_experimental_api:
tf.keras.mixed_precision.experimental.set_policy(policy) policy = tf.keras.mixed_precision.experimental.Policy(
'mixed_float16', loss_scale=loss_scale)
tf.keras.mixed_precision.experimental.set_policy(policy)
else:
tf.keras.mixed_precision.set_global_policy('mixed_float16')
elif dtype == tf.bfloat16: elif dtype == tf.bfloat16:
policy = tf.keras.mixed_precision.experimental.Policy('mixed_bfloat16') if use_experimental_api:
tf.keras.mixed_precision.experimental.set_policy(policy) tf.keras.mixed_precision.experimental.set_policy('mixed_bfloat16')
else:
tf.keras.mixed_precision.set_global_policy('mixed_bfloat16')
elif dtype == tf.float32: elif dtype == tf.float32:
tf.keras.mixed_precision.experimental.set_policy('float32') if use_experimental_api:
tf.keras.mixed_precision.experimental.set_policy('float32')
else:
tf.keras.mixed_precision.set_global_policy('float32')
else: else:
raise ValueError('Unexpected dtype: %s' % dtype) raise ValueError('Unexpected dtype: %s' % dtype)
...@@ -122,7 +122,7 @@ def minimize_using_explicit_allreduce(tape, ...@@ -122,7 +122,7 @@ def minimize_using_explicit_allreduce(tape,
in one pack. in one pack.
""" """
if isinstance(optimizer, if isinstance(optimizer,
tf.keras.mixed_precision.experimental.LossScaleOptimizer): tf.keras.mixed_precision.LossScaleOptimizer):
# FP16 GPU code path # FP16 GPU code path
with tape: with tape:
scaled_loss = optimizer.get_scaled_loss(loss) scaled_loss = optimizer.get_scaled_loss(loss)
......
...@@ -168,7 +168,7 @@ class ImageClassificationTask(base_task.Task): ...@@ -168,7 +168,7 @@ class ImageClassificationTask(base_task.Task):
# For mixed_precision policy, when LossScaleOptimizer is used, loss is # For mixed_precision policy, when LossScaleOptimizer is used, loss is
# scaled for numerical stability. # scaled for numerical stability.
if isinstance( if isinstance(
optimizer, tf.keras.mixed_precision.experimental.LossScaleOptimizer): optimizer, tf.keras.mixed_precision.LossScaleOptimizer):
scaled_loss = optimizer.get_scaled_loss(scaled_loss) scaled_loss = optimizer.get_scaled_loss(scaled_loss)
tvars = model.trainable_variables tvars = model.trainable_variables
...@@ -176,7 +176,7 @@ class ImageClassificationTask(base_task.Task): ...@@ -176,7 +176,7 @@ class ImageClassificationTask(base_task.Task):
# Scales back gradient before apply_gradients when LossScaleOptimizer is # Scales back gradient before apply_gradients when LossScaleOptimizer is
# used. # used.
if isinstance( if isinstance(
optimizer, tf.keras.mixed_precision.experimental.LossScaleOptimizer): optimizer, tf.keras.mixed_precision.LossScaleOptimizer):
grads = optimizer.get_unscaled_gradients(grads) grads = optimizer.get_unscaled_gradients(grads)
# Apply gradient clipping. # Apply gradient clipping.
......
...@@ -99,7 +99,8 @@ def run(flags_obj): ...@@ -99,7 +99,8 @@ def run(flags_obj):
""" """
keras_utils.set_session_config( keras_utils.set_session_config(
enable_xla=flags_obj.enable_xla) enable_xla=flags_obj.enable_xla)
performance.set_mixed_precision_policy(flags_core.get_tf_dtype(flags_obj)) performance.set_mixed_precision_policy(flags_core.get_tf_dtype(flags_obj),
use_experimental_api=False)
if tf.config.list_physical_devices('GPU'): if tf.config.list_physical_devices('GPU'):
if flags_obj.tf_gpu_thread_mode: if flags_obj.tf_gpu_thread_mode:
......
...@@ -81,7 +81,8 @@ class ResnetRunnable(orbit.StandardTrainer, orbit.StandardEvaluator): ...@@ -81,7 +81,8 @@ class ResnetRunnable(orbit.StandardTrainer, orbit.StandardEvaluator):
self.optimizer, self.optimizer,
use_float16=self.dtype == tf.float16, use_float16=self.dtype == tf.float16,
use_graph_rewrite=use_graph_rewrite, use_graph_rewrite=use_graph_rewrite,
loss_scale=flags_core.get_loss_scale(flags_obj, default_for_fp16=128)) loss_scale=flags_core.get_loss_scale(flags_obj, default_for_fp16=128),
use_experimental_api=False)
self.train_loss = tf.keras.metrics.Mean('train_loss', dtype=tf.float32) self.train_loss = tf.keras.metrics.Mean('train_loss', dtype=tf.float32)
self.train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy( self.train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment