Commit 6f0e3a0b authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 296561671
parent e185b94c
...@@ -137,16 +137,10 @@ class AdamWeightDecay(tf.keras.optimizers.Adam): ...@@ -137,16 +137,10 @@ class AdamWeightDecay(tf.keras.optimizers.Adam):
use_locking=self._use_locking) use_locking=self._use_locking)
return tf.no_op() return tf.no_op()
def apply_gradients(self, def apply_gradients(self, grads_and_vars, name=None):
grads_and_vars,
name=None,
all_reduce_sum_gradients=True):
grads, tvars = list(zip(*grads_and_vars)) grads, tvars = list(zip(*grads_and_vars))
(grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0) (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0)
return super(AdamWeightDecay, self).apply_gradients( return super(AdamWeightDecay, self).apply_gradients(zip(grads, tvars))
zip(grads, tvars),
name=name,
all_reduce_sum_gradients=all_reduce_sum_gradients)
def _get_lr(self, var_device, var_dtype, apply_state): def _get_lr(self, var_device, var_dtype, apply_state):
"""Retrieves the learning rate with the given state.""" """Retrieves the learning rate with the given state."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment