Commit cc12499b authored by Reed Wanderman-Milne's avatar Reed Wanderman-Milne Committed by A. Unique TensorFlower
Browse files

Use nonexperimental LSO API in base_task.py.

This shouldn't break any official models, since I changed all LossScaleOptimizer isinstance checks to use the nonexperimental version (the experimental LSO subclasses the nonexperimental LSO, so changing isinstance checks in this way is always safe).

PiperOrigin-RevId: 366891847
parent 2ae06c8a
...@@ -81,7 +81,7 @@ class Task(tf.Module, metaclass=abc.ABCMeta): ...@@ -81,7 +81,7 @@ class Task(tf.Module, metaclass=abc.ABCMeta):
optimizer, optimizer,
use_float16=runtime_config.mixed_precision_dtype == "float16", use_float16=runtime_config.mixed_precision_dtype == "float16",
loss_scale=runtime_config.loss_scale, loss_scale=runtime_config.loss_scale,
use_experimental_api=True) use_experimental_api=False)
return optimizer return optimizer
......
...@@ -86,16 +86,14 @@ class ImageClassificationTask(image_classification.ImageClassificationTask): ...@@ -86,16 +86,14 @@ class ImageClassificationTask(image_classification.ImageClassificationTask):
# For mixed_precision policy, when LossScaleOptimizer is used, loss is # For mixed_precision policy, when LossScaleOptimizer is used, loss is
# scaled for numerical stability. # scaled for numerical stability.
if isinstance( if isinstance(optimizer, tf.keras.mixed_precision.LossScaleOptimizer):
optimizer, tf.keras.mixed_precision.experimental.LossScaleOptimizer):
scaled_loss = optimizer.get_scaled_loss(scaled_loss) scaled_loss = optimizer.get_scaled_loss(scaled_loss)
tvars = model.trainable_variables tvars = model.trainable_variables
grads = tape.gradient(scaled_loss, tvars) grads = tape.gradient(scaled_loss, tvars)
# Scales back gradient before apply_gradients when LossScaleOptimizer is # Scales back gradient before apply_gradients when LossScaleOptimizer is
# used. # used.
if isinstance( if isinstance(optimizer, tf.keras.mixed_precision.LossScaleOptimizer):
optimizer, tf.keras.mixed_precision.experimental.LossScaleOptimizer):
grads = optimizer.get_unscaled_gradients(grads) grads = optimizer.get_unscaled_gradients(grads)
# Apply gradient clipping. # Apply gradient clipping.
......
...@@ -184,7 +184,7 @@ class YT8MTask(base_task.Task): ...@@ -184,7 +184,7 @@ class YT8MTask(base_task.Task):
# For mixed_precision policy, when LossScaleOptimizer is used, loss is # For mixed_precision policy, when LossScaleOptimizer is used, loss is
# scaled for numerical stability. # scaled for numerical stability.
if isinstance(optimizer, if isinstance(optimizer,
tf.keras.mixed_precision.experimental.LossScaleOptimizer): tf.keras.mixed_precision.LossScaleOptimizer):
scaled_loss = optimizer.get_scaled_loss(scaled_loss) scaled_loss = optimizer.get_scaled_loss(scaled_loss)
tvars = model.trainable_variables tvars = model.trainable_variables
...@@ -192,7 +192,7 @@ class YT8MTask(base_task.Task): ...@@ -192,7 +192,7 @@ class YT8MTask(base_task.Task):
# Scales back gradient before apply_gradients when LossScaleOptimizer is # Scales back gradient before apply_gradients when LossScaleOptimizer is
# used. # used.
if isinstance(optimizer, if isinstance(optimizer,
tf.keras.mixed_precision.experimental.LossScaleOptimizer): tf.keras.mixed_precision.LossScaleOptimizer):
grads = optimizer.get_unscaled_gradients(grads) grads = optimizer.get_unscaled_gradients(grads)
# Apply gradient clipping. # Apply gradient clipping.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment