Commit bae940dc authored by Haoyu Zhang's avatar Haoyu Zhang Committed by Toby Boyd
Browse files

Fix broken test in V2 (#6755)

parent 4b4dbad1
...@@ -120,9 +120,10 @@ class PiecewiseConstantDecayWithWarmup( ...@@ -120,9 +120,10 @@ class PiecewiseConstantDecayWithWarmup(
def _get_learning_rate(self, step): def _get_learning_rate(self, step):
"""Compute learning rate at given step.""" """Compute learning rate at given step."""
with tf.name_scope(self.name, 'PiecewiseConstantDecayWithWarmup', [ with tf.compat.v1.name_scope(self.name, 'PiecewiseConstantDecayWithWarmup',
self.rescaled_lr, self.step_boundaries, self.lr_values, [self.rescaled_lr, self.step_boundaries,
self.warmup_steps, self.compute_lr_on_cpu]): self.lr_values, self.warmup_steps,
self.compute_lr_on_cpu]):
def warmup_lr(step): def warmup_lr(step):
return self.rescaled_lr * ( return self.rescaled_lr * (
tf.cast(step, tf.float32) / tf.cast(self.warmup_steps, tf.float32)) tf.cast(step, tf.float32) / tf.cast(self.warmup_steps, tf.float32))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment