Commit 5b2575c2 authored by Zongwei Zhou's avatar Zongwei Zhou Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 282293668
parent 1a40ebdd
......@@ -227,7 +227,7 @@ def define_transformer_flags():
# pylint: enable=unused-variable
def get_callbacks():
def get_callbacks(steps_per_epoch):
"""Returns common callbacks."""
callbacks = []
if FLAGS.enable_time_history:
......@@ -243,7 +243,8 @@ def get_callbacks():
profiler_callback = keras_utils.get_profiler_callback(
FLAGS.model_dir,
FLAGS.profile_steps,
FLAGS.enable_tensorboard)
FLAGS.enable_tensorboard,
steps_per_epoch)
callbacks.append(profiler_callback)
return callbacks
......
......@@ -159,6 +159,7 @@ class TransformerTask(object):
params["repeat_dataset"] = None
params["dtype"] = flags_core.get_tf_dtype(flags_obj)
params["enable_metrics_in_training"] = flags_obj.enable_metrics_in_training
params["steps_between_evals"] = flags_obj.steps_between_evals
self.distribution_strategy = distribution_utils.get_distribution_strategy(
distribution_strategy=flags_obj.distribution_strategy,
......@@ -387,7 +388,7 @@ class TransformerTask(object):
params["hidden_size"],
params["learning_rate_warmup_steps"])
scheduler_callback = optimizer.LearningRateScheduler(sfunc, init_steps)
callbacks = misc.get_callbacks()
callbacks = misc.get_callbacks(params["steps_between_evals"])
callbacks.append(scheduler_callback)
ckpt_full_path = os.path.join(cur_log_dir, "cp-{epoch:04d}.ckpt")
callbacks.append(tf.keras.callbacks.ModelCheckpoint(ckpt_full_path,
......
......@@ -93,7 +93,8 @@ class TimeHistory(tf.keras.callbacks.Callback):
(epoch, epoch_run_time))
def get_profiler_callback(model_dir, profile_steps, enable_tensorboard):
def get_profiler_callback(model_dir, profile_steps, enable_tensorboard,
steps_per_epoch):
"""Validate profile_steps flag value and return profiler callback."""
profile_steps_error_message = (
'profile_steps must be a comma separated pair of positive integers, '
......@@ -114,26 +115,39 @@ def get_profiler_callback(model_dir, profile_steps, enable_tensorboard):
'TensorBoard callback profiles the 2nd step (unless otherwise '
'specified). Please make sure the steps profiled by the two callbacks '
'do not overlap.')
return ProfilerCallback(model_dir, start_step, stop_step)
return ProfilerCallback(model_dir, start_step, stop_step, steps_per_epoch)
class ProfilerCallback(tf.keras.callbacks.Callback):
"""Save profiles in specified step range to log directory."""
def __init__(self, log_dir, start_step, stop_step):
def __init__(self, log_dir, start_step, stop_step, steps_per_epoch):
super(ProfilerCallback, self).__init__()
self.log_dir = log_dir
self.start_step = start_step
self.stop_step = stop_step
self.start_epoch = start_step // steps_per_epoch
self.stop_epoch = stop_step // steps_per_epoch
self.start_step_in_epoch = start_step % steps_per_epoch
self.stop_step_in_epoch = stop_step % steps_per_epoch
self.should_start = False
self.should_stop = False
def on_epoch_begin(self, epoch, logs=None):
if epoch == self.start_epoch:
self.should_start = True
if epoch == self.stop_epoch:
self.should_stop = True
def on_batch_begin(self, batch, logs=None):
if batch == self.start_step:
if batch == self.start_step_in_epoch and self.should_start:
self.should_start = False
profiler.start()
tf.compat.v1.logging.info('Profiler started at Step %s', self.start_step)
def on_batch_end(self, batch, logs=None):
if batch == self.stop_step:
if batch == self.stop_step_in_epoch and self.should_stop:
self.should_stop = False
results = profiler.stop()
profiler.save(self.log_dir, results)
tf.compat.v1.logging.info(
......
......@@ -38,7 +38,7 @@ LR_SCHEDULE = [ # (multiplier, epoch to start) tuples
def learning_rate_schedule(current_epoch,
current_batch,
batches_per_epoch,
steps_per_epoch,
batch_size):
"""Handles linear scaling rule, gradual warmup, and LR decay.
......@@ -48,14 +48,14 @@ def learning_rate_schedule(current_epoch,
Args:
current_epoch: integer, current epoch indexed from 0.
current_batch: integer, current batch in the current epoch, indexed from 0.
batches_per_epoch: integer, number of steps in an epoch.
steps_per_epoch: integer, number of steps in an epoch.
batch_size: integer, total batch sized.
Returns:
Adjusted learning rate.
"""
initial_lr = BASE_LEARNING_RATE * batch_size / 256
epoch = current_epoch + float(current_batch) / batches_per_epoch
epoch = current_epoch + float(current_batch) / steps_per_epoch
warmup_lr_multiplier, warmup_end_epoch = LR_SCHEDULE[0]
if epoch < warmup_end_epoch:
# Learning rate increases linearly per step.
......@@ -79,10 +79,10 @@ class LearningRateBatchScheduler(tf.keras.callbacks.Callback):
output (float).
"""
def __init__(self, schedule, batch_size, num_images):
def __init__(self, schedule, batch_size, steps_per_epoch):
super(LearningRateBatchScheduler, self).__init__()
self.schedule = schedule
self.batches_per_epoch = num_images / batch_size
self.steps_per_epoch = steps_per_epoch
self.batch_size = batch_size
self.epochs = -1
self.prev_lr = -1
......@@ -96,7 +96,7 @@ class LearningRateBatchScheduler(tf.keras.callbacks.Callback):
"""Executes before step begins."""
lr = self.schedule(self.epochs,
batch,
self.batches_per_epoch,
self.steps_per_epoch,
self.batch_size)
if not isinstance(lr, (float, np.float32, np.float64)):
raise ValueError('The output of the "schedule" function should be float.')
......@@ -120,13 +120,12 @@ class PiecewiseConstantDecayWithWarmup(
'length of multipliers')
base_lr_batch_size = 256
num_batches_per_epoch = epoch_size // batch_size
steps_per_epoch = epoch_size // batch_size
self.rescaled_lr = BASE_LEARNING_RATE * batch_size / base_lr_batch_size
self.step_boundaries = [float(num_batches_per_epoch) * x
for x in boundaries]
self.step_boundaries = [float(steps_per_epoch) * x for x in boundaries]
self.lr_values = [self.rescaled_lr * m for m in multipliers]
self.warmup_steps = warmup_epochs * num_batches_per_epoch
self.warmup_steps = warmup_epochs * steps_per_epoch
self.compute_lr_on_cpu = compute_lr_on_cpu
self.name = name
......@@ -208,7 +207,7 @@ def get_optimizer(learning_rate=0.1):
# TODO(hongkuny,haoyuzhang): make cifar model use_tensor_lr to clean up code.
def get_callbacks(learning_rate_schedule_fn=None, num_images=None):
def get_callbacks(steps_per_epoch, learning_rate_schedule_fn=None):
"""Returns common callbacks."""
time_callback = keras_utils.TimeHistory(FLAGS.batch_size, FLAGS.log_steps)
callbacks = [time_callback]
......@@ -217,7 +216,7 @@ def get_callbacks(learning_rate_schedule_fn=None, num_images=None):
lr_callback = LearningRateBatchScheduler(
learning_rate_schedule_fn,
batch_size=FLAGS.batch_size,
num_images=num_images)
steps_per_epoch=steps_per_epoch)
callbacks.append(lr_callback)
if FLAGS.enable_tensorboard:
......@@ -229,7 +228,8 @@ def get_callbacks(learning_rate_schedule_fn=None, num_images=None):
profiler_callback = keras_utils.get_profiler_callback(
FLAGS.model_dir,
FLAGS.profile_steps,
FLAGS.enable_tensorboard)
FLAGS.enable_tensorboard,
steps_per_epoch)
callbacks.append(profiler_callback)
return callbacks
......@@ -332,7 +332,7 @@ def define_keras_flags(dynamic_loss_scale=True):
'ignored if train_epochs is set to be larger than 1. ')
flags.DEFINE_string(
name='profile_steps', default=None,
help='Save profiling data to model dir at given range of steps. The '
help='Save profiling data to model dir at given range of global steps. The '
'value must be a comma separated pair of positive integers, specifying '
'the first and last step to profile. For example, "--profile_steps=2,4" '
'triggers the profiler to process 3 steps, starting from the 2nd step. '
......
......@@ -171,14 +171,15 @@ def run(flags_obj):
if flags_obj.report_accuracy_metrics else None),
run_eagerly=flags_obj.run_eagerly)
callbacks = common.get_callbacks(
learning_rate_schedule, cifar_preprocessing.NUM_IMAGES['train'])
train_steps = cifar_preprocessing.NUM_IMAGES['train'] // flags_obj.batch_size
steps_per_epoch = (
cifar_preprocessing.NUM_IMAGES['train'] // flags_obj.batch_size)
train_epochs = flags_obj.train_epochs
if flags_obj.train_steps:
train_steps = min(flags_obj.train_steps, train_steps)
callbacks = common.get_callbacks(steps_per_epoch, learning_rate_schedule)
# if mutliple epochs, ignore the train_steps flag.
if train_epochs <= 1 and flags_obj.train_steps:
steps_per_epoch = min(flags_obj.train_steps, steps_per_epoch)
train_epochs = 1
num_eval_steps = (cifar_preprocessing.NUM_IMAGES['validation'] //
......@@ -201,7 +202,7 @@ def run(flags_obj):
history = model.fit(train_input_dataset,
epochs=train_epochs,
steps_per_epoch=train_steps,
steps_per_epoch=steps_per_epoch,
callbacks=callbacks,
validation_steps=num_eval_steps,
validation_data=validation_data,
......@@ -225,7 +226,6 @@ def define_cifar_flags():
flags_core.set_defaults(data_dir='/tmp/cifar10_data/cifar-10-batches-bin',
model_dir='/tmp/cifar10_model',
train_epochs=182,
epochs_between_evals=10,
batch_size=128)
......
......@@ -187,19 +187,20 @@ def run(flags_obj):
if flags_obj.report_accuracy_metrics else None),
run_eagerly=flags_obj.run_eagerly)
callbacks = common.get_callbacks(
common.learning_rate_schedule, imagenet_preprocessing.NUM_IMAGES['train'])
steps_per_epoch = (
imagenet_preprocessing.NUM_IMAGES['train'] // flags_obj.batch_size)
train_epochs = flags_obj.train_epochs
callbacks = common.get_callbacks(steps_per_epoch,
common.learning_rate_schedule)
if flags_obj.enable_checkpoint_and_export:
ckpt_full_path = os.path.join(flags_obj.model_dir, 'model.ckpt-{epoch:04d}')
callbacks.append(tf.keras.callbacks.ModelCheckpoint(ckpt_full_path,
save_weights_only=True))
train_steps = (
imagenet_preprocessing.NUM_IMAGES['train'] // flags_obj.batch_size)
train_epochs = flags_obj.train_epochs
# if mutliple epochs, ignore the train_steps flag.
if train_epochs <= 1 and flags_obj.train_steps:
train_steps = min(flags_obj.train_steps, train_steps)
steps_per_epoch = min(flags_obj.train_steps, steps_per_epoch)
train_epochs = 1
num_eval_steps = (
......@@ -225,7 +226,7 @@ def run(flags_obj):
history = model.fit(train_input_dataset,
epochs=train_epochs,
steps_per_epoch=train_steps,
steps_per_epoch=steps_per_epoch,
callbacks=callbacks,
validation_steps=num_eval_steps,
validation_data=validation_data,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment