Unverified Commit ae655521 authored by Shining Sun's avatar Shining Sun Committed by GitHub
Browse files

Add timestamp history for each batch in training (#6024)

* Add timestamp history for each batch in training

* Resolve github comments

* Change the batch start recording logic

* Lint fix
parent c7f29b2a
......@@ -186,13 +186,13 @@ def run(flags_obj):
eval_output = model.evaluate(eval_input_dataset,
steps=num_eval_steps,
verbose=1)
stats = keras_common.build_stats(history, eval_output)
stats = keras_common.build_stats(history, eval_output, time_callback)
return stats
def main(_):
with logger.benchmark_context(flags.FLAGS):
run(flags.FLAGS)
return run(flags.FLAGS)
if __name__ == '__main__':
......
......@@ -28,16 +28,23 @@ import tensorflow as tf
from tensorflow.python.keras.optimizer_v2 import (gradient_descent as
gradient_descent_v2)
FLAGS = flags.FLAGS
BASE_LEARNING_RATE = 0.1 # This matches Jing's version.
TRAIN_TOP_1 = 'training_accuracy_top_1'
class BatchTimestamp(object):
"""A structure to store batch time stamp."""
def __init__(self, batch_index, timestamp):
self.batch_index = batch_index
self.timestamp = timestamp
class TimeHistory(tf.keras.callbacks.Callback):
"""Callback for Keras models."""
def __init__(self, batch_size):
def __init__(self, batch_size, log_steps):
"""Callback for logging performance (# image/second).
Args:
......@@ -46,23 +53,34 @@ class TimeHistory(tf.keras.callbacks.Callback):
"""
self._batch_size = batch_size
super(TimeHistory, self).__init__()
self.log_steps = 100
self.log_steps = log_steps
# has stats for all batches
self.batch_start_timestamps = []
# only has stats for batch_index % log_steps == 0 (excluding 0)
self.batch_end_timestamps = []
def on_train_begin(self, logs=None):
self.record_batch = True
def on_train_end(self, logs=None):
self.train_finish_time = time.time()
def on_batch_begin(self, batch, logs=None):
if self.record_batch:
self.start_time = time.time()
timestamp = time.time()
self.start_time = timestamp
self.record_batch = False
self.batch_start_timestamps.append(BatchTimestamp(batch, timestamp))
def on_batch_end(self, batch, logs=None):
if batch % self.log_steps == 0:
elapsed_time = time.time() - self.start_time
timestamp = time.time()
elapsed_time = timestamp - self.start_time
examples_per_second = (self._batch_size * self.log_steps) / elapsed_time
self.record_batch = True
# TODO(anjalisridhar): add timestamp as well.
if batch != 0:
self.batch_end_timestamps.append(BatchTimestamp(batch, timestamp))
tf.logging.info("BenchmarkMetric: {'num_batches':%d, 'time_taken': %f,"
"'images_per_second': %f}" %
(batch, elapsed_time, examples_per_second))
......@@ -115,7 +133,7 @@ def get_optimizer():
def get_callbacks(learning_rate_schedule_fn, num_images):
"""Returns common callbacks."""
time_callback = TimeHistory(FLAGS.batch_size)
time_callback = TimeHistory(FLAGS.batch_size, FLAGS.log_steps)
tensorboard_callback = tf.keras.callbacks.TensorBoard(
log_dir=FLAGS.model_dir)
......@@ -128,7 +146,7 @@ def get_callbacks(learning_rate_schedule_fn, num_images):
return time_callback, tensorboard_callback, lr_callback
def build_stats(history, eval_output):
def build_stats(history, eval_output, time_callback):
"""Normalizes and returns dictionary of stats.
Args:
......@@ -144,6 +162,7 @@ def build_stats(history, eval_output):
if eval_output:
stats['accuracy_top_1'] = eval_output[1].item()
stats['eval_loss'] = eval_output[0].item()
if history and history.history:
train_hist = history.history
# Gets final loss from training.
......@@ -154,6 +173,11 @@ def build_stats(history, eval_output):
elif 'sparse_categorical_accuracy' in train_hist:
stats[TRAIN_TOP_1] = train_hist['sparse_categorical_accuracy'][-1].item()
if time_callback:
stats['batch_start_timestamps'] = time_callback.batch_start_timestamps
stats['batch_end_timestamps'] = time_callback.batch_end_timestamps
stats['train_finish_time'] = time_callback.train_finish_time
return stats
......@@ -165,6 +189,11 @@ def define_keras_flags():
help='The number of steps to run for training. If it is larger than '
'# batches per epoch, then use # batches per epoch. When this flag is '
'set, only one epoch is going to run for training.')
flags.DEFINE_integer(
name='log_steps', default=100,
help='For every log_steps, we log the timing information such as '
'examples per second. Besides, for every log_steps, we store the '
'timestamp of a batch end.')
def get_synth_input_fn(height, width, num_channels, num_classes,
......
......@@ -36,7 +36,12 @@ class KerasCommonTests(tf.test.TestCase):
history = self._build_history(1.145, cat_accuracy=.99988)
eval_output = self._build_eval_output(.56432111, 5.990)
stats = keras_common.build_stats(history, eval_output)
th = keras_common.TimeHistory(128, 100)
th.batch_start_timestamps = [1, 2, 3]
th.batch_end_timestamps = [4, 5, 6]
th.train_finish_time = 12345
stats = keras_common.build_stats(history, eval_output, th)
self.assertEqual(1.145, stats['loss'])
self.assertEqual(.99988, stats['training_accuracy_top_1'])
......@@ -44,11 +49,15 @@ class KerasCommonTests(tf.test.TestCase):
self.assertEqual(.56432111, stats['accuracy_top_1'])
self.assertEqual(5.990, stats['eval_loss'])
self.assertItemsEqual([1, 2, 3], stats['batch_start_timestamps'])
self.assertItemsEqual([4, 5, 6], stats['batch_end_timestamps'])
self.assertEqual(12345, stats['train_finish_time'])
def test_build_stats_sparse(self):
history = self._build_history(1.145, cat_accuracy_sparse=.99988)
eval_output = self._build_eval_output(.928, 1.9844)
stats = keras_common.build_stats(history, eval_output)
stats = keras_common.build_stats(history, eval_output, None)
self.assertEqual(1.145, stats['loss'])
self.assertEqual(.99988, stats['training_accuracy_top_1'])
......@@ -56,6 +65,36 @@ class KerasCommonTests(tf.test.TestCase):
self.assertEqual(.928, stats['accuracy_top_1'])
self.assertEqual(1.9844, stats['eval_loss'])
def test_time_history(self):
th = keras_common.TimeHistory(batch_size=128, log_steps=3)
th.on_train_begin()
th.on_batch_begin(0)
th.on_batch_end(0)
th.on_batch_begin(1)
th.on_batch_end(1)
th.on_batch_begin(2)
th.on_batch_end(2)
th.on_batch_begin(3)
th.on_batch_end(3)
th.on_batch_begin(4)
th.on_batch_end(4)
th.on_batch_begin(5)
th.on_batch_end(5)
th.on_batch_begin(6)
th.on_batch_end(6)
th.on_train_end()
self.assertEqual(3, len(th.batch_start_timestamps))
self.assertEqual(2, len(th.batch_end_timestamps))
self.assertEqual(0, th.batch_start_timestamps[0].batch_index)
self.assertEqual(1, th.batch_start_timestamps[1].batch_index)
self.assertEqual(4, th.batch_start_timestamps[2].batch_index)
self.assertEqual(3, th.batch_end_timestamps[0].batch_index)
self.assertEqual(6, th.batch_end_timestamps[1].batch_index)
def _build_history(self, loss, cat_accuracy=None,
cat_accuracy_sparse=None):
history_p = Mock()
......
......@@ -169,23 +169,23 @@ def run(flags_obj):
time_callback,
lr_callback,
tensorboard_callback
],
],
validation_steps=num_eval_steps,
validation_data=validation_data,
verbose=1)
eval_output = None
if not flags_obj.skip_eval:
eval_output = model.evaluate(eval_input_dataset,
steps=num_eval_steps,
verbose=1)
stats = keras_common.build_stats(history, eval_output)
stats = keras_common.build_stats(history, eval_output, time_callback)
return stats
def main(_):
with logger.benchmark_context(flags.FLAGS):
run(flags.FLAGS)
return run(flags.FLAGS)
if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment