Commit 0015eedf authored by Scott Zhu's avatar Scott Zhu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 293378746
parent 5d36f19b
......@@ -125,7 +125,7 @@ class AdamWeightDecay(tf.keras.optimizers.Adam):
def _prepare_local(self, var_device, var_dtype, apply_state):
super(AdamWeightDecay, self)._prepare_local(var_device, var_dtype,
apply_state)
apply_state['weight_decay_rate'] = tf.constant(
apply_state[(var_device, var_dtype)]['weight_decay_rate'] = tf.constant(
self.weight_decay_rate, name='adam_weight_decay_rate')
def _decay_weights_op(self, var, learning_rate, apply_state):
......@@ -133,7 +133,7 @@ class AdamWeightDecay(tf.keras.optimizers.Adam):
if do_decay:
return var.assign_sub(
learning_rate * var *
apply_state['weight_decay_rate'],
apply_state[(var.device, var.dtype.base_dtype)]['weight_decay_rate'],
use_locking=self._use_locking)
return tf.no_op()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment