Commit 5e854f25 authored by Zhichao Lu's avatar Zhichao Lu Committed by lzc5123016
Browse files

Make learning_schedules.py TPU compatible

PiperOrigin-RevId: 185425302
parent 2e539758
...@@ -264,6 +264,7 @@ py_test( ...@@ -264,6 +264,7 @@ py_test(
srcs = ["learning_schedules_test.py"], srcs = ["learning_schedules_test.py"],
deps = [ deps = [
":learning_schedules", ":learning_schedules",
":test_case",
"//tensorflow", "//tensorflow",
], ],
) )
......
...@@ -53,10 +53,10 @@ def exponential_decay_with_burnin(global_step, ...@@ -53,10 +53,10 @@ def exponential_decay_with_burnin(global_step,
learning_rate_decay_steps, learning_rate_decay_steps,
learning_rate_decay_factor, learning_rate_decay_factor,
staircase=True) staircase=True)
return tf.cond( return tf.where(
tf.less(global_step, burnin_steps), tf.less(tf.cast(global_step, tf.int32), tf.constant(burnin_steps)),
lambda: tf.convert_to_tensor(burnin_learning_rate), tf.constant(burnin_learning_rate),
lambda: post_burnin_learning_rate) post_burnin_learning_rate)
def cosine_decay_with_warmup(global_step, def cosine_decay_with_warmup(global_step,
...@@ -100,9 +100,10 @@ def cosine_decay_with_warmup(global_step, ...@@ -100,9 +100,10 @@ def cosine_decay_with_warmup(global_step,
slope = (learning_rate_base - warmup_learning_rate) / warmup_steps slope = (learning_rate_base - warmup_learning_rate) / warmup_steps
pre_cosine_learning_rate = slope * tf.cast( pre_cosine_learning_rate = slope * tf.cast(
global_step, tf.float32) + warmup_learning_rate global_step, tf.float32) + warmup_learning_rate
learning_rate = tf.cond( learning_rate = tf.where(
tf.less(global_step, warmup_steps), lambda: pre_cosine_learning_rate, tf.less(tf.cast(global_step, tf.int32), warmup_steps),
lambda: learning_rate) pre_cosine_learning_rate,
learning_rate)
return learning_rate return learning_rate
...@@ -141,10 +142,15 @@ def manual_stepping(global_step, boundaries, rates): ...@@ -141,10 +142,15 @@ def manual_stepping(global_step, boundaries, rates):
if len(rates) != len(boundaries) + 1: if len(rates) != len(boundaries) + 1:
raise ValueError('Number of provided learning rates must exceed ' raise ValueError('Number of provided learning rates must exceed '
'number of boundary points by exactly 1.') 'number of boundary points by exactly 1.')
step_boundaries = tf.constant(boundaries, tf.int64) step_boundaries = tf.constant(boundaries, tf.int32)
num_boundaries = len(boundaries)
learning_rates = tf.constant(rates, tf.float32) learning_rates = tf.constant(rates, tf.float32)
unreached_boundaries = tf.reshape( index = tf.reduce_min(
tf.where(tf.greater(step_boundaries, global_step)), [-1]) tf.where(
unreached_boundaries = tf.concat([unreached_boundaries, [len(boundaries)]], 0) # Casting global step to tf.int32 is dangerous, but necessary to be
index = tf.reshape(tf.reduce_min(unreached_boundaries), [1]) # compatible with TPU.
return tf.reshape(tf.slice(learning_rates, index, [1]), []) tf.greater(step_boundaries, tf.cast(global_step, tf.int32)),
tf.constant(range(num_boundaries), dtype=tf.int32),
tf.constant([num_boundaries] * num_boundaries, dtype=tf.int32)))
return tf.reduce_sum(learning_rates * tf.one_hot(index, len(rates),
dtype=tf.float32))
...@@ -14,65 +14,66 @@ ...@@ -14,65 +14,66 @@
# ============================================================================== # ==============================================================================
"""Tests for object_detection.utils.learning_schedules.""" """Tests for object_detection.utils.learning_schedules."""
import numpy as np
import tensorflow as tf import tensorflow as tf
from object_detection.utils import learning_schedules from object_detection.utils import learning_schedules
from object_detection.utils import test_case
class LearningSchedulesTest(tf.test.TestCase): class LearningSchedulesTest(test_case.TestCase):
def testExponentialDecayWithBurnin(self): def testExponentialDecayWithBurnin(self):
global_step = tf.placeholder(tf.int32, []) def graph_fn(global_step):
learning_rate_base = 1.0 learning_rate_base = 1.0
learning_rate_decay_steps = 3 learning_rate_decay_steps = 3
learning_rate_decay_factor = .1 learning_rate_decay_factor = .1
burnin_learning_rate = .5 burnin_learning_rate = .5
burnin_steps = 2 burnin_steps = 2
learning_rate = learning_schedules.exponential_decay_with_burnin(
global_step, learning_rate_base, learning_rate_decay_steps,
learning_rate_decay_factor, burnin_learning_rate, burnin_steps)
return (learning_rate,)
output_rates = [
self.execute(graph_fn, [np.array(i).astype(np.int64)]) for i in range(8)
]
exp_rates = [.5, .5, 1, .1, .1, .1, .01, .01] exp_rates = [.5, .5, 1, .1, .1, .1, .01, .01]
learning_rate = learning_schedules.exponential_decay_with_burnin( self.assertAllClose(output_rates, exp_rates, rtol=1e-4)
global_step, learning_rate_base, learning_rate_decay_steps,
learning_rate_decay_factor, burnin_learning_rate, burnin_steps)
with self.test_session() as sess:
output_rates = []
for input_global_step in range(8):
output_rate = sess.run(learning_rate,
feed_dict={global_step: input_global_step})
output_rates.append(output_rate)
self.assertAllClose(output_rates, exp_rates)
def testCosineDecayWithWarmup(self): def testCosineDecayWithWarmup(self):
global_step = tf.placeholder(tf.int32, []) def graph_fn(global_step):
learning_rate_base = 1.0 learning_rate_base = 1.0
total_steps = 100 total_steps = 100
warmup_learning_rate = 0.1 warmup_learning_rate = 0.1
warmup_steps = 9 warmup_steps = 9
input_global_steps = [0, 4, 8, 9, 100] learning_rate = learning_schedules.cosine_decay_with_warmup(
global_step, learning_rate_base, total_steps,
warmup_learning_rate, warmup_steps)
return (learning_rate,)
exp_rates = [0.1, 0.5, 0.9, 1.0, 0] exp_rates = [0.1, 0.5, 0.9, 1.0, 0]
learning_rate = learning_schedules.cosine_decay_with_warmup( input_global_steps = [0, 4, 8, 9, 100]
global_step, learning_rate_base, total_steps, output_rates = [
warmup_learning_rate, warmup_steps) self.execute(graph_fn, [np.array(step).astype(np.int64)])
with self.test_session() as sess: for step in input_global_steps
output_rates = [] ]
for input_global_step in input_global_steps: self.assertAllClose(output_rates, exp_rates)
output_rate = sess.run(learning_rate,
feed_dict={global_step: input_global_step})
output_rates.append(output_rate)
self.assertAllClose(output_rates, exp_rates)
def testManualStepping(self): def testManualStepping(self):
global_step = tf.placeholder(tf.int64, []) def graph_fn(global_step):
boundaries = [2, 3, 7] boundaries = [2, 3, 7]
rates = [1.0, 2.0, 3.0, 4.0] rates = [1.0, 2.0, 3.0, 4.0]
learning_rate = learning_schedules.manual_stepping(
global_step, boundaries, rates)
return (learning_rate,)
output_rates = [
self.execute(graph_fn, [np.array(i).astype(np.int64)])
for i in range(10)
]
exp_rates = [1.0, 1.0, 2.0, 3.0, 3.0, 3.0, 3.0, 4.0, 4.0, 4.0] exp_rates = [1.0, 1.0, 2.0, 3.0, 3.0, 3.0, 3.0, 4.0, 4.0, 4.0]
learning_rate = learning_schedules.manual_stepping(global_step, boundaries, self.assertAllClose(output_rates, exp_rates)
rates)
with self.test_session() as sess:
output_rates = []
for input_global_step in range(10):
output_rate = sess.run(learning_rate,
feed_dict={global_step: input_global_step})
output_rates.append(output_rate)
self.assertAllClose(output_rates, exp_rates)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment