Unverified Commit 02e8cd55 authored by Julien Plu's avatar Julien Plu Committed by GitHub
Browse files

Fix optimizer (#6717)

parent 77abd1e7
......@@ -221,9 +221,9 @@ class AdamWeightDecay(tf.keras.optimizers.Adam):
)
return tf.no_op()
def apply_gradients(self, grads_and_vars, name=None):
def apply_gradients(self, grads_and_vars, name=None, **kwargs):
grads, tvars = list(zip(*grads_and_vars))
return super(AdamWeightDecay, self).apply_gradients(zip(grads, tvars), name=name,)
return super(AdamWeightDecay, self).apply_gradients(zip(grads, tvars), name=name, **kwargs)
def _get_lr(self, var_device, var_dtype, apply_state):
"""Retrieves the learning rate with the given state."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment