"vscode:/vscode.git/clone" did not exist on "cfc8306d7d61eeac99e3838b0c0074e8a8b6cf54"
Commit 8cf11fd4 authored by Reed Wanderman-Milne's avatar Reed Wanderman-Milne Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 426529953
parent a02d9f0c
...@@ -203,8 +203,7 @@ class BASNetTask(base_task.Task): ...@@ -203,8 +203,7 @@ class BASNetTask(base_task.Task):
# For mixed_precision policy, when LossScaleOptimizer is used, loss is # For mixed_precision policy, when LossScaleOptimizer is used, loss is
# scaled for numerical stability. # scaled for numerical stability.
if isinstance( if isinstance(optimizer, tf.keras.mixed_precision.LossScaleOptimizer):
optimizer, tf.keras.mixed_precision.experimental.LossScaleOptimizer):
scaled_loss = optimizer.get_scaled_loss(scaled_loss) scaled_loss = optimizer.get_scaled_loss(scaled_loss)
tvars = model.trainable_variables tvars = model.trainable_variables
...@@ -212,8 +211,7 @@ class BASNetTask(base_task.Task): ...@@ -212,8 +211,7 @@ class BASNetTask(base_task.Task):
# Scales back gradient before apply_gradients when LossScaleOptimizer is # Scales back gradient before apply_gradients when LossScaleOptimizer is
# used. # used.
if isinstance( if isinstance(optimizer, tf.keras.mixed_precision.LossScaleOptimizer):
optimizer, tf.keras.mixed_precision.experimental.LossScaleOptimizer):
grads = optimizer.get_unscaled_gradients(grads) grads = optimizer.get_unscaled_gradients(grads)
# Apply gradient clipping. # Apply gradient clipping.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment