"git@developer.sourcefind.cn:modelzoo/resnet50_tensorflow.git" did not exist on "bba0f953eadb600d9d680c6bb845bed6fcc9d1b7"
Commit dc4d1121 authored by Ruoxin Sang's avatar Ruoxin Sang Committed by A. Unique TensorFlower
Browse files

Remove explicit control dependency for weight decay.

PiperOrigin-RevId: 341329653
parent 2bde2485
...@@ -191,27 +191,15 @@ class AdamWeightDecay(tf.keras.optimizers.Adam): ...@@ -191,27 +191,15 @@ class AdamWeightDecay(tf.keras.optimizers.Adam):
return coefficients['lr_t'], dict(apply_state=apply_state) return coefficients['lr_t'], dict(apply_state=apply_state)
def _resource_apply_dense(self, grad, var, apply_state=None): def _resource_apply_dense(self, grad, var, apply_state=None):
# As the weight decay doesn't take any tensors from forward pass as inputs, lr_t, kwargs = self._get_lr(var.device, var.dtype.base_dtype, apply_state)
# add a control dependency here to make sure it happens strictly in the decay = self._decay_weights_op(var, lr_t, apply_state)
# backward pass.
# TODO(b/171088214): Remove it after the control dependency in
# nested function is fixed.
with tf.control_dependencies([tf.identity(grad)]):
lr_t, kwargs = self._get_lr(var.device, var.dtype.base_dtype, apply_state)
decay = self._decay_weights_op(var, lr_t, apply_state)
with tf.control_dependencies([decay]): with tf.control_dependencies([decay]):
return super(AdamWeightDecay, return super(AdamWeightDecay,
self)._resource_apply_dense(grad, var, **kwargs) self)._resource_apply_dense(grad, var, **kwargs)
def _resource_apply_sparse(self, grad, var, indices, apply_state=None): def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
# As the weight decay doesn't take any tensors from forward pass as inputs, lr_t, kwargs = self._get_lr(var.device, var.dtype.base_dtype, apply_state)
# add a control dependency here to make sure it happens strictly in the decay = self._decay_weights_op(var, lr_t, apply_state)
# backward pass.
# TODO(b/171088214): Remove it after the control dependency in
# nested function is fixed.
with tf.control_dependencies([tf.identity(grad)]):
lr_t, kwargs = self._get_lr(var.device, var.dtype.base_dtype, apply_state)
decay = self._decay_weights_op(var, lr_t, apply_state)
with tf.control_dependencies([decay]): with tf.control_dependencies([decay]):
return super(AdamWeightDecay, return super(AdamWeightDecay,
self)._resource_apply_sparse(grad, var, indices, **kwargs) self)._resource_apply_sparse(grad, var, indices, **kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment