Commit 21286f77 authored by Reed Wanderman-Milne's avatar Reed Wanderman-Milne Committed by A. Unique TensorFlower
Browse files

With float16, always use LossScaleOptimizer.

Before, it was too easy to accidentally forget to set runtime.loss_scale, which had to always be done if mixed precision is used, otherwise the model would converge to worse accuracy. Now, all that needs to be done to use mixed precision is to set runtime.mixed_precision_dtype=float16.

PiperOrigin-RevId: 383767033
parent c7d4dd39
...@@ -79,7 +79,7 @@ class Task(tf.Module, metaclass=abc.ABCMeta): ...@@ -79,7 +79,7 @@ class Task(tf.Module, metaclass=abc.ABCMeta):
optimizer = opt_factory.build_optimizer(opt_factory.build_learning_rate()) optimizer = opt_factory.build_optimizer(opt_factory.build_learning_rate())
# Configuring optimizer when loss_scale is set in runtime config. This helps # Configuring optimizer when loss_scale is set in runtime config. This helps
# avoiding overflow/underflow for float16 computations. # avoiding overflow/underflow for float16 computations.
if runtime_config and runtime_config.loss_scale: if runtime_config:
optimizer = performance.configure_optimizer( optimizer = performance.configure_optimizer(
optimizer, optimizer,
use_float16=runtime_config.mixed_precision_dtype == "float16", use_float16=runtime_config.mixed_precision_dtype == "float16",
......
...@@ -303,13 +303,16 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -303,13 +303,16 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase):
}, },
}))) })))
trainer = self.create_test_trainer(config) trainer = self.create_test_trainer(config)
if mixed_precision_dtype != 'float16': if mixed_precision_dtype == 'float16':
self.assertIsInstance(trainer.optimizer, tf.keras.optimizers.SGD)
elif mixed_precision_dtype == 'float16' and loss_scale is None:
self.assertIsInstance(trainer.optimizer, tf.keras.optimizers.SGD)
else:
self.assertIsInstance(trainer.optimizer, self.assertIsInstance(trainer.optimizer,
tf.keras.mixed_precision.LossScaleOptimizer) tf.keras.mixed_precision.LossScaleOptimizer)
if loss_scale in (None, 'dynamic'):
self.assertTrue(trainer.optimizer.dynamic)
else:
self.assertFalse(trainer.optimizer.dynamic)
self.assertEqual(trainer.optimizer.initial_scale, loss_scale)
else:
self.assertIsInstance(trainer.optimizer, tf.keras.optimizers.SGD)
metrics = trainer.train(tf.convert_to_tensor(5, dtype=tf.int32)) metrics = trainer.train(tf.convert_to_tensor(5, dtype=tf.int32))
self.assertIn('training_loss', metrics) self.assertIn('training_loss', metrics)
......
...@@ -20,10 +20,10 @@ import tensorflow as tf ...@@ -20,10 +20,10 @@ import tensorflow as tf
def configure_optimizer(optimizer, def configure_optimizer(optimizer,
use_float16=False, use_float16=False,
use_graph_rewrite=False, use_graph_rewrite=False,
loss_scale='dynamic'): loss_scale=None):
"""Configures optimizer object with performance options.""" """Configures optimizer object with performance options."""
if use_float16: if use_float16:
if loss_scale == 'dynamic': if loss_scale in (None, 'dynamic'):
optimizer = tf.keras.mixed_precision.LossScaleOptimizer(optimizer) optimizer = tf.keras.mixed_precision.LossScaleOptimizer(optimizer)
else: else:
# loss_scale is a number. We interpret that as a fixed loss scale. # loss_scale is a number. We interpret that as a fixed loss scale.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment