"tools/git@developer.sourcefind.cn:Wenxuan/LightX2V.git" did not exist on "f7cdbcb5ee0590cbf28ce0237f393202d68d3670"
Unverified Commit 9d38e894 authored by Haoyu Zhang's avatar Haoyu Zhang Committed by GitHub
Browse files

Use TensorFlow ops for Keras LearningRateSchedule (#6739)

* Add learning rate tensor. This makes training slower

* Improve LearningRateSchedule with better efficiency

* Fix lint error

* Replace constant definition with existing one
parent 9c5253f1
...@@ -79,6 +79,71 @@ class LearningRateBatchScheduler(tf.keras.callbacks.Callback): ...@@ -79,6 +79,71 @@ class LearningRateBatchScheduler(tf.keras.callbacks.Callback):
'change learning rate to %s.', self.epochs, batch, lr) 'change learning rate to %s.', self.epochs, batch, lr)
class PiecewiseConstantDecayWithWarmup(
tf.keras.optimizers.schedules.LearningRateSchedule):
"""Piecewise constant decay with warmup schedule."""
def __init__(self, batch_size, epoch_size, warmup_epochs, boundaries,
multipliers, compute_lr_on_cpu=True, name=None):
super(PiecewiseConstantDecayWithWarmup, self).__init__()
if len(boundaries) != len(multipliers) - 1:
raise ValueError('The length of boundaries must be 1 less than the '
'length of multipliers')
base_lr_batch_size = 256
num_batches_per_epoch = epoch_size // batch_size
self.rescaled_lr = BASE_LEARNING_RATE * batch_size / base_lr_batch_size
self.step_boundaries = [float(num_batches_per_epoch) * x
for x in boundaries]
self.lr_values = [self.rescaled_lr * m for m in multipliers]
self.warmup_steps = warmup_epochs * num_batches_per_epoch
self.compute_lr_on_cpu = compute_lr_on_cpu
self.name = name
self.cached_learning_rate_op = None
def __call__(self, step):
if tf.executing_eagerly():
return self._get_learning_rate(step)
# In an eager function or graph, the current implementation of optimizer
# repeatedly call and thus create ops for the learning rate schedule. To
# avoid this, we cache the ops if not executing eagerly.
if self.cached_learning_rate_op is None:
if self.compute_lr_on_cpu:
with tf.device('/device:CPU:0'):
self.cached_learning_rate_op = self._get_learning_rate(step)
else:
self.cached_learning_rate_op = self._get_learning_rate(step)
return self.cached_learning_rate_op
def _get_learning_rate(self, step):
"""Compute learning rate at given step."""
with tf.name_scope(self.name, 'PiecewiseConstantDecayWithWarmup', [
self.rescaled_lr, self.step_boundaries, self.lr_values,
self.warmup_steps, self.compute_lr_on_cpu]):
def warmup_lr(step):
return self.rescaled_lr * (
tf.cast(step, tf.float32) / tf.cast(self.warmup_steps, tf.float32))
def piecewise_lr(step):
return tf.compat.v1.train.piecewise_constant(
step, self.step_boundaries, self.lr_values)
return tf.cond(step < self.warmup_steps,
lambda: warmup_lr(step),
lambda: piecewise_lr(step))
def get_config(self):
return {
'rescaled_lr': self.rescaled_lr,
'step_boundaries': self.step_boundaries,
'lr_values': self.lr_values,
'warmup_steps': self.warmup_steps,
'compute_lr_on_cpu': self.compute_lr_on_cpu,
'name': self.name
}
class ProfilerCallback(tf.keras.callbacks.Callback): class ProfilerCallback(tf.keras.callbacks.Callback):
"""Save profiles in specified step range to log directory.""" """Save profiles in specified step range to log directory."""
...@@ -159,20 +224,23 @@ def set_gpu_thread_mode_and_count(flags_obj): ...@@ -159,20 +224,23 @@ def set_gpu_thread_mode_and_count(flags_obj):
flags_obj.datasets_num_private_threads) flags_obj.datasets_num_private_threads)
def get_optimizer(): def get_optimizer(learning_rate=0.1):
"""Returns optimizer to use.""" """Returns optimizer to use."""
# The learning_rate is overwritten at the beginning of each step by callback. # The learning_rate is overwritten at the beginning of each step by callback.
return gradient_descent_v2.SGD(learning_rate=0.1, momentum=0.9) return gradient_descent_v2.SGD(learning_rate=learning_rate, momentum=0.9)
def get_callbacks(learning_rate_schedule_fn, num_images): def get_callbacks(learning_rate_schedule_fn, num_images):
"""Returns common callbacks.""" """Returns common callbacks."""
time_callback = keras_utils.TimeHistory(FLAGS.batch_size, FLAGS.log_steps) time_callback = keras_utils.TimeHistory(FLAGS.batch_size, FLAGS.log_steps)
callbacks = [time_callback]
if not FLAGS.use_tensor_lr:
lr_callback = LearningRateBatchScheduler( lr_callback = LearningRateBatchScheduler(
learning_rate_schedule_fn, learning_rate_schedule_fn,
batch_size=FLAGS.batch_size, batch_size=FLAGS.batch_size,
num_images=num_images) num_images=num_images)
callbacks = [time_callback, lr_callback] callbacks.append(lr_callback)
if FLAGS.enable_tensorboard: if FLAGS.enable_tensorboard:
tensorboard_callback = tf.keras.callbacks.TensorBoard( tensorboard_callback = tf.keras.callbacks.TensorBoard(
...@@ -264,6 +332,8 @@ def define_keras_flags(): ...@@ -264,6 +332,8 @@ def define_keras_flags():
flags.DEFINE_boolean(name='skip_eval', default=False, help='Skip evaluation?') flags.DEFINE_boolean(name='skip_eval', default=False, help='Skip evaluation?')
flags.DEFINE_boolean(name='use_trivial_model', default=False, flags.DEFINE_boolean(name='use_trivial_model', default=False,
help='Whether to use a trivial Keras model.') help='Whether to use a trivial Keras model.')
flags.DEFINE_boolean(name='use_tensor_lr', default=False,
help='Use learning rate tensor instead of a callback.')
flags.DEFINE_boolean( flags.DEFINE_boolean(
name='enable_xla', default=False, name='enable_xla', default=False,
help='Whether to enable XLA auto jit compilation. This is still an ' help='Whether to enable XLA auto jit compilation. This is still an '
......
...@@ -170,8 +170,18 @@ def run(flags_obj): ...@@ -170,8 +170,18 @@ def run(flags_obj):
dtype=dtype, dtype=dtype,
drop_remainder=drop_remainder) drop_remainder=drop_remainder)
lr_schedule = 0.1
if flags_obj.use_tensor_lr:
lr_schedule = keras_common.PiecewiseConstantDecayWithWarmup(
batch_size=flags_obj.batch_size,
epoch_size=imagenet_main.NUM_IMAGES['train'],
warmup_epochs=LR_SCHEDULE[0][1],
boundaries=list(p[1] for p in LR_SCHEDULE[1:]),
multipliers=list(p[0] for p in LR_SCHEDULE),
compute_lr_on_cpu=True)
with strategy_scope: with strategy_scope:
optimizer = keras_common.get_optimizer() optimizer = keras_common.get_optimizer(lr_schedule)
if dtype == 'float16': if dtype == 'float16':
# TODO(reedwm): Remove manually wrapping optimizer once mixed precision # TODO(reedwm): Remove manually wrapping optimizer once mixed precision
# can be enabled with a single line of code. # can be enabled with a single line of code.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment