Commit 9fc4fd08 authored by Ran Chen's avatar Ran Chen Committed by A. Unique TensorFlower
Browse files

Rename all_reduce_sum_gradients to experimental_aggregate_gradients

For some strategies we don't do all reduce, so all_reduce_sum_gradients can be
misleading. The parameter is also changed to experimental because of issues with
CentralStorageStrategy.

PiperOrigin-RevId: 302734837
parent 1ec383c8
...@@ -140,19 +140,19 @@ class AdamWeightDecay(tf.keras.optimizers.Adam): ...@@ -140,19 +140,19 @@ class AdamWeightDecay(tf.keras.optimizers.Adam):
def apply_gradients(self, def apply_gradients(self,
grads_and_vars, grads_and_vars,
name=None, name=None,
all_reduce_sum_gradients=True): experimental_aggregate_gradients=True):
grads, tvars = list(zip(*grads_and_vars)) grads, tvars = list(zip(*grads_and_vars))
if all_reduce_sum_gradients: if experimental_aggregate_gradients:
# when all_reduce_sum_gradients = False, apply_gradients() no longer # when experimental_aggregate_gradients = False, apply_gradients() no
# implicitly allreduce gradients, users manually allreduce gradient and # longer implicitly allreduce gradients, users manually allreduce gradient
# passed the allreduced grads_and_vars. For now, the clip_by_global_norm # and passed the allreduced grads_and_vars. For now, the
# will be moved to before the explicit allreduce to keep the math # clip_by_global_norm will be moved to before the explicit allreduce to
# the same as TF 1 and pre TF 2.2 implementation. # keep the math the same as TF 1 and pre TF 2.2 implementation.
(grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0) (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0)
return super(AdamWeightDecay, self).apply_gradients( return super(AdamWeightDecay, self).apply_gradients(
zip(grads, tvars), zip(grads, tvars),
name=name, name=name,
all_reduce_sum_gradients=all_reduce_sum_gradients) experimental_aggregate_gradients=experimental_aggregate_gradients)
def _get_lr(self, var_device, var_dtype, apply_state): def _get_lr(self, var_device, var_dtype, apply_state):
"""Retrieves the learning rate with the given state.""" """Retrieves the learning rate with the given state."""
......
...@@ -54,7 +54,7 @@ def _filter_and_allreduce_gradients(grads_and_vars, ...@@ -54,7 +54,7 @@ def _filter_and_allreduce_gradients(grads_and_vars,
This utils function is used when users intent to explicitly allreduce This utils function is used when users intent to explicitly allreduce
gradients and customize gradients operations before and after allreduce. gradients and customize gradients operations before and after allreduce.
The allreduced gradients are then passed to optimizer.apply_gradients( The allreduced gradients are then passed to optimizer.apply_gradients(
all_reduce_sum_gradients=False). experimental_aggregate_gradients=False).
Arguments: Arguments:
grads_and_vars: gradients and variables pairs. grads_and_vars: gradients and variables pairs.
...@@ -139,4 +139,5 @@ def minimize_using_explicit_allreduce(tape, ...@@ -139,4 +139,5 @@ def minimize_using_explicit_allreduce(tape,
grads_and_vars = zip(allreduced_grads, filtered_training_vars) grads_and_vars = zip(allreduced_grads, filtered_training_vars)
if post_allreduce_callbacks: if post_allreduce_callbacks:
grads_and_vars = _run_callbacks(post_allreduce_callbacks, grads_and_vars) grads_and_vars = _run_callbacks(post_allreduce_callbacks, grads_and_vars)
optimizer.apply_gradients(grads_and_vars, all_reduce_sum_gradients=False) optimizer.apply_gradients(
grads_and_vars, experimental_aggregate_gradients=False)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment