Commit c9ab5a7a authored by Le Hou's avatar Le Hou Committed by A. Unique TensorFlower
Browse files

Fix corner cases where LR schedules output Inf.

PiperOrigin-RevId: 360719881
parent b3b0664b
...@@ -120,7 +120,14 @@ class PolynomialWarmUp(tf.keras.optimizers.schedules.LearningRateSchedule): ...@@ -120,7 +120,14 @@ class PolynomialWarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
# learning rate will be `global_step/num_warmup_steps * init_lr`. # learning rate will be `global_step/num_warmup_steps * init_lr`.
global_step_float = tf.cast(step, tf.float32) global_step_float = tf.cast(step, tf.float32)
warmup_steps_float = tf.cast(self._warmup_steps, tf.float32) warmup_steps_float = tf.cast(self._warmup_steps, tf.float32)
warmup_percent_done = global_step_float / warmup_steps_float
if self._warmup_steps <= 0:
warmup_percent_done = 1.0
else:
# A zero `step` may cause Inf. So make `step` positive.
step_non_zero = tf.math.maximum(global_step_float, 1.0)
warmup_percent_done = step_non_zero / warmup_steps_float
warmup_learning_rate = ( warmup_learning_rate = (
self._initial_learning_rate * self._initial_learning_rate *
tf.math.pow(warmup_percent_done, self._power)) tf.math.pow(warmup_percent_done, self._power))
...@@ -226,8 +233,10 @@ class PowerAndLinearDecay(tf.keras.optimizers.schedules.LearningRateSchedule): ...@@ -226,8 +233,10 @@ class PowerAndLinearDecay(tf.keras.optimizers.schedules.LearningRateSchedule):
with tf.name_scope(self._name or "PowerAndLinearDecay"): with tf.name_scope(self._name or "PowerAndLinearDecay"):
step = tf.cast(step, tf.float32) step = tf.cast(step, tf.float32)
learning_rate = self._initial_learning_rate learning_rate = self._initial_learning_rate
learning_rate *= tf.math.pow(step, self._power) # A zero `step` may cause Inf. So make `step` positive.
if self._linear_decay_fraction > 0: step_non_zero = tf.math.maximum(step, 1.0)
learning_rate *= tf.math.pow(step_non_zero, self._power)
if self._total_decay_steps * self._linear_decay_fraction > 0:
learning_rate *= tf.minimum( learning_rate *= tf.minimum(
1.0, (self._total_decay_steps - step) / 1.0, (self._total_decay_steps - step) /
(self._total_decay_steps * self._linear_decay_fraction)) (self._total_decay_steps * self._linear_decay_fraction))
......
...@@ -313,7 +313,7 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase): ...@@ -313,7 +313,7 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
lr = opt_factory.build_learning_rate() lr = opt_factory.build_learning_rate()
for step, value in expected_lr_step_values: for step, value in expected_lr_step_values:
self.assertAlmostEqual(lr(step).numpy(), value) self.assertAlmostEqual(lr(step).numpy(), value, places=6)
def test_power_lr_schedule(self): def test_power_lr_schedule(self):
params = { params = {
...@@ -331,7 +331,7 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase): ...@@ -331,7 +331,7 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
} }
} }
} }
expected_lr_step_values = [[1, 1.0], [250, 1. / 250.]] expected_lr_step_values = [[0, 1.0], [1, 1.0], [250, 1. / 250.]]
opt_config = optimization_config.OptimizationConfig(params) opt_config = optimization_config.OptimizationConfig(params)
opt_factory = optimizer_factory.OptimizerFactory(opt_config) opt_factory = optimizer_factory.OptimizerFactory(opt_config)
lr = opt_factory.build_learning_rate() lr = opt_factory.build_learning_rate()
...@@ -357,7 +357,8 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase): ...@@ -357,7 +357,8 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
} }
} }
} }
expected_lr_step_values = [[1, 1.0], [40, 1. / 40.], [60, 1. / 60. * 0.8]] expected_lr_step_values = [
[0, 1.0], [1, 1.0], [40, 1. / 40.], [60, 1. / 60. * 0.8]]
opt_config = optimization_config.OptimizationConfig(params) opt_config = optimization_config.OptimizationConfig(params)
opt_factory = optimizer_factory.OptimizerFactory(opt_config) opt_factory = optimizer_factory.OptimizerFactory(opt_config)
lr = opt_factory.build_learning_rate() lr = opt_factory.build_learning_rate()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment