Commit 19113a57 authored by Ruoxin Sang's avatar Ruoxin Sang Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 337609198
parent a3e847b6
......@@ -194,15 +194,27 @@ class AdamWeightDecay(tf.keras.optimizers.Adam):
return coefficients['lr_t'], dict(apply_state=apply_state)
def _resource_apply_dense(self, grad, var, apply_state=None):
lr_t, kwargs = self._get_lr(var.device, var.dtype.base_dtype, apply_state)
decay = self._decay_weights_op(var, lr_t, apply_state)
# As the weight decay doesn't take any tensors from forward pass as inputs,
# add a control dependency here to make sure it happens strictly in the
# backward pass.
# TODO(b/171088214): Remove it after the control dependency in
# nested function is fixed.
with tf.control_dependencies([grad]):
lr_t, kwargs = self._get_lr(var.device, var.dtype.base_dtype, apply_state)
decay = self._decay_weights_op(var, lr_t, apply_state)
with tf.control_dependencies([decay]):
return super(AdamWeightDecay,
self)._resource_apply_dense(grad, var, **kwargs)
def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
lr_t, kwargs = self._get_lr(var.device, var.dtype.base_dtype, apply_state)
decay = self._decay_weights_op(var, lr_t, apply_state)
# As the weight decay doesn't take any tensors from forward pass as inputs,
# add a control dependency here to make sure it happens strictly in the
# backward pass.
# TODO(b/171088214): Remove it after the control dependency in
# nested function is fixed.
with tf.control_dependencies([grad]):
lr_t, kwargs = self._get_lr(var.device, var.dtype.base_dtype, apply_state)
decay = self._decay_weights_op(var, lr_t, apply_state)
with tf.control_dependencies([decay]):
return super(AdamWeightDecay,
self)._resource_apply_sparse(grad, var, indices, **kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment