Commit 0015eedf authored by Scott Zhu's avatar Scott Zhu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 293378746
parent 5d36f19b
...@@ -125,7 +125,7 @@ class AdamWeightDecay(tf.keras.optimizers.Adam): ...@@ -125,7 +125,7 @@ class AdamWeightDecay(tf.keras.optimizers.Adam):
def _prepare_local(self, var_device, var_dtype, apply_state): def _prepare_local(self, var_device, var_dtype, apply_state):
super(AdamWeightDecay, self)._prepare_local(var_device, var_dtype, super(AdamWeightDecay, self)._prepare_local(var_device, var_dtype,
apply_state) apply_state)
apply_state['weight_decay_rate'] = tf.constant( apply_state[(var_device, var_dtype)]['weight_decay_rate'] = tf.constant(
self.weight_decay_rate, name='adam_weight_decay_rate') self.weight_decay_rate, name='adam_weight_decay_rate')
def _decay_weights_op(self, var, learning_rate, apply_state): def _decay_weights_op(self, var, learning_rate, apply_state):
...@@ -133,7 +133,7 @@ class AdamWeightDecay(tf.keras.optimizers.Adam): ...@@ -133,7 +133,7 @@ class AdamWeightDecay(tf.keras.optimizers.Adam):
if do_decay: if do_decay:
return var.assign_sub( return var.assign_sub(
learning_rate * var * learning_rate * var *
apply_state['weight_decay_rate'], apply_state[(var.device, var.dtype.base_dtype)]['weight_decay_rate'],
use_locking=self._use_locking) use_locking=self._use_locking)
return tf.no_op() return tf.no_op()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment