Commit 7dd9852e authored by Jinoo Baek's avatar Jinoo Baek Committed by A. Unique TensorFlower
Browse files

Indentation bug. Divide by num_replicas_in_sync once.

PiperOrigin-RevId: 448508763
parent f7201d1a
......@@ -138,10 +138,10 @@ class MultiTask(tf.Module, metaclass=abc.ABCMeta):
self.tasks[name].process_metrics(task_metrics[name], labels, outputs,
**kwargs)
# Scales loss as the default gradients allreduce performs sum inside
# the optimizer.
scaled_loss = total_loss / tf.distribute.get_strategy(
).num_replicas_in_sync
# Scales loss as the default gradients allreduce performs sum inside
# the optimizer.
scaled_loss = total_loss / tf.distribute.get_strategy(
).num_replicas_in_sync
tvars = multi_task_model.trainable_variables
grads = tape.gradient(scaled_loss, tvars)
optimizer.apply_gradients(list(zip(grads, tvars)))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment