"git@developer.sourcefind.cn:lacacy/qwen_lmdeploy.git" did not exist on "902a3e16f937d4f142968dea265c9e03c8559bb8"
Commit 89dd9a4e authored by Zongwei Zhou's avatar Zongwei Zhou Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 296767846
parent c25d7711
......@@ -137,10 +137,16 @@ class AdamWeightDecay(tf.keras.optimizers.Adam):
use_locking=self._use_locking)
return tf.no_op()
def apply_gradients(self, grads_and_vars, name=None):
def apply_gradients(self,
grads_and_vars,
name=None,
all_reduce_sum_gradients=True):
grads, tvars = list(zip(*grads_and_vars))
(grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0)
return super(AdamWeightDecay, self).apply_gradients(zip(grads, tvars))
return super(AdamWeightDecay, self).apply_gradients(
zip(grads, tvars),
name=name,
all_reduce_sum_gradients=all_reduce_sum_gradients)
def _get_lr(self, var_device, var_dtype, apply_state):
"""Retrieves the learning rate with the given state."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment