Commit dc4d1121 authored by Ruoxin Sang's avatar Ruoxin Sang Committed by A. Unique TensorFlower
Browse files

Remove explicit control dependency for weight decay.

PiperOrigin-RevId: 341329653
parent 2bde2485
......@@ -191,27 +191,15 @@ class AdamWeightDecay(tf.keras.optimizers.Adam):
return coefficients['lr_t'], dict(apply_state=apply_state)
def _resource_apply_dense(self, grad, var, apply_state=None):
# As the weight decay doesn't take any tensors from forward pass as inputs,
# add a control dependency here to make sure it happens strictly in the
# backward pass.
# TODO(b/171088214): Remove it after the control dependency in
# nested function is fixed.
with tf.control_dependencies([tf.identity(grad)]):
lr_t, kwargs = self._get_lr(var.device, var.dtype.base_dtype, apply_state)
decay = self._decay_weights_op(var, lr_t, apply_state)
lr_t, kwargs = self._get_lr(var.device, var.dtype.base_dtype, apply_state)
decay = self._decay_weights_op(var, lr_t, apply_state)
with tf.control_dependencies([decay]):
return super(AdamWeightDecay,
self)._resource_apply_dense(grad, var, **kwargs)
def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
# As the weight decay doesn't take any tensors from forward pass as inputs,
# add a control dependency here to make sure it happens strictly in the
# backward pass.
# TODO(b/171088214): Remove it after the control dependency in
# nested function is fixed.
with tf.control_dependencies([tf.identity(grad)]):
lr_t, kwargs = self._get_lr(var.device, var.dtype.base_dtype, apply_state)
decay = self._decay_weights_op(var, lr_t, apply_state)
lr_t, kwargs = self._get_lr(var.device, var.dtype.base_dtype, apply_state)
decay = self._decay_weights_op(var, lr_t, apply_state)
with tf.control_dependencies([decay]):
return super(AdamWeightDecay,
self)._resource_apply_sparse(grad, var, indices, **kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment