Commit 5e854f25 authored by Zhichao Lu's avatar Zhichao Lu Committed by lzc5123016
Browse files

Make learning_schedules.py TPU compatible

PiperOrigin-RevId: 185425302
parent 2e539758
......@@ -264,6 +264,7 @@ py_test(
srcs = ["learning_schedules_test.py"],
deps = [
":learning_schedules",
":test_case",
"//tensorflow",
],
)
......
......@@ -53,10 +53,10 @@ def exponential_decay_with_burnin(global_step,
learning_rate_decay_steps,
learning_rate_decay_factor,
staircase=True)
return tf.cond(
tf.less(global_step, burnin_steps),
lambda: tf.convert_to_tensor(burnin_learning_rate),
lambda: post_burnin_learning_rate)
return tf.where(
tf.less(tf.cast(global_step, tf.int32), tf.constant(burnin_steps)),
tf.constant(burnin_learning_rate),
post_burnin_learning_rate)
def cosine_decay_with_warmup(global_step,
......@@ -100,9 +100,10 @@ def cosine_decay_with_warmup(global_step,
slope = (learning_rate_base - warmup_learning_rate) / warmup_steps
pre_cosine_learning_rate = slope * tf.cast(
global_step, tf.float32) + warmup_learning_rate
learning_rate = tf.cond(
tf.less(global_step, warmup_steps), lambda: pre_cosine_learning_rate,
lambda: learning_rate)
learning_rate = tf.where(
tf.less(tf.cast(global_step, tf.int32), warmup_steps),
pre_cosine_learning_rate,
learning_rate)
return learning_rate
......@@ -141,10 +142,15 @@ def manual_stepping(global_step, boundaries, rates):
if len(rates) != len(boundaries) + 1:
raise ValueError('Number of provided learning rates must exceed '
'number of boundary points by exactly 1.')
step_boundaries = tf.constant(boundaries, tf.int64)
step_boundaries = tf.constant(boundaries, tf.int32)
num_boundaries = len(boundaries)
learning_rates = tf.constant(rates, tf.float32)
unreached_boundaries = tf.reshape(
tf.where(tf.greater(step_boundaries, global_step)), [-1])
unreached_boundaries = tf.concat([unreached_boundaries, [len(boundaries)]], 0)
index = tf.reshape(tf.reduce_min(unreached_boundaries), [1])
return tf.reshape(tf.slice(learning_rates, index, [1]), [])
index = tf.reduce_min(
tf.where(
# Casting global step to tf.int32 is dangerous, but necessary to be
# compatible with TPU.
tf.greater(step_boundaries, tf.cast(global_step, tf.int32)),
tf.constant(range(num_boundaries), dtype=tf.int32),
tf.constant([num_boundaries] * num_boundaries, dtype=tf.int32)))
return tf.reduce_sum(learning_rates * tf.one_hot(index, len(rates),
dtype=tf.float32))
......@@ -14,64 +14,65 @@
# ==============================================================================
"""Tests for object_detection.utils.learning_schedules."""
import numpy as np
import tensorflow as tf
from object_detection.utils import learning_schedules
from object_detection.utils import test_case
class LearningSchedulesTest(tf.test.TestCase):
class LearningSchedulesTest(test_case.TestCase):
def testExponentialDecayWithBurnin(self):
global_step = tf.placeholder(tf.int32, [])
def graph_fn(global_step):
learning_rate_base = 1.0
learning_rate_decay_steps = 3
learning_rate_decay_factor = .1
burnin_learning_rate = .5
burnin_steps = 2
exp_rates = [.5, .5, 1, .1, .1, .1, .01, .01]
learning_rate = learning_schedules.exponential_decay_with_burnin(
global_step, learning_rate_base, learning_rate_decay_steps,
learning_rate_decay_factor, burnin_learning_rate, burnin_steps)
with self.test_session() as sess:
output_rates = []
for input_global_step in range(8):
output_rate = sess.run(learning_rate,
feed_dict={global_step: input_global_step})
output_rates.append(output_rate)
self.assertAllClose(output_rates, exp_rates)
return (learning_rate,)
output_rates = [
self.execute(graph_fn, [np.array(i).astype(np.int64)]) for i in range(8)
]
exp_rates = [.5, .5, 1, .1, .1, .1, .01, .01]
self.assertAllClose(output_rates, exp_rates, rtol=1e-4)
def testCosineDecayWithWarmup(self):
global_step = tf.placeholder(tf.int32, [])
def graph_fn(global_step):
learning_rate_base = 1.0
total_steps = 100
warmup_learning_rate = 0.1
warmup_steps = 9
input_global_steps = [0, 4, 8, 9, 100]
exp_rates = [0.1, 0.5, 0.9, 1.0, 0]
learning_rate = learning_schedules.cosine_decay_with_warmup(
global_step, learning_rate_base, total_steps,
warmup_learning_rate, warmup_steps)
with self.test_session() as sess:
output_rates = []
for input_global_step in input_global_steps:
output_rate = sess.run(learning_rate,
feed_dict={global_step: input_global_step})
output_rates.append(output_rate)
return (learning_rate,)
exp_rates = [0.1, 0.5, 0.9, 1.0, 0]
input_global_steps = [0, 4, 8, 9, 100]
output_rates = [
self.execute(graph_fn, [np.array(step).astype(np.int64)])
for step in input_global_steps
]
self.assertAllClose(output_rates, exp_rates)
def testManualStepping(self):
global_step = tf.placeholder(tf.int64, [])
def graph_fn(global_step):
boundaries = [2, 3, 7]
rates = [1.0, 2.0, 3.0, 4.0]
learning_rate = learning_schedules.manual_stepping(
global_step, boundaries, rates)
return (learning_rate,)
output_rates = [
self.execute(graph_fn, [np.array(i).astype(np.int64)])
for i in range(10)
]
exp_rates = [1.0, 1.0, 2.0, 3.0, 3.0, 3.0, 3.0, 4.0, 4.0, 4.0]
learning_rate = learning_schedules.manual_stepping(global_step, boundaries,
rates)
with self.test_session() as sess:
output_rates = []
for input_global_step in range(10):
output_rate = sess.run(learning_rate,
feed_dict={global_step: input_global_step})
output_rates.append(output_rate)
self.assertAllClose(output_rates, exp_rates)
if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment