Unverified Commit ae655521 authored by Shining Sun's avatar Shining Sun Committed by GitHub
Browse files

Add timestamp history for each batch in training (#6024)

* Add timestamp history for each batch in training

* Resolve github comments

* Change the batch start recording logic

* Lint fix
parent c7f29b2a
...@@ -186,13 +186,13 @@ def run(flags_obj): ...@@ -186,13 +186,13 @@ def run(flags_obj):
eval_output = model.evaluate(eval_input_dataset, eval_output = model.evaluate(eval_input_dataset,
steps=num_eval_steps, steps=num_eval_steps,
verbose=1) verbose=1)
stats = keras_common.build_stats(history, eval_output) stats = keras_common.build_stats(history, eval_output, time_callback)
return stats return stats
def main(_): def main(_):
with logger.benchmark_context(flags.FLAGS): with logger.benchmark_context(flags.FLAGS):
run(flags.FLAGS) return run(flags.FLAGS)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -28,16 +28,23 @@ import tensorflow as tf ...@@ -28,16 +28,23 @@ import tensorflow as tf
from tensorflow.python.keras.optimizer_v2 import (gradient_descent as from tensorflow.python.keras.optimizer_v2 import (gradient_descent as
gradient_descent_v2) gradient_descent_v2)
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
BASE_LEARNING_RATE = 0.1 # This matches Jing's version. BASE_LEARNING_RATE = 0.1 # This matches Jing's version.
TRAIN_TOP_1 = 'training_accuracy_top_1' TRAIN_TOP_1 = 'training_accuracy_top_1'
class BatchTimestamp(object):
"""A structure to store batch time stamp."""
def __init__(self, batch_index, timestamp):
self.batch_index = batch_index
self.timestamp = timestamp
class TimeHistory(tf.keras.callbacks.Callback): class TimeHistory(tf.keras.callbacks.Callback):
"""Callback for Keras models.""" """Callback for Keras models."""
def __init__(self, batch_size): def __init__(self, batch_size, log_steps):
"""Callback for logging performance (# image/second). """Callback for logging performance (# image/second).
Args: Args:
...@@ -46,23 +53,34 @@ class TimeHistory(tf.keras.callbacks.Callback): ...@@ -46,23 +53,34 @@ class TimeHistory(tf.keras.callbacks.Callback):
""" """
self._batch_size = batch_size self._batch_size = batch_size
super(TimeHistory, self).__init__() super(TimeHistory, self).__init__()
self.log_steps = 100 self.log_steps = log_steps
# has stats for all batches
self.batch_start_timestamps = []
# only has stats for batch_index % log_steps == 0 (excluding 0)
self.batch_end_timestamps = []
def on_train_begin(self, logs=None): def on_train_begin(self, logs=None):
self.record_batch = True self.record_batch = True
def on_train_end(self, logs=None):
self.train_finish_time = time.time()
def on_batch_begin(self, batch, logs=None): def on_batch_begin(self, batch, logs=None):
if self.record_batch: if self.record_batch:
self.start_time = time.time() timestamp = time.time()
self.start_time = timestamp
self.record_batch = False self.record_batch = False
self.batch_start_timestamps.append(BatchTimestamp(batch, timestamp))
def on_batch_end(self, batch, logs=None): def on_batch_end(self, batch, logs=None):
if batch % self.log_steps == 0: if batch % self.log_steps == 0:
elapsed_time = time.time() - self.start_time timestamp = time.time()
elapsed_time = timestamp - self.start_time
examples_per_second = (self._batch_size * self.log_steps) / elapsed_time examples_per_second = (self._batch_size * self.log_steps) / elapsed_time
self.record_batch = True self.record_batch = True
# TODO(anjalisridhar): add timestamp as well.
if batch != 0: if batch != 0:
self.batch_end_timestamps.append(BatchTimestamp(batch, timestamp))
tf.logging.info("BenchmarkMetric: {'num_batches':%d, 'time_taken': %f," tf.logging.info("BenchmarkMetric: {'num_batches':%d, 'time_taken': %f,"
"'images_per_second': %f}" % "'images_per_second': %f}" %
(batch, elapsed_time, examples_per_second)) (batch, elapsed_time, examples_per_second))
...@@ -115,7 +133,7 @@ def get_optimizer(): ...@@ -115,7 +133,7 @@ def get_optimizer():
def get_callbacks(learning_rate_schedule_fn, num_images): def get_callbacks(learning_rate_schedule_fn, num_images):
"""Returns common callbacks.""" """Returns common callbacks."""
time_callback = TimeHistory(FLAGS.batch_size) time_callback = TimeHistory(FLAGS.batch_size, FLAGS.log_steps)
tensorboard_callback = tf.keras.callbacks.TensorBoard( tensorboard_callback = tf.keras.callbacks.TensorBoard(
log_dir=FLAGS.model_dir) log_dir=FLAGS.model_dir)
...@@ -128,7 +146,7 @@ def get_callbacks(learning_rate_schedule_fn, num_images): ...@@ -128,7 +146,7 @@ def get_callbacks(learning_rate_schedule_fn, num_images):
return time_callback, tensorboard_callback, lr_callback return time_callback, tensorboard_callback, lr_callback
def build_stats(history, eval_output): def build_stats(history, eval_output, time_callback):
"""Normalizes and returns dictionary of stats. """Normalizes and returns dictionary of stats.
Args: Args:
...@@ -144,6 +162,7 @@ def build_stats(history, eval_output): ...@@ -144,6 +162,7 @@ def build_stats(history, eval_output):
if eval_output: if eval_output:
stats['accuracy_top_1'] = eval_output[1].item() stats['accuracy_top_1'] = eval_output[1].item()
stats['eval_loss'] = eval_output[0].item() stats['eval_loss'] = eval_output[0].item()
if history and history.history: if history and history.history:
train_hist = history.history train_hist = history.history
# Gets final loss from training. # Gets final loss from training.
...@@ -154,6 +173,11 @@ def build_stats(history, eval_output): ...@@ -154,6 +173,11 @@ def build_stats(history, eval_output):
elif 'sparse_categorical_accuracy' in train_hist: elif 'sparse_categorical_accuracy' in train_hist:
stats[TRAIN_TOP_1] = train_hist['sparse_categorical_accuracy'][-1].item() stats[TRAIN_TOP_1] = train_hist['sparse_categorical_accuracy'][-1].item()
if time_callback:
stats['batch_start_timestamps'] = time_callback.batch_start_timestamps
stats['batch_end_timestamps'] = time_callback.batch_end_timestamps
stats['train_finish_time'] = time_callback.train_finish_time
return stats return stats
...@@ -165,6 +189,11 @@ def define_keras_flags(): ...@@ -165,6 +189,11 @@ def define_keras_flags():
help='The number of steps to run for training. If it is larger than ' help='The number of steps to run for training. If it is larger than '
'# batches per epoch, then use # batches per epoch. When this flag is ' '# batches per epoch, then use # batches per epoch. When this flag is '
'set, only one epoch is going to run for training.') 'set, only one epoch is going to run for training.')
flags.DEFINE_integer(
name='log_steps', default=100,
help='For every log_steps, we log the timing information such as '
'examples per second. Besides, for every log_steps, we store the '
'timestamp of a batch end.')
def get_synth_input_fn(height, width, num_channels, num_classes, def get_synth_input_fn(height, width, num_channels, num_classes,
......
...@@ -36,7 +36,12 @@ class KerasCommonTests(tf.test.TestCase): ...@@ -36,7 +36,12 @@ class KerasCommonTests(tf.test.TestCase):
history = self._build_history(1.145, cat_accuracy=.99988) history = self._build_history(1.145, cat_accuracy=.99988)
eval_output = self._build_eval_output(.56432111, 5.990) eval_output = self._build_eval_output(.56432111, 5.990)
stats = keras_common.build_stats(history, eval_output) th = keras_common.TimeHistory(128, 100)
th.batch_start_timestamps = [1, 2, 3]
th.batch_end_timestamps = [4, 5, 6]
th.train_finish_time = 12345
stats = keras_common.build_stats(history, eval_output, th)
self.assertEqual(1.145, stats['loss']) self.assertEqual(1.145, stats['loss'])
self.assertEqual(.99988, stats['training_accuracy_top_1']) self.assertEqual(.99988, stats['training_accuracy_top_1'])
...@@ -44,11 +49,15 @@ class KerasCommonTests(tf.test.TestCase): ...@@ -44,11 +49,15 @@ class KerasCommonTests(tf.test.TestCase):
self.assertEqual(.56432111, stats['accuracy_top_1']) self.assertEqual(.56432111, stats['accuracy_top_1'])
self.assertEqual(5.990, stats['eval_loss']) self.assertEqual(5.990, stats['eval_loss'])
self.assertItemsEqual([1, 2, 3], stats['batch_start_timestamps'])
self.assertItemsEqual([4, 5, 6], stats['batch_end_timestamps'])
self.assertEqual(12345, stats['train_finish_time'])
def test_build_stats_sparse(self): def test_build_stats_sparse(self):
history = self._build_history(1.145, cat_accuracy_sparse=.99988) history = self._build_history(1.145, cat_accuracy_sparse=.99988)
eval_output = self._build_eval_output(.928, 1.9844) eval_output = self._build_eval_output(.928, 1.9844)
stats = keras_common.build_stats(history, eval_output) stats = keras_common.build_stats(history, eval_output, None)
self.assertEqual(1.145, stats['loss']) self.assertEqual(1.145, stats['loss'])
self.assertEqual(.99988, stats['training_accuracy_top_1']) self.assertEqual(.99988, stats['training_accuracy_top_1'])
...@@ -56,6 +65,36 @@ class KerasCommonTests(tf.test.TestCase): ...@@ -56,6 +65,36 @@ class KerasCommonTests(tf.test.TestCase):
self.assertEqual(.928, stats['accuracy_top_1']) self.assertEqual(.928, stats['accuracy_top_1'])
self.assertEqual(1.9844, stats['eval_loss']) self.assertEqual(1.9844, stats['eval_loss'])
def test_time_history(self):
th = keras_common.TimeHistory(batch_size=128, log_steps=3)
th.on_train_begin()
th.on_batch_begin(0)
th.on_batch_end(0)
th.on_batch_begin(1)
th.on_batch_end(1)
th.on_batch_begin(2)
th.on_batch_end(2)
th.on_batch_begin(3)
th.on_batch_end(3)
th.on_batch_begin(4)
th.on_batch_end(4)
th.on_batch_begin(5)
th.on_batch_end(5)
th.on_batch_begin(6)
th.on_batch_end(6)
th.on_train_end()
self.assertEqual(3, len(th.batch_start_timestamps))
self.assertEqual(2, len(th.batch_end_timestamps))
self.assertEqual(0, th.batch_start_timestamps[0].batch_index)
self.assertEqual(1, th.batch_start_timestamps[1].batch_index)
self.assertEqual(4, th.batch_start_timestamps[2].batch_index)
self.assertEqual(3, th.batch_end_timestamps[0].batch_index)
self.assertEqual(6, th.batch_end_timestamps[1].batch_index)
def _build_history(self, loss, cat_accuracy=None, def _build_history(self, loss, cat_accuracy=None,
cat_accuracy_sparse=None): cat_accuracy_sparse=None):
history_p = Mock() history_p = Mock()
......
...@@ -173,19 +173,19 @@ def run(flags_obj): ...@@ -173,19 +173,19 @@ def run(flags_obj):
validation_steps=num_eval_steps, validation_steps=num_eval_steps,
validation_data=validation_data, validation_data=validation_data,
verbose=1) verbose=1)
eval_output = None eval_output = None
if not flags_obj.skip_eval: if not flags_obj.skip_eval:
eval_output = model.evaluate(eval_input_dataset, eval_output = model.evaluate(eval_input_dataset,
steps=num_eval_steps, steps=num_eval_steps,
verbose=1) verbose=1)
stats = keras_common.build_stats(history, eval_output, time_callback)
stats = keras_common.build_stats(history, eval_output)
return stats return stats
def main(_): def main(_):
with logger.benchmark_context(flags.FLAGS): with logger.benchmark_context(flags.FLAGS):
run(flags.FLAGS) return run(flags.FLAGS)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment