Commit 4aafcbfe authored by Chen Qian's avatar Chen Qian Committed by A. Unique TensorFlower
Browse files

Change the casting logic of iterations.

This is to get compatible for an incoming Keras optimizer migration.

PiperOrigin-RevId: 469561702
parent b983b565
......@@ -81,17 +81,18 @@ class PiecewiseConstantDecayWithWarmup(
def _get_learning_rate(self, step):
"""Compute learning rate at given step."""
step = tf.cast(step, dtype=tf.float32)
warmup_steps = tf.cast(self.warmup_steps, dtype=tf.float32)
with tf.name_scope('PiecewiseConstantDecayWithWarmup'):
def warmup_lr(step):
return self.rescaled_lr * (
tf.cast(step, tf.float32) / tf.cast(self.warmup_steps, tf.float32))
return self.rescaled_lr * (step / warmup_steps)
def piecewise_lr(step):
return tf.compat.v1.train.piecewise_constant(step, self.step_boundaries,
self.lr_values)
return tf.cond(step < self.warmup_steps, lambda: warmup_lr(step),
return tf.cond(step < warmup_steps, lambda: warmup_lr(step),
lambda: piecewise_lr(step))
def get_config(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment