"docker/vscode:/vscode.git/clone" did not exist on "6dcdaf59ce813bca5a3b8f6d333741d9bd331c30"
Commit b60dc237 authored by Will Cromar's avatar Will Cromar Committed by A. Unique TensorFlower
Browse files

Write examples/second and steps/second summaries in TimeHistory callback.

PiperOrigin-RevId: 296507807
parent 706a0bd9
...@@ -298,13 +298,16 @@ class EpochHelper(object): ...@@ -298,13 +298,16 @@ class EpochHelper(object):
self._epoch_steps = epoch_steps self._epoch_steps = epoch_steps
self._global_step = global_step self._global_step = global_step
self._current_epoch = None self._current_epoch = None
self._epoch_start_step = None
self._in_epoch = False self._in_epoch = False
def epoch_begin(self): def epoch_begin(self):
"""Returns whether a new epoch should begin.""" """Returns whether a new epoch should begin."""
if self._in_epoch: if self._in_epoch:
return False return False
self._current_epoch = self._global_step.numpy() / self._epoch_steps current_step = self._global_step.numpy()
self._epoch_start_step = current_step
self._current_epoch = current_step // self._epoch_steps
self._in_epoch = True self._in_epoch = True
return True return True
...@@ -313,13 +316,18 @@ class EpochHelper(object): ...@@ -313,13 +316,18 @@ class EpochHelper(object):
if not self._in_epoch: if not self._in_epoch:
raise ValueError("`epoch_end` can only be called inside an epoch") raise ValueError("`epoch_end` can only be called inside an epoch")
current_step = self._global_step.numpy() current_step = self._global_step.numpy()
epoch = current_step / self._epoch_steps epoch = current_step // self._epoch_steps
if epoch > self._current_epoch: if epoch > self._current_epoch:
self._in_epoch = False self._in_epoch = False
return True return True
return False return False
@property
def batch_index(self):
"""Index of the next batch within the current epoch."""
return self._global_step.numpy() - self._epoch_start_step
@property @property
def current_epoch(self): def current_epoch(self):
return self._current_epoch return self._current_epoch
...@@ -44,17 +44,28 @@ class BatchTimestamp(object): ...@@ -44,17 +44,28 @@ class BatchTimestamp(object):
class TimeHistory(tf.keras.callbacks.Callback): class TimeHistory(tf.keras.callbacks.Callback):
"""Callback for Keras models.""" """Callback for Keras models."""
def __init__(self, batch_size, log_steps): def __init__(self, batch_size, log_steps, logdir=None):
"""Callback for logging performance. """Callback for logging performance.
Args: Args:
batch_size: Total batch size. batch_size: Total batch size.
log_steps: Interval of steps between logging of batch level stats. log_steps: Interval of steps between logging of batch level stats.
logdir: Optional directory to write TensorBoard summaries.
""" """
# TODO(wcromar): remove this parameter and rely on `logs` parameter of
# on_train_batch_end()
self.batch_size = batch_size self.batch_size = batch_size
super(TimeHistory, self).__init__() super(TimeHistory, self).__init__()
self.log_steps = log_steps self.log_steps = log_steps
self.global_steps = 0 self.last_log_step = 0
self.steps_before_epoch = 0
self.steps_in_epoch = 0
self.start_time = None
if logdir:
self.summary_writer = tf.summary.create_file_writer(logdir)
else:
self.summary_writer = None
# Logs start of step 1 then end of each step based on log_steps interval. # Logs start of step 1 then end of each step based on log_steps interval.
self.timestamp_log = [] self.timestamp_log = []
...@@ -62,38 +73,70 @@ class TimeHistory(tf.keras.callbacks.Callback): ...@@ -62,38 +73,70 @@ class TimeHistory(tf.keras.callbacks.Callback):
# Records the time each epoch takes to run from start to finish of epoch. # Records the time each epoch takes to run from start to finish of epoch.
self.epoch_runtime_log = [] self.epoch_runtime_log = []
@property
def global_steps(self):
"""The current 1-indexed global step."""
return self.steps_before_epoch + self.steps_in_epoch
@property
def average_steps_per_second(self):
"""The average training steps per second across all epochs."""
return self.global_steps / sum(self.epoch_runtime_log)
@property
def average_examples_per_second(self):
"""The average number of training examples per second across all epochs."""
return self.average_steps_per_second * self.batch_size
def on_train_end(self, logs=None): def on_train_end(self, logs=None):
self.train_finish_time = time.time() self.train_finish_time = time.time()
if self.summary_writer:
self.summary_writer.flush()
def on_epoch_begin(self, epoch, logs=None): def on_epoch_begin(self, epoch, logs=None):
self.epoch_start = time.time() self.epoch_start = time.time()
def on_batch_begin(self, batch, logs=None): def on_batch_begin(self, batch, logs=None):
self.global_steps += 1 if not self.start_time:
if self.global_steps == 1:
self.start_time = time.time() self.start_time = time.time()
# Record the timestamp of the first global step
if not self.timestamp_log:
self.timestamp_log.append(BatchTimestamp(self.global_steps, self.timestamp_log.append(BatchTimestamp(self.global_steps,
self.start_time)) self.start_time))
def on_batch_end(self, batch, logs=None): def on_batch_end(self, batch, logs=None):
"""Records elapse time of the batch and calculates examples per second.""" """Records elapse time of the batch and calculates examples per second."""
if self.global_steps % self.log_steps == 0: self.steps_in_epoch = batch + 1
timestamp = time.time() steps_since_last_log = self.global_steps - self.last_log_step
elapsed_time = timestamp - self.start_time if steps_since_last_log >= self.log_steps:
examples_per_second = (self.batch_size * self.log_steps) / elapsed_time now = time.time()
self.timestamp_log.append(BatchTimestamp(self.global_steps, timestamp)) elapsed_time = now - self.start_time
steps_per_second = steps_since_last_log / elapsed_time
examples_per_second = steps_per_second * self.batch_size
self.timestamp_log.append(BatchTimestamp(self.global_steps, now))
logging.info( logging.info(
"BenchmarkMetric: {'global step':%d, 'time_taken': %f," "TimeHistory: %.2f examples/second between steps %d and %d",
"'examples_per_second': %f}", examples_per_second, self.last_log_step, self.global_steps)
self.global_steps, elapsed_time, examples_per_second)
self.start_time = timestamp if self.summary_writer:
with self.summary_writer.as_default():
tf.summary.scalar('global_step/sec', steps_per_second,
self.global_steps)
tf.summary.scalar('examples/sec', examples_per_second,
self.global_steps)
self.last_log_step = self.global_steps
self.start_time = None
def on_epoch_end(self, epoch, logs=None): def on_epoch_end(self, epoch, logs=None):
epoch_run_time = time.time() - self.epoch_start epoch_run_time = time.time() - self.epoch_start
self.epoch_runtime_log.append(epoch_run_time) self.epoch_runtime_log.append(epoch_run_time)
logging.info(
"BenchmarkMetric: {'epoch':%d, 'time_taken': %f}", self.steps_before_epoch += self.steps_in_epoch
epoch, epoch_run_time) self.steps_in_epoch = 0
def get_profiler_callback(model_dir, profile_steps, enable_tensorboard, def get_profiler_callback(model_dir, profile_steps, enable_tensorboard,
......
...@@ -188,7 +188,10 @@ def get_callbacks( ...@@ -188,7 +188,10 @@ def get_callbacks(
enable_checkpoint_and_export=False, enable_checkpoint_and_export=False,
model_dir=None): model_dir=None):
"""Returns common callbacks.""" """Returns common callbacks."""
time_callback = keras_utils.TimeHistory(FLAGS.batch_size, FLAGS.log_steps) time_callback = keras_utils.TimeHistory(
FLAGS.batch_size,
FLAGS.log_steps,
logdir=FLAGS.model_dir if FLAGS.enable_tensorboard else None)
callbacks = [time_callback] callbacks = [time_callback]
if not FLAGS.use_tensor_lr and learning_rate_schedule_fn: if not FLAGS.use_tensor_lr and learning_rate_schedule_fn:
...@@ -265,11 +268,9 @@ def build_stats(history, eval_output, callbacks): ...@@ -265,11 +268,9 @@ def build_stats(history, eval_output, callbacks):
timestamp_log = callback.timestamp_log timestamp_log = callback.timestamp_log
stats['step_timestamp_log'] = timestamp_log stats['step_timestamp_log'] = timestamp_log
stats['train_finish_time'] = callback.train_finish_time stats['train_finish_time'] = callback.train_finish_time
if len(timestamp_log) > 1: if callback.epoch_runtime_log:
stats['avg_exp_per_second'] = ( stats['avg_exp_per_second'] = callback.average_examples_per_second
callback.batch_size * callback.log_steps *
(len(callback.timestamp_log)-1) /
(timestamp_log[-1].timestamp - timestamp_log[0].timestamp))
return stats return stats
......
...@@ -64,15 +64,8 @@ def build_stats(runnable, time_callback): ...@@ -64,15 +64,8 @@ def build_stats(runnable, time_callback):
timestamp_log = time_callback.timestamp_log timestamp_log = time_callback.timestamp_log
stats['step_timestamp_log'] = timestamp_log stats['step_timestamp_log'] = timestamp_log
stats['train_finish_time'] = time_callback.train_finish_time stats['train_finish_time'] = time_callback.train_finish_time
if len(timestamp_log) > 1: if time_callback.epoch_runtime_log:
stats['avg_exp_per_second'] = ( stats['avg_exp_per_second'] = time_callback.average_examples_per_second
time_callback.batch_size * time_callback.log_steps *
(len(time_callback.timestamp_log) - 1) /
(timestamp_log[-1].timestamp - timestamp_log[0].timestamp))
avg_exp_per_second = tf.reduce_mean(
runnable.examples_per_second_history).numpy(),
stats['avg_exp_per_second'] = avg_exp_per_second
return stats return stats
...@@ -154,8 +147,10 @@ def run(flags_obj): ...@@ -154,8 +147,10 @@ def run(flags_obj):
'total steps: %d; Eval %d steps', train_epochs, per_epoch_steps, 'total steps: %d; Eval %d steps', train_epochs, per_epoch_steps,
train_epochs * per_epoch_steps, eval_steps) train_epochs * per_epoch_steps, eval_steps)
time_callback = keras_utils.TimeHistory(flags_obj.batch_size, time_callback = keras_utils.TimeHistory(
flags_obj.log_steps) flags_obj.batch_size,
flags_obj.log_steps,
logdir=flags_obj.model_dir if flags_obj.enable_tensorboard else None)
with distribution_utils.get_strategy_scope(strategy): with distribution_utils.get_strategy_scope(strategy):
runnable = resnet_runnable.ResnetRunnable(flags_obj, time_callback, runnable = resnet_runnable.ResnetRunnable(flags_obj, time_callback,
per_epoch_steps) per_epoch_steps)
......
...@@ -114,7 +114,6 @@ class ResnetRunnable(standard_runnable.StandardTrainable, ...@@ -114,7 +114,6 @@ class ResnetRunnable(standard_runnable.StandardTrainable,
# Handling epochs. # Handling epochs.
self.epoch_steps = epoch_steps self.epoch_steps = epoch_steps
self.epoch_helper = utils.EpochHelper(epoch_steps, self.global_step) self.epoch_helper = utils.EpochHelper(epoch_steps, self.global_step)
self.examples_per_second_history = []
def build_train_dataset(self): def build_train_dataset(self):
"""See base class.""" """See base class."""
...@@ -147,8 +146,8 @@ class ResnetRunnable(standard_runnable.StandardTrainable, ...@@ -147,8 +146,8 @@ class ResnetRunnable(standard_runnable.StandardTrainable,
self.train_loss.reset_states() self.train_loss.reset_states()
self.train_accuracy.reset_states() self.train_accuracy.reset_states()
self.time_callback.on_batch_begin(self.global_step)
self._epoch_begin() self._epoch_begin()
self.time_callback.on_batch_begin(self.epoch_helper.batch_index)
def train_step(self, iterator): def train_step(self, iterator):
"""See base class.""" """See base class."""
...@@ -194,12 +193,13 @@ class ResnetRunnable(standard_runnable.StandardTrainable, ...@@ -194,12 +193,13 @@ class ResnetRunnable(standard_runnable.StandardTrainable,
def train_loop_end(self): def train_loop_end(self):
"""See base class.""" """See base class."""
self.time_callback.on_batch_end(self.global_step) metrics = {
self._epoch_end()
return {
'train_loss': self.train_loss.result(), 'train_loss': self.train_loss.result(),
'train_accuracy': self.train_accuracy.result(), 'train_accuracy': self.train_accuracy.result(),
} }
self.time_callback.on_batch_end(self.epoch_helper.batch_index - 1)
self._epoch_end()
return metrics
def eval_begin(self): def eval_begin(self):
"""See base class.""" """See base class."""
...@@ -234,10 +234,3 @@ class ResnetRunnable(standard_runnable.StandardTrainable, ...@@ -234,10 +234,3 @@ class ResnetRunnable(standard_runnable.StandardTrainable,
def _epoch_end(self): def _epoch_end(self):
if self.epoch_helper.epoch_end(): if self.epoch_helper.epoch_end():
self.time_callback.on_epoch_end(self.epoch_helper.current_epoch) self.time_callback.on_epoch_end(self.epoch_helper.current_epoch)
epoch_time = self.time_callback.epoch_runtime_log[-1]
steps_per_second = self.epoch_steps / epoch_time
examples_per_second = steps_per_second * self.flags_obj.batch_size
self.examples_per_second_history.append(examples_per_second)
tf.summary.scalar('global_step/sec', steps_per_second)
tf.summary.scalar('examples/sec', examples_per_second)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment