Commit 9fc4fd08 authored by Ran Chen's avatar Ran Chen Committed by A. Unique TensorFlower
Browse files

Rename all_reduce_sum_gradients to experimental_aggregate_gradients

For some strategies we don't do all reduce, so all_reduce_sum_gradients can be
misleading. The parameter is also changed to experimental because of issues with
CentralStorageStrategy.

PiperOrigin-RevId: 302734837
parent 1ec383c8
......@@ -140,19 +140,19 @@ class AdamWeightDecay(tf.keras.optimizers.Adam):
def apply_gradients(self,
grads_and_vars,
name=None,
all_reduce_sum_gradients=True):
experimental_aggregate_gradients=True):
grads, tvars = list(zip(*grads_and_vars))
if all_reduce_sum_gradients:
# when all_reduce_sum_gradients = False, apply_gradients() no longer
# implicitly allreduce gradients, users manually allreduce gradient and
# passed the allreduced grads_and_vars. For now, the clip_by_global_norm
# will be moved to before the explicit allreduce to keep the math
# the same as TF 1 and pre TF 2.2 implementation.
if experimental_aggregate_gradients:
# when experimental_aggregate_gradients = False, apply_gradients() no
# longer implicitly allreduce gradients, users manually allreduce gradient
# and passed the allreduced grads_and_vars. For now, the
# clip_by_global_norm will be moved to before the explicit allreduce to
# keep the math the same as TF 1 and pre TF 2.2 implementation.
(grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0)
return super(AdamWeightDecay, self).apply_gradients(
zip(grads, tvars),
name=name,
all_reduce_sum_gradients=all_reduce_sum_gradients)
experimental_aggregate_gradients=experimental_aggregate_gradients)
def _get_lr(self, var_device, var_dtype, apply_state):
"""Retrieves the learning rate with the given state."""
......
......@@ -54,7 +54,7 @@ def _filter_and_allreduce_gradients(grads_and_vars,
This utils function is used when users intent to explicitly allreduce
gradients and customize gradients operations before and after allreduce.
The allreduced gradients are then passed to optimizer.apply_gradients(
all_reduce_sum_gradients=False).
experimental_aggregate_gradients=False).
Arguments:
grads_and_vars: gradients and variables pairs.
......@@ -139,4 +139,5 @@ def minimize_using_explicit_allreduce(tape,
grads_and_vars = zip(allreduced_grads, filtered_training_vars)
if post_allreduce_callbacks:
grads_and_vars = _run_callbacks(post_allreduce_callbacks, grads_and_vars)
optimizer.apply_gradients(grads_and_vars, all_reduce_sum_gradients=False)
optimizer.apply_gradients(
grads_and_vars, experimental_aggregate_gradients=False)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment