Commit 4f50e2fc authored by Ran Chen's avatar Ran Chen Committed by A. Unique TensorFlower
Browse files

Workaround a known issue with control dependency on external tensors

This fix tensorflow_models/official/nlp/bert/run_pretraining failure with
explicit allreduce.

PiperOrigin-RevId: 338521478
parent 2b436a12
......@@ -199,7 +199,7 @@ class AdamWeightDecay(tf.keras.optimizers.Adam):
# backward pass.
# TODO(b/171088214): Remove it after the control dependency in
# nested function is fixed.
with tf.control_dependencies([grad]):
with tf.control_dependencies([tf.identity(grad)]):
lr_t, kwargs = self._get_lr(var.device, var.dtype.base_dtype, apply_state)
decay = self._decay_weights_op(var, lr_t, apply_state)
with tf.control_dependencies([decay]):
......@@ -212,7 +212,7 @@ class AdamWeightDecay(tf.keras.optimizers.Adam):
# backward pass.
# TODO(b/171088214): Remove it after the control dependency in
# nested function is fixed.
with tf.control_dependencies([grad]):
with tf.control_dependencies([tf.identity(grad)]):
lr_t, kwargs = self._get_lr(var.device, var.dtype.base_dtype, apply_state)
decay = self._decay_weights_op(var, lr_t, apply_state)
with tf.control_dependencies([decay]):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment