Commit 47067b87 authored by Pankaj Kanwar's avatar Pankaj Kanwar Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 340527173
parent 76b8a67a
...@@ -338,8 +338,7 @@ def run_customized_training_loop( ...@@ -338,8 +338,7 @@ def run_customized_training_loop(
post_allreduce_callbacks, post_allreduce_callbacks,
allreduce_bytes_per_pack) allreduce_bytes_per_pack)
else: else:
if isinstance(optimizer, if isinstance(optimizer, tf.keras.mixed_precision.LossScaleOptimizer):
tf.keras.mixed_precision.experimental.LossScaleOptimizer):
with tape: with tape:
scaled_loss = optimizer.get_scaled_loss(loss) scaled_loss = optimizer.get_scaled_loss(loss)
scaled_grads = tape.gradient(scaled_loss, training_vars) scaled_grads = tape.gradient(scaled_loss, training_vars)
......
...@@ -272,15 +272,13 @@ class MaskRCNNTask(base_task.Task): ...@@ -272,15 +272,13 @@ class MaskRCNNTask(base_task.Task):
# For mixed_precision policy, when LossScaleOptimizer is used, loss is # For mixed_precision policy, when LossScaleOptimizer is used, loss is
# scaled for numerical stability. # scaled for numerical stability.
if isinstance( if isinstance(optimizer, tf.keras.mixed_precision.LossScaleOptimizer):
optimizer, tf.keras.mixed_precision.experimental.LossScaleOptimizer):
scaled_loss = optimizer.get_scaled_loss(scaled_loss) scaled_loss = optimizer.get_scaled_loss(scaled_loss)
tvars = model.trainable_variables tvars = model.trainable_variables
grads = tape.gradient(scaled_loss, tvars) grads = tape.gradient(scaled_loss, tvars)
# Scales back gradient when LossScaleOptimizer is used. # Scales back gradient when LossScaleOptimizer is used.
if isinstance( if isinstance(optimizer, tf.keras.mixed_precision.LossScaleOptimizer):
optimizer, tf.keras.mixed_precision.experimental.LossScaleOptimizer):
grads = optimizer.get_unscaled_gradients(grads) grads = optimizer.get_unscaled_gradients(grads)
# Apply gradient clipping. # Apply gradient clipping.
......
...@@ -210,15 +210,13 @@ class RetinaNetTask(base_task.Task): ...@@ -210,15 +210,13 @@ class RetinaNetTask(base_task.Task):
# For mixed_precision policy, when LossScaleOptimizer is used, loss is # For mixed_precision policy, when LossScaleOptimizer is used, loss is
# scaled for numerical stability. # scaled for numerical stability.
if isinstance( if isinstance(optimizer, tf.keras.mixed_precision.LossScaleOptimizer):
optimizer, tf.keras.mixed_precision.experimental.LossScaleOptimizer):
scaled_loss = optimizer.get_scaled_loss(scaled_loss) scaled_loss = optimizer.get_scaled_loss(scaled_loss)
tvars = model.trainable_variables tvars = model.trainable_variables
grads = tape.gradient(scaled_loss, tvars) grads = tape.gradient(scaled_loss, tvars)
# Scales back gradient when LossScaleOptimizer is used. # Scales back gradient when LossScaleOptimizer is used.
if isinstance( if isinstance(optimizer, tf.keras.mixed_precision.LossScaleOptimizer):
optimizer, tf.keras.mixed_precision.experimental.LossScaleOptimizer):
grads = optimizer.get_unscaled_gradients(grads) grads = optimizer.get_unscaled_gradients(grads)
# Apply gradient clipping. # Apply gradient clipping.
......
...@@ -179,16 +179,14 @@ class SemanticSegmentationTask(base_task.Task): ...@@ -179,16 +179,14 @@ class SemanticSegmentationTask(base_task.Task):
# For mixed_precision policy, when LossScaleOptimizer is used, loss is # For mixed_precision policy, when LossScaleOptimizer is used, loss is
# scaled for numerical stability. # scaled for numerical stability.
if isinstance( if isinstance(optimizer, tf.keras.mixed_precision.LossScaleOptimizer):
optimizer, tf.keras.mixed_precision.experimental.LossScaleOptimizer):
scaled_loss = optimizer.get_scaled_loss(scaled_loss) scaled_loss = optimizer.get_scaled_loss(scaled_loss)
tvars = model.trainable_variables tvars = model.trainable_variables
grads = tape.gradient(scaled_loss, tvars) grads = tape.gradient(scaled_loss, tvars)
# Scales back gradient before apply_gradients when LossScaleOptimizer is # Scales back gradient before apply_gradients when LossScaleOptimizer is
# used. # used.
if isinstance( if isinstance(optimizer, tf.keras.mixed_precision.LossScaleOptimizer):
optimizer, tf.keras.mixed_precision.experimental.LossScaleOptimizer):
grads = optimizer.get_unscaled_gradients(grads) grads = optimizer.get_unscaled_gradients(grads)
# Apply gradient clipping. # Apply gradient clipping.
......
...@@ -151,15 +151,14 @@ class VideoClassificationTask(base_task.Task): ...@@ -151,15 +151,14 @@ class VideoClassificationTask(base_task.Task):
# For mixed_precision policy, when LossScaleOptimizer is used, loss is # For mixed_precision policy, when LossScaleOptimizer is used, loss is
# scaled for numerical stability. # scaled for numerical stability.
if isinstance( if isinstance(
optimizer, tf.keras.mixed_precision.experimental.LossScaleOptimizer): optimizer, tf.keras.mixed_precision.LossScaleOptimizer):
scaled_loss = optimizer.get_scaled_loss(scaled_loss) scaled_loss = optimizer.get_scaled_loss(scaled_loss)
tvars = model.trainable_variables tvars = model.trainable_variables
grads = tape.gradient(scaled_loss, tvars) grads = tape.gradient(scaled_loss, tvars)
# Scales back gradient before apply_gradients when LossScaleOptimizer is # Scales back gradient before apply_gradients when LossScaleOptimizer is
# used. # used.
if isinstance( if isinstance(optimizer, tf.keras.mixed_precision.LossScaleOptimizer):
optimizer, tf.keras.mixed_precision.experimental.LossScaleOptimizer):
grads = optimizer.get_unscaled_gradients(grads) grads = optimizer.get_unscaled_gradients(grads)
# Apply gradient clipping. # Apply gradient clipping.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment