"git@developer.sourcefind.cn:wangsen/mineru.git" did not exist on "720bc9cf64b6b9aba320bb506adb660fd0b854ef"
Commit 9661bd57 authored by Chen Qian's avatar Chen Qian Committed by A. Unique TensorFlower
Browse files

Change the casting logic of iterations.

This is to get compatible for an incoming Keras optimizer migration.

PiperOrigin-RevId: 469561702
parent ed232e90
...@@ -81,17 +81,18 @@ class PiecewiseConstantDecayWithWarmup( ...@@ -81,17 +81,18 @@ class PiecewiseConstantDecayWithWarmup(
def _get_learning_rate(self, step): def _get_learning_rate(self, step):
"""Compute learning rate at given step.""" """Compute learning rate at given step."""
step = tf.cast(step, dtype=tf.float32)
warmup_steps = tf.cast(self.warmup_steps, dtype=tf.float32)
with tf.name_scope('PiecewiseConstantDecayWithWarmup'): with tf.name_scope('PiecewiseConstantDecayWithWarmup'):
def warmup_lr(step): def warmup_lr(step):
return self.rescaled_lr * ( return self.rescaled_lr * (step / warmup_steps)
tf.cast(step, tf.float32) / tf.cast(self.warmup_steps, tf.float32))
def piecewise_lr(step): def piecewise_lr(step):
return tf.compat.v1.train.piecewise_constant(step, self.step_boundaries, return tf.compat.v1.train.piecewise_constant(step, self.step_boundaries,
self.lr_values) self.lr_values)
return tf.cond(step < self.warmup_steps, lambda: warmup_lr(step), return tf.cond(step < warmup_steps, lambda: warmup_lr(step),
lambda: piecewise_lr(step)) lambda: piecewise_lr(step))
def get_config(self): def get_config(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment