Commit e1d7dafb authored by Reed Wanderman-Milne's avatar Reed Wanderman-Milne Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 342354179
parent 6d6a78a2
......@@ -52,6 +52,11 @@ def configure_optimizer(optimizer,
def set_mixed_precision_policy(dtype, loss_scale=None,
use_experimental_api=True):
"""Sets mix precision policy."""
assert use_experimental_api or loss_scale is None, (
'loss_scale cannot be specified if use_experimental_api is False. If the '
'non-experimental API is used, specify the loss scaling configuration '
'when creating the LossScaleOptimizer instead.'
)
if dtype == tf.float16:
# TODO(b/171936854): Move all methods to non-experimental api.
if use_experimental_api:
......
......@@ -220,10 +220,7 @@ def run_ncf(_):
model_helpers.apply_clean(FLAGS)
if FLAGS.dtype == "fp16" and FLAGS.fp16_implementation == "keras":
policy = tf.keras.mixed_precision.experimental.Policy(
"mixed_float16",
loss_scale=flags_core.get_loss_scale(FLAGS, default_for_fp16="dynamic"))
tf.keras.mixed_precision.experimental.set_policy(policy)
tf.keras.mixed_precision.set_global_policy("mixed_float16")
strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=FLAGS.distribution_strategy,
......@@ -284,12 +281,17 @@ def run_ncf(_):
optimizer,
loss_scale=flags_core.get_loss_scale(FLAGS,
default_for_fp16="dynamic"))
elif FLAGS.dtype == "fp16" and params["keras_use_ctl"]:
# When keras_use_ctl is False, instead Model.fit() automatically applies
# loss scaling so we don't need to create a LossScaleOptimizer.
optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer(
optimizer,
tf.keras.mixed_precision.experimental.global_policy().loss_scale)
elif FLAGS.dtype == "fp16":
loss_scale = flags_core.get_loss_scale(FLAGS, default_for_fp16="dynamic")
# Note Model.compile automatically wraps the optimizer with a
# LossScaleOptimizer using dynamic loss scaling. We explicitly wrap it
# here for the case where a custom training loop or fixed loss scale is
# used.
if loss_scale == "dynamic":
optimizer = tf.keras.mixed_precision.LossScaleOptimizer(optimizer)
else:
optimizer = tf.keras.mixed_precision.LossScaleOptimizer(
optimizer, dynamic=False, initial_scale=loss_scale)
if params["keras_use_ctl"]:
train_loss, eval_results = run_ncf_custom_training(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment