Commit 92f1d3eb authored by Reed Wanderman-Milne's avatar Reed Wanderman-Milne Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 381302776
parent 5c96ad96
...@@ -14,29 +14,16 @@ ...@@ -14,29 +14,16 @@
"""Functions and classes related to training performance.""" """Functions and classes related to training performance."""
from absl import logging
import tensorflow as tf import tensorflow as tf
def configure_optimizer(optimizer, def configure_optimizer(optimizer,
use_float16=False, use_float16=False,
use_graph_rewrite=False, use_graph_rewrite=False,
loss_scale='dynamic', loss_scale='dynamic'):
use_experimental_api=False):
"""Configures optimizer object with performance options.""" """Configures optimizer object with performance options."""
if use_experimental_api:
logging.warning('Passing use_experimental_api=True is deprecated. The '
'argument will be removed in the future.')
if use_float16: if use_float16:
# TODO(b/171936854): Move all methods to non-experimental api. if loss_scale == 'dynamic':
if use_experimental_api:
# Wraps optimizer with a LossScaleOptimizer. This is done automatically
# in compile() with the "mixed_float16" policy, but since we do not call
# compile(), we must wrap the optimizer manually.
optimizer = (
tf.keras.mixed_precision.experimental.LossScaleOptimizer(
optimizer, loss_scale=loss_scale))
elif loss_scale == 'dynamic':
optimizer = tf.keras.mixed_precision.LossScaleOptimizer(optimizer) optimizer = tf.keras.mixed_precision.LossScaleOptimizer(optimizer)
else: else:
# loss_scale is a number. We interpret that as a fixed loss scale. # loss_scale is a number. We interpret that as a fixed loss scale.
...@@ -52,34 +39,17 @@ def configure_optimizer(optimizer, ...@@ -52,34 +39,17 @@ def configure_optimizer(optimizer,
return optimizer return optimizer
def set_mixed_precision_policy(dtype, loss_scale=None, def set_mixed_precision_policy(dtype, loss_scale=None):
use_experimental_api=False): """Sets the global `tf.keras.mixed_precision.Policy`."""
"""Sets mix precision policy.""" # TODO(b/191894773): Remove loss_scale argument
if use_experimental_api: assert loss_scale is None, (
logging.warning('Passing use_experimental_api=True is deprecated. The ' 'The loss_scale argument must be None. The argument exists for '
'argument will be removed in the future.') 'historical reasons and will be removed soon.')
assert use_experimental_api or loss_scale is None, (
'loss_scale cannot be specified if use_experimental_api is False. If the '
'non-experimental API is used, specify the loss scaling configuration '
'when creating the LossScaleOptimizer instead.'
)
if dtype == tf.float16: if dtype == tf.float16:
# TODO(b/171936854): Move all methods to non-experimental api.
if use_experimental_api:
policy = tf.keras.mixed_precision.experimental.Policy(
'mixed_float16', loss_scale=loss_scale)
tf.keras.mixed_precision.experimental.set_policy(policy)
else:
tf.keras.mixed_precision.set_global_policy('mixed_float16') tf.keras.mixed_precision.set_global_policy('mixed_float16')
elif dtype == tf.bfloat16: elif dtype == tf.bfloat16:
if use_experimental_api:
tf.keras.mixed_precision.experimental.set_policy('mixed_bfloat16')
else:
tf.keras.mixed_precision.set_global_policy('mixed_bfloat16') tf.keras.mixed_precision.set_global_policy('mixed_bfloat16')
elif dtype == tf.float32: elif dtype == tf.float32:
if use_experimental_api:
tf.keras.mixed_precision.experimental.set_policy('float32')
else:
tf.keras.mixed_precision.set_global_policy('float32') tf.keras.mixed_precision.set_global_policy('float32')
else: else:
raise ValueError('Unexpected dtype: %s' % dtype) raise ValueError('Unexpected dtype: %s' % dtype)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment