"vscode:/vscode.git/clone" did not exist on "ac0825f513f98c1ecd5ab95da103992bac1cfda5"
Commit 5b2575c2 authored by Zongwei Zhou's avatar Zongwei Zhou Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 282293668
parent 1a40ebdd
...@@ -227,7 +227,7 @@ def define_transformer_flags(): ...@@ -227,7 +227,7 @@ def define_transformer_flags():
# pylint: enable=unused-variable # pylint: enable=unused-variable
def get_callbacks(): def get_callbacks(steps_per_epoch):
"""Returns common callbacks.""" """Returns common callbacks."""
callbacks = [] callbacks = []
if FLAGS.enable_time_history: if FLAGS.enable_time_history:
...@@ -243,7 +243,8 @@ def get_callbacks(): ...@@ -243,7 +243,8 @@ def get_callbacks():
profiler_callback = keras_utils.get_profiler_callback( profiler_callback = keras_utils.get_profiler_callback(
FLAGS.model_dir, FLAGS.model_dir,
FLAGS.profile_steps, FLAGS.profile_steps,
FLAGS.enable_tensorboard) FLAGS.enable_tensorboard,
steps_per_epoch)
callbacks.append(profiler_callback) callbacks.append(profiler_callback)
return callbacks return callbacks
......
...@@ -159,6 +159,7 @@ class TransformerTask(object): ...@@ -159,6 +159,7 @@ class TransformerTask(object):
params["repeat_dataset"] = None params["repeat_dataset"] = None
params["dtype"] = flags_core.get_tf_dtype(flags_obj) params["dtype"] = flags_core.get_tf_dtype(flags_obj)
params["enable_metrics_in_training"] = flags_obj.enable_metrics_in_training params["enable_metrics_in_training"] = flags_obj.enable_metrics_in_training
params["steps_between_evals"] = flags_obj.steps_between_evals
self.distribution_strategy = distribution_utils.get_distribution_strategy( self.distribution_strategy = distribution_utils.get_distribution_strategy(
distribution_strategy=flags_obj.distribution_strategy, distribution_strategy=flags_obj.distribution_strategy,
...@@ -387,7 +388,7 @@ class TransformerTask(object): ...@@ -387,7 +388,7 @@ class TransformerTask(object):
params["hidden_size"], params["hidden_size"],
params["learning_rate_warmup_steps"]) params["learning_rate_warmup_steps"])
scheduler_callback = optimizer.LearningRateScheduler(sfunc, init_steps) scheduler_callback = optimizer.LearningRateScheduler(sfunc, init_steps)
callbacks = misc.get_callbacks() callbacks = misc.get_callbacks(params["steps_between_evals"])
callbacks.append(scheduler_callback) callbacks.append(scheduler_callback)
ckpt_full_path = os.path.join(cur_log_dir, "cp-{epoch:04d}.ckpt") ckpt_full_path = os.path.join(cur_log_dir, "cp-{epoch:04d}.ckpt")
callbacks.append(tf.keras.callbacks.ModelCheckpoint(ckpt_full_path, callbacks.append(tf.keras.callbacks.ModelCheckpoint(ckpt_full_path,
......
...@@ -93,7 +93,8 @@ class TimeHistory(tf.keras.callbacks.Callback): ...@@ -93,7 +93,8 @@ class TimeHistory(tf.keras.callbacks.Callback):
(epoch, epoch_run_time)) (epoch, epoch_run_time))
def get_profiler_callback(model_dir, profile_steps, enable_tensorboard): def get_profiler_callback(model_dir, profile_steps, enable_tensorboard,
steps_per_epoch):
"""Validate profile_steps flag value and return profiler callback.""" """Validate profile_steps flag value and return profiler callback."""
profile_steps_error_message = ( profile_steps_error_message = (
'profile_steps must be a comma separated pair of positive integers, ' 'profile_steps must be a comma separated pair of positive integers, '
...@@ -114,26 +115,39 @@ def get_profiler_callback(model_dir, profile_steps, enable_tensorboard): ...@@ -114,26 +115,39 @@ def get_profiler_callback(model_dir, profile_steps, enable_tensorboard):
'TensorBoard callback profiles the 2nd step (unless otherwise ' 'TensorBoard callback profiles the 2nd step (unless otherwise '
'specified). Please make sure the steps profiled by the two callbacks ' 'specified). Please make sure the steps profiled by the two callbacks '
'do not overlap.') 'do not overlap.')
return ProfilerCallback(model_dir, start_step, stop_step, steps_per_epoch)
return ProfilerCallback(model_dir, start_step, stop_step)
class ProfilerCallback(tf.keras.callbacks.Callback): class ProfilerCallback(tf.keras.callbacks.Callback):
"""Save profiles in specified step range to log directory.""" """Save profiles in specified step range to log directory."""
def __init__(self, log_dir, start_step, stop_step): def __init__(self, log_dir, start_step, stop_step, steps_per_epoch):
super(ProfilerCallback, self).__init__() super(ProfilerCallback, self).__init__()
self.log_dir = log_dir self.log_dir = log_dir
self.start_step = start_step self.start_step = start_step
self.stop_step = stop_step self.stop_step = stop_step
self.start_epoch = start_step // steps_per_epoch
self.stop_epoch = stop_step // steps_per_epoch
self.start_step_in_epoch = start_step % steps_per_epoch
self.stop_step_in_epoch = stop_step % steps_per_epoch
self.should_start = False
self.should_stop = False
def on_epoch_begin(self, epoch, logs=None):
if epoch == self.start_epoch:
self.should_start = True
if epoch == self.stop_epoch:
self.should_stop = True
def on_batch_begin(self, batch, logs=None): def on_batch_begin(self, batch, logs=None):
if batch == self.start_step: if batch == self.start_step_in_epoch and self.should_start:
self.should_start = False
profiler.start() profiler.start()
tf.compat.v1.logging.info('Profiler started at Step %s', self.start_step) tf.compat.v1.logging.info('Profiler started at Step %s', self.start_step)
def on_batch_end(self, batch, logs=None): def on_batch_end(self, batch, logs=None):
if batch == self.stop_step: if batch == self.stop_step_in_epoch and self.should_stop:
self.should_stop = False
results = profiler.stop() results = profiler.stop()
profiler.save(self.log_dir, results) profiler.save(self.log_dir, results)
tf.compat.v1.logging.info( tf.compat.v1.logging.info(
......
...@@ -38,7 +38,7 @@ LR_SCHEDULE = [ # (multiplier, epoch to start) tuples ...@@ -38,7 +38,7 @@ LR_SCHEDULE = [ # (multiplier, epoch to start) tuples
def learning_rate_schedule(current_epoch, def learning_rate_schedule(current_epoch,
current_batch, current_batch,
batches_per_epoch, steps_per_epoch,
batch_size): batch_size):
"""Handles linear scaling rule, gradual warmup, and LR decay. """Handles linear scaling rule, gradual warmup, and LR decay.
...@@ -48,14 +48,14 @@ def learning_rate_schedule(current_epoch, ...@@ -48,14 +48,14 @@ def learning_rate_schedule(current_epoch,
Args: Args:
current_epoch: integer, current epoch indexed from 0. current_epoch: integer, current epoch indexed from 0.
current_batch: integer, current batch in the current epoch, indexed from 0. current_batch: integer, current batch in the current epoch, indexed from 0.
batches_per_epoch: integer, number of steps in an epoch. steps_per_epoch: integer, number of steps in an epoch.
batch_size: integer, total batch sized. batch_size: integer, total batch sized.
Returns: Returns:
Adjusted learning rate. Adjusted learning rate.
""" """
initial_lr = BASE_LEARNING_RATE * batch_size / 256 initial_lr = BASE_LEARNING_RATE * batch_size / 256
epoch = current_epoch + float(current_batch) / batches_per_epoch epoch = current_epoch + float(current_batch) / steps_per_epoch
warmup_lr_multiplier, warmup_end_epoch = LR_SCHEDULE[0] warmup_lr_multiplier, warmup_end_epoch = LR_SCHEDULE[0]
if epoch < warmup_end_epoch: if epoch < warmup_end_epoch:
# Learning rate increases linearly per step. # Learning rate increases linearly per step.
...@@ -79,10 +79,10 @@ class LearningRateBatchScheduler(tf.keras.callbacks.Callback): ...@@ -79,10 +79,10 @@ class LearningRateBatchScheduler(tf.keras.callbacks.Callback):
output (float). output (float).
""" """
def __init__(self, schedule, batch_size, num_images): def __init__(self, schedule, batch_size, steps_per_epoch):
super(LearningRateBatchScheduler, self).__init__() super(LearningRateBatchScheduler, self).__init__()
self.schedule = schedule self.schedule = schedule
self.batches_per_epoch = num_images / batch_size self.steps_per_epoch = steps_per_epoch
self.batch_size = batch_size self.batch_size = batch_size
self.epochs = -1 self.epochs = -1
self.prev_lr = -1 self.prev_lr = -1
...@@ -96,7 +96,7 @@ class LearningRateBatchScheduler(tf.keras.callbacks.Callback): ...@@ -96,7 +96,7 @@ class LearningRateBatchScheduler(tf.keras.callbacks.Callback):
"""Executes before step begins.""" """Executes before step begins."""
lr = self.schedule(self.epochs, lr = self.schedule(self.epochs,
batch, batch,
self.batches_per_epoch, self.steps_per_epoch,
self.batch_size) self.batch_size)
if not isinstance(lr, (float, np.float32, np.float64)): if not isinstance(lr, (float, np.float32, np.float64)):
raise ValueError('The output of the "schedule" function should be float.') raise ValueError('The output of the "schedule" function should be float.')
...@@ -120,13 +120,12 @@ class PiecewiseConstantDecayWithWarmup( ...@@ -120,13 +120,12 @@ class PiecewiseConstantDecayWithWarmup(
'length of multipliers') 'length of multipliers')
base_lr_batch_size = 256 base_lr_batch_size = 256
num_batches_per_epoch = epoch_size // batch_size steps_per_epoch = epoch_size // batch_size
self.rescaled_lr = BASE_LEARNING_RATE * batch_size / base_lr_batch_size self.rescaled_lr = BASE_LEARNING_RATE * batch_size / base_lr_batch_size
self.step_boundaries = [float(num_batches_per_epoch) * x self.step_boundaries = [float(steps_per_epoch) * x for x in boundaries]
for x in boundaries]
self.lr_values = [self.rescaled_lr * m for m in multipliers] self.lr_values = [self.rescaled_lr * m for m in multipliers]
self.warmup_steps = warmup_epochs * num_batches_per_epoch self.warmup_steps = warmup_epochs * steps_per_epoch
self.compute_lr_on_cpu = compute_lr_on_cpu self.compute_lr_on_cpu = compute_lr_on_cpu
self.name = name self.name = name
...@@ -208,7 +207,7 @@ def get_optimizer(learning_rate=0.1): ...@@ -208,7 +207,7 @@ def get_optimizer(learning_rate=0.1):
# TODO(hongkuny,haoyuzhang): make cifar model use_tensor_lr to clean up code. # TODO(hongkuny,haoyuzhang): make cifar model use_tensor_lr to clean up code.
def get_callbacks(learning_rate_schedule_fn=None, num_images=None): def get_callbacks(steps_per_epoch, learning_rate_schedule_fn=None):
"""Returns common callbacks.""" """Returns common callbacks."""
time_callback = keras_utils.TimeHistory(FLAGS.batch_size, FLAGS.log_steps) time_callback = keras_utils.TimeHistory(FLAGS.batch_size, FLAGS.log_steps)
callbacks = [time_callback] callbacks = [time_callback]
...@@ -217,7 +216,7 @@ def get_callbacks(learning_rate_schedule_fn=None, num_images=None): ...@@ -217,7 +216,7 @@ def get_callbacks(learning_rate_schedule_fn=None, num_images=None):
lr_callback = LearningRateBatchScheduler( lr_callback = LearningRateBatchScheduler(
learning_rate_schedule_fn, learning_rate_schedule_fn,
batch_size=FLAGS.batch_size, batch_size=FLAGS.batch_size,
num_images=num_images) steps_per_epoch=steps_per_epoch)
callbacks.append(lr_callback) callbacks.append(lr_callback)
if FLAGS.enable_tensorboard: if FLAGS.enable_tensorboard:
...@@ -229,7 +228,8 @@ def get_callbacks(learning_rate_schedule_fn=None, num_images=None): ...@@ -229,7 +228,8 @@ def get_callbacks(learning_rate_schedule_fn=None, num_images=None):
profiler_callback = keras_utils.get_profiler_callback( profiler_callback = keras_utils.get_profiler_callback(
FLAGS.model_dir, FLAGS.model_dir,
FLAGS.profile_steps, FLAGS.profile_steps,
FLAGS.enable_tensorboard) FLAGS.enable_tensorboard,
steps_per_epoch)
callbacks.append(profiler_callback) callbacks.append(profiler_callback)
return callbacks return callbacks
...@@ -332,7 +332,7 @@ def define_keras_flags(dynamic_loss_scale=True): ...@@ -332,7 +332,7 @@ def define_keras_flags(dynamic_loss_scale=True):
'ignored if train_epochs is set to be larger than 1. ') 'ignored if train_epochs is set to be larger than 1. ')
flags.DEFINE_string( flags.DEFINE_string(
name='profile_steps', default=None, name='profile_steps', default=None,
help='Save profiling data to model dir at given range of steps. The ' help='Save profiling data to model dir at given range of global steps. The '
'value must be a comma separated pair of positive integers, specifying ' 'value must be a comma separated pair of positive integers, specifying '
'the first and last step to profile. For example, "--profile_steps=2,4" ' 'the first and last step to profile. For example, "--profile_steps=2,4" '
'triggers the profiler to process 3 steps, starting from the 2nd step. ' 'triggers the profiler to process 3 steps, starting from the 2nd step. '
......
...@@ -171,14 +171,15 @@ def run(flags_obj): ...@@ -171,14 +171,15 @@ def run(flags_obj):
if flags_obj.report_accuracy_metrics else None), if flags_obj.report_accuracy_metrics else None),
run_eagerly=flags_obj.run_eagerly) run_eagerly=flags_obj.run_eagerly)
callbacks = common.get_callbacks( steps_per_epoch = (
learning_rate_schedule, cifar_preprocessing.NUM_IMAGES['train']) cifar_preprocessing.NUM_IMAGES['train'] // flags_obj.batch_size)
train_steps = cifar_preprocessing.NUM_IMAGES['train'] // flags_obj.batch_size
train_epochs = flags_obj.train_epochs train_epochs = flags_obj.train_epochs
if flags_obj.train_steps: callbacks = common.get_callbacks(steps_per_epoch, learning_rate_schedule)
train_steps = min(flags_obj.train_steps, train_steps)
# if mutliple epochs, ignore the train_steps flag.
if train_epochs <= 1 and flags_obj.train_steps:
steps_per_epoch = min(flags_obj.train_steps, steps_per_epoch)
train_epochs = 1 train_epochs = 1
num_eval_steps = (cifar_preprocessing.NUM_IMAGES['validation'] // num_eval_steps = (cifar_preprocessing.NUM_IMAGES['validation'] //
...@@ -201,7 +202,7 @@ def run(flags_obj): ...@@ -201,7 +202,7 @@ def run(flags_obj):
history = model.fit(train_input_dataset, history = model.fit(train_input_dataset,
epochs=train_epochs, epochs=train_epochs,
steps_per_epoch=train_steps, steps_per_epoch=steps_per_epoch,
callbacks=callbacks, callbacks=callbacks,
validation_steps=num_eval_steps, validation_steps=num_eval_steps,
validation_data=validation_data, validation_data=validation_data,
...@@ -225,7 +226,6 @@ def define_cifar_flags(): ...@@ -225,7 +226,6 @@ def define_cifar_flags():
flags_core.set_defaults(data_dir='/tmp/cifar10_data/cifar-10-batches-bin', flags_core.set_defaults(data_dir='/tmp/cifar10_data/cifar-10-batches-bin',
model_dir='/tmp/cifar10_model', model_dir='/tmp/cifar10_model',
train_epochs=182,
epochs_between_evals=10, epochs_between_evals=10,
batch_size=128) batch_size=128)
......
...@@ -187,19 +187,20 @@ def run(flags_obj): ...@@ -187,19 +187,20 @@ def run(flags_obj):
if flags_obj.report_accuracy_metrics else None), if flags_obj.report_accuracy_metrics else None),
run_eagerly=flags_obj.run_eagerly) run_eagerly=flags_obj.run_eagerly)
callbacks = common.get_callbacks( steps_per_epoch = (
common.learning_rate_schedule, imagenet_preprocessing.NUM_IMAGES['train']) imagenet_preprocessing.NUM_IMAGES['train'] // flags_obj.batch_size)
train_epochs = flags_obj.train_epochs
callbacks = common.get_callbacks(steps_per_epoch,
common.learning_rate_schedule)
if flags_obj.enable_checkpoint_and_export: if flags_obj.enable_checkpoint_and_export:
ckpt_full_path = os.path.join(flags_obj.model_dir, 'model.ckpt-{epoch:04d}') ckpt_full_path = os.path.join(flags_obj.model_dir, 'model.ckpt-{epoch:04d}')
callbacks.append(tf.keras.callbacks.ModelCheckpoint(ckpt_full_path, callbacks.append(tf.keras.callbacks.ModelCheckpoint(ckpt_full_path,
save_weights_only=True)) save_weights_only=True))
train_steps = (
imagenet_preprocessing.NUM_IMAGES['train'] // flags_obj.batch_size)
train_epochs = flags_obj.train_epochs
# if mutliple epochs, ignore the train_steps flag. # if mutliple epochs, ignore the train_steps flag.
if train_epochs <= 1 and flags_obj.train_steps: if train_epochs <= 1 and flags_obj.train_steps:
train_steps = min(flags_obj.train_steps, train_steps) steps_per_epoch = min(flags_obj.train_steps, steps_per_epoch)
train_epochs = 1 train_epochs = 1
num_eval_steps = ( num_eval_steps = (
...@@ -225,7 +226,7 @@ def run(flags_obj): ...@@ -225,7 +226,7 @@ def run(flags_obj):
history = model.fit(train_input_dataset, history = model.fit(train_input_dataset,
epochs=train_epochs, epochs=train_epochs,
steps_per_epoch=train_steps, steps_per_epoch=steps_per_epoch,
callbacks=callbacks, callbacks=callbacks,
validation_steps=num_eval_steps, validation_steps=num_eval_steps,
validation_data=validation_data, validation_data=validation_data,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment