Commit 51e2004c authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 340748908
parent 86057b3d
...@@ -215,7 +215,6 @@ class ElectraPretrainTask(base_task.Task): ...@@ -215,7 +215,6 @@ class ElectraPretrainTask(base_task.Task):
aux_losses=model.losses) aux_losses=model.losses)
# Scales loss as the default gradients allreduce performs sum inside the # Scales loss as the default gradients allreduce performs sum inside the
# optimizer. # optimizer.
# TODO(b/154564893): enable loss scaling.
scaled_loss = loss / tf.distribute.get_strategy().num_replicas_in_sync scaled_loss = loss / tf.distribute.get_strategy().num_replicas_in_sync
tvars = model.trainable_variables tvars = model.trainable_variables
grads = tape.gradient(scaled_loss, tvars) grads = tape.gradient(scaled_loss, tvars)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment