Commit e1d7dafb authored by Reed Wanderman-Milne's avatar Reed Wanderman-Milne Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 342354179
parent 6d6a78a2
...@@ -52,6 +52,11 @@ def configure_optimizer(optimizer, ...@@ -52,6 +52,11 @@ def configure_optimizer(optimizer,
def set_mixed_precision_policy(dtype, loss_scale=None, def set_mixed_precision_policy(dtype, loss_scale=None,
use_experimental_api=True): use_experimental_api=True):
"""Sets mix precision policy.""" """Sets mix precision policy."""
assert use_experimental_api or loss_scale is None, (
'loss_scale cannot be specified if use_experimental_api is False. If the '
'non-experimental API is used, specify the loss scaling configuration '
'when creating the LossScaleOptimizer instead.'
)
if dtype == tf.float16: if dtype == tf.float16:
# TODO(b/171936854): Move all methods to non-experimental api. # TODO(b/171936854): Move all methods to non-experimental api.
if use_experimental_api: if use_experimental_api:
......
...@@ -220,10 +220,7 @@ def run_ncf(_): ...@@ -220,10 +220,7 @@ def run_ncf(_):
model_helpers.apply_clean(FLAGS) model_helpers.apply_clean(FLAGS)
if FLAGS.dtype == "fp16" and FLAGS.fp16_implementation == "keras": if FLAGS.dtype == "fp16" and FLAGS.fp16_implementation == "keras":
policy = tf.keras.mixed_precision.experimental.Policy( tf.keras.mixed_precision.set_global_policy("mixed_float16")
"mixed_float16",
loss_scale=flags_core.get_loss_scale(FLAGS, default_for_fp16="dynamic"))
tf.keras.mixed_precision.experimental.set_policy(policy)
strategy = distribute_utils.get_distribution_strategy( strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=FLAGS.distribution_strategy, distribution_strategy=FLAGS.distribution_strategy,
...@@ -284,12 +281,17 @@ def run_ncf(_): ...@@ -284,12 +281,17 @@ def run_ncf(_):
optimizer, optimizer,
loss_scale=flags_core.get_loss_scale(FLAGS, loss_scale=flags_core.get_loss_scale(FLAGS,
default_for_fp16="dynamic")) default_for_fp16="dynamic"))
elif FLAGS.dtype == "fp16" and params["keras_use_ctl"]: elif FLAGS.dtype == "fp16":
# When keras_use_ctl is False, instead Model.fit() automatically applies loss_scale = flags_core.get_loss_scale(FLAGS, default_for_fp16="dynamic")
# loss scaling so we don't need to create a LossScaleOptimizer. # Note Model.compile automatically wraps the optimizer with a
optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer( # LossScaleOptimizer using dynamic loss scaling. We explicitly wrap it
optimizer, # here for the case where a custom training loop or fixed loss scale is
tf.keras.mixed_precision.experimental.global_policy().loss_scale) # used.
if loss_scale == "dynamic":
optimizer = tf.keras.mixed_precision.LossScaleOptimizer(optimizer)
else:
optimizer = tf.keras.mixed_precision.LossScaleOptimizer(
optimizer, dynamic=False, initial_scale=loss_scale)
if params["keras_use_ctl"]: if params["keras_use_ctl"]:
train_loss, eval_results = run_ncf_custom_training( train_loss, eval_results = run_ncf_custom_training(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment