"magic_pdf/git@developer.sourcefind.cn:wangsen/mineru.git" did not exist on "0606301412150f2878ee2ee56f1932453c4cd79d"
Commit a9684184 authored by Abdullah Rashwan's avatar Abdullah Rashwan Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 315331850
parent 39bdc9c2
...@@ -191,8 +191,19 @@ class Task(tf.Module): ...@@ -191,8 +191,19 @@ class Task(tf.Module):
# Scales loss as the default gradients allreduce performs sum inside the # Scales loss as the default gradients allreduce performs sum inside the
# optimizer. # optimizer.
scaled_loss = loss / tf.distribute.get_strategy().num_replicas_in_sync scaled_loss = loss / tf.distribute.get_strategy().num_replicas_in_sync
# For mixed precision, when a LossScaleOptimizer is used, the loss is
# scaled to avoid numeric underflow.
if isinstance(optimizer,
tf.keras.mixed_precision.experimental.LossScaleOptimizer):
scaled_loss = optimizer.get_scaled_loss(scaled_loss)
tvars = model.trainable_variables tvars = model.trainable_variables
grads = tape.gradient(scaled_loss, tvars) grads = tape.gradient(scaled_loss, tvars)
if isinstance(optimizer,
tf.keras.mixed_precision.experimental.LossScaleOptimizer):
grads = optimizer.get_unscaled_gradients(grads)
optimizer.apply_gradients(list(zip(grads, tvars))) optimizer.apply_gradients(list(zip(grads, tvars)))
logs = {self.loss: loss} logs = {self.loss: loss}
if metrics: if metrics:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment