"...src/git@developer.sourcefind.cn:OpenDAS/torch-scatter.git" did not exist on "4af63f28887a8d52e1d1e006258febb70899f3a4"
Commit b5c6170e authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 340748908
parent 3d976e0d
...@@ -215,7 +215,6 @@ class ElectraPretrainTask(base_task.Task): ...@@ -215,7 +215,6 @@ class ElectraPretrainTask(base_task.Task):
aux_losses=model.losses) aux_losses=model.losses)
# Scales loss as the default gradients allreduce performs sum inside the # Scales loss as the default gradients allreduce performs sum inside the
# optimizer. # optimizer.
# TODO(b/154564893): enable loss scaling.
scaled_loss = loss / tf.distribute.get_strategy().num_replicas_in_sync scaled_loss = loss / tf.distribute.get_strategy().num_replicas_in_sync
tvars = model.trainable_variables tvars = model.trainable_variables
grads = tape.gradient(scaled_loss, tvars) grads = tape.gradient(scaled_loss, tvars)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment