Commit 3a6079ce authored by Zhichao Lu's avatar Zhichao Lu Committed by TF Object Detection Team
Browse files

Add an exponential decay learning rate schedule with warmup.

PiperOrigin-RevId: 344743049
parent 067d35f9
...@@ -23,6 +23,14 @@ from six.moves import zip ...@@ -23,6 +23,14 @@ from six.moves import zip
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
def _learning_rate_return_value(eager_decay_rate):
"""Helper function to return proper learning rate based on tf version."""
if tf.executing_eagerly():
return eager_decay_rate
else:
return eager_decay_rate()
def exponential_decay_with_burnin(global_step, def exponential_decay_with_burnin(global_step,
learning_rate_base, learning_rate_base,
learning_rate_decay_steps, learning_rate_decay_steps,
...@@ -76,10 +84,65 @@ def exponential_decay_with_burnin(global_step, ...@@ -76,10 +84,65 @@ def exponential_decay_with_burnin(global_step,
tf.constant(burnin_learning_rate), tf.constant(burnin_learning_rate),
post_burnin_learning_rate), min_learning_rate, name='learning_rate') post_burnin_learning_rate), min_learning_rate, name='learning_rate')
if tf.executing_eagerly(): return _learning_rate_return_value(eager_decay_rate)
return eager_decay_rate
else:
return eager_decay_rate() def exponential_decay_with_warmup(global_step,
learning_rate_base,
learning_rate_decay_steps,
learning_rate_decay_factor,
warmup_learning_rate=0.0,
warmup_steps=0,
min_learning_rate=0.0,
staircase=True):
"""Exponential decay schedule with warm up period.
Args:
global_step: int tensor representing global step.
learning_rate_base: base learning rate.
learning_rate_decay_steps: steps to take between decaying the learning rate.
Note that this includes the number of burn-in steps.
learning_rate_decay_factor: multiplicative factor by which to decay learning
rate.
warmup_learning_rate: initial learning rate during warmup period.
warmup_steps: number of steps to use warmup learning rate.
min_learning_rate: the minimum learning rate.
staircase: whether use staircase decay.
Returns:
If executing eagerly:
returns a no-arg callable that outputs the (scalar)
float tensor learning rate given the current value of global_step.
If in a graph:
immediately returns a (scalar) float tensor representing learning rate.
"""
def eager_decay_rate():
"""Callable to compute the learning rate."""
post_warmup_learning_rate = tf.train.exponential_decay(
learning_rate_base,
global_step - warmup_steps,
learning_rate_decay_steps,
learning_rate_decay_factor,
staircase=staircase)
if callable(post_warmup_learning_rate):
post_warmup_learning_rate = post_warmup_learning_rate()
if learning_rate_base < warmup_learning_rate:
raise ValueError('learning_rate_base must be larger or equal to '
'warmup_learning_rate.')
slope = (learning_rate_base - warmup_learning_rate) / warmup_steps
warmup_rate = slope * tf.cast(global_step,
tf.float32) + warmup_learning_rate
learning_rate = tf.where(
tf.less(tf.cast(global_step, tf.int32), tf.constant(warmup_steps)),
warmup_rate,
tf.maximum(post_warmup_learning_rate, min_learning_rate),
name='learning_rate')
return learning_rate
return _learning_rate_return_value(eager_decay_rate)
def cosine_decay_with_warmup(global_step, def cosine_decay_with_warmup(global_step,
...@@ -142,10 +205,7 @@ def cosine_decay_with_warmup(global_step, ...@@ -142,10 +205,7 @@ def cosine_decay_with_warmup(global_step,
return tf.where(global_step > total_steps, 0.0, learning_rate, return tf.where(global_step > total_steps, 0.0, learning_rate,
name='learning_rate') name='learning_rate')
if tf.executing_eagerly(): return _learning_rate_return_value(eager_decay_rate)
return eager_decay_rate
else:
return eager_decay_rate()
def manual_stepping(global_step, boundaries, rates, warmup=False): def manual_stepping(global_step, boundaries, rates, warmup=False):
...@@ -212,7 +272,5 @@ def manual_stepping(global_step, boundaries, rates, warmup=False): ...@@ -212,7 +272,5 @@ def manual_stepping(global_step, boundaries, rates, warmup=False):
[0] * num_boundaries)) [0] * num_boundaries))
return tf.reduce_sum(rates * tf.one_hot(rate_index, depth=num_boundaries), return tf.reduce_sum(rates * tf.one_hot(rate_index, depth=num_boundaries),
name='learning_rate') name='learning_rate')
if tf.executing_eagerly():
return eager_decay_rate return _learning_rate_return_value(eager_decay_rate)
else:
return eager_decay_rate()
...@@ -50,6 +50,28 @@ class LearningSchedulesTest(test_case.TestCase): ...@@ -50,6 +50,28 @@ class LearningSchedulesTest(test_case.TestCase):
exp_rates = [.5, .5, 1, 1, 1, .1, .1, .1, .05] exp_rates = [.5, .5, 1, 1, 1, .1, .1, .1, .05]
self.assertAllClose(output_rates, exp_rates, rtol=1e-4) self.assertAllClose(output_rates, exp_rates, rtol=1e-4)
def testExponentialDecayWithWarmup(self):
def graph_fn(global_step):
learning_rate_base = 1.0
learning_rate_decay_steps = 3
learning_rate_decay_factor = .1
warmup_learning_rate = .5
warmup_steps = 2
min_learning_rate = .05
learning_rate = learning_schedules.exponential_decay_with_warmup(
global_step, learning_rate_base, learning_rate_decay_steps,
learning_rate_decay_factor, warmup_learning_rate, warmup_steps,
min_learning_rate)
assert learning_rate.op.name.endswith('learning_rate')
return (learning_rate,)
output_rates = [
self.execute(graph_fn, [np.array(i).astype(np.int64)]) for i in range(9)
]
exp_rates = [.5, .75, 1, 1, 1, .1, .1, .1, .05]
self.assertAllClose(output_rates, exp_rates, rtol=1e-4)
def testCosineDecayWithWarmup(self): def testCosineDecayWithWarmup(self):
def graph_fn(global_step): def graph_fn(global_step):
learning_rate_base = 1.0 learning_rate_base = 1.0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment