Unverified Commit 3f94db4e authored by Haoyu Zhang's avatar Haoyu Zhang Committed by GitHub
Browse files

Add profiler callback for Keras models (#6528)

* Add profiler callback for Keras models

* Update build stats to identify time callback by type

* Add warning message when both TensorBoard and profiler callbacks are used
parent 7467ccde
......@@ -161,7 +161,7 @@ def run(flags_obj):
optimizer=optimizer,
metrics=['categorical_accuracy'])
time_callback, tensorboard_callback, lr_callback = keras_common.get_callbacks(
callbacks = keras_common.get_callbacks(
learning_rate_schedule, cifar_main.NUM_IMAGES['train'])
train_steps = cifar_main.NUM_IMAGES['train'] // flags_obj.batch_size
......@@ -180,10 +180,6 @@ def run(flags_obj):
num_eval_steps = None
validation_data = None
callbacks = [time_callback, lr_callback]
if flags_obj.enable_tensorboard:
callbacks.append(tensorboard_callback)
history = model.fit(train_input_dataset,
epochs=train_epochs,
steps_per_epoch=train_steps,
......@@ -197,7 +193,7 @@ def run(flags_obj):
eval_output = model.evaluate(eval_input_dataset,
steps=num_eval_steps,
verbose=2)
stats = keras_common.build_stats(history, eval_output, time_callback)
stats = keras_common.build_stats(history, eval_output, callbacks)
return stats
......
......@@ -30,6 +30,7 @@ import tensorflow as tf
from official.utils.misc import keras_utils
# pylint: disable=ungrouped-imports
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.eager import profiler
from tensorflow.python.keras.optimizer_v2 import (gradient_descent as
gradient_descent_v2)
......@@ -78,6 +79,29 @@ class LearningRateBatchScheduler(tf.keras.callbacks.Callback):
'change learning rate to %s.', self.epochs, batch, lr)
class ProfilerCallback(tf.keras.callbacks.Callback):
"""Save profiles in specified step range to log directory."""
def __init__(self, log_dir, start_step, stop_step):
super(ProfilerCallback, self).__init__()
self.log_dir = log_dir
self.start_step = start_step
self.stop_step = stop_step
def on_batch_begin(self, batch, logs=None):
if batch == self.start_step:
profiler.start()
tf.compat.v1.logging.info('Profiler started at Step %s', self.start_step)
def on_batch_end(self, batch, logs=None):
if batch == self.stop_step:
results = profiler.stop()
profiler.save(self.log_dir, results)
tf.compat.v1.logging.info(
'Profiler saved profiles for steps between %s and %s to %s',
self.start_step, self.stop_step, self.log_dir)
def get_config_proto_v1():
"""Return config proto according to flag settings, or None to use default."""
config = None
......@@ -143,19 +167,50 @@ def get_optimizer():
def get_callbacks(learning_rate_schedule_fn, num_images):
"""Returns common callbacks."""
time_callback = keras_utils.TimeHistory(FLAGS.batch_size, FLAGS.log_steps)
tensorboard_callback = tf.keras.callbacks.TensorBoard(
log_dir=FLAGS.model_dir)
lr_callback = LearningRateBatchScheduler(
learning_rate_schedule_fn,
batch_size=FLAGS.batch_size,
num_images=num_images)
callbacks = [time_callback, lr_callback]
return time_callback, tensorboard_callback, lr_callback
if FLAGS.enable_tensorboard:
tensorboard_callback = tf.keras.callbacks.TensorBoard(
log_dir=FLAGS.model_dir)
callbacks.append(tensorboard_callback)
if FLAGS.profile_steps:
profiler_callback = get_profiler_callback()
callbacks.append(profiler_callback)
return callbacks
def build_stats(history, eval_output, time_callback):
def get_profiler_callback():
"""Validate profile_steps flag value and return profiler callback."""
profile_steps_error_message = (
'profile_steps must be a comma separated pair of positive integers, '
'specifying the first and last steps to be profiled.'
)
try:
profile_steps = [int(i) for i in FLAGS.profile_steps.split(',')]
except ValueError:
raise ValueError(profile_steps_error_message)
if len(profile_steps) != 2:
raise ValueError(profile_steps_error_message)
start_step, stop_step = profile_steps
if start_step < 0 or start_step > stop_step:
raise ValueError(profile_steps_error_message)
if FLAGS.enable_tensorboard:
tf.compat.v1.logging.warn(
'Both TensorBoard and profiler callbacks are used. Note that the '
'TensorBoard callback profiles the 2nd step (unless otherwise '
'specified). Please make sure the steps profiled by the two callbacks '
'do not overlap.')
return ProfilerCallback(FLAGS.model_dir, start_step, stop_step)
def build_stats(history, eval_output, callbacks):
"""Normalizes and returns dictionary of stats.
Args:
......@@ -163,7 +218,8 @@ def build_stats(history, eval_output, time_callback):
and sparse_categorical_accuracy.
eval_output: Output of the eval step. Assumes first value is eval_loss and
second value is accuracy_top_1.
time_callback: Time tracking callback likely used during keras.fit.
callbacks: a list of callbacks which might include a time history callback
used during keras.fit.
Returns:
Dictionary of normalized results.
......@@ -183,16 +239,20 @@ def build_stats(history, eval_output, time_callback):
elif 'sparse_categorical_accuracy' in train_hist:
stats[TRAIN_TOP_1] = train_hist['sparse_categorical_accuracy'][-1].item()
if time_callback:
timestamp_log = time_callback.timestamp_log
if not callbacks:
return stats
# Look for the time history callback which was used during keras.fit
for callback in callbacks:
if isinstance(callback, keras_utils.TimeHistory):
timestamp_log = callback.timestamp_log
stats['step_timestamp_log'] = timestamp_log
stats['train_finish_time'] = time_callback.train_finish_time
stats['train_finish_time'] = callback.train_finish_time
if len(timestamp_log) > 1:
stats['avg_exp_per_second'] = (
time_callback.batch_size * time_callback.log_steps *
(len(time_callback.timestamp_log)-1) /
callback.batch_size * callback.log_steps *
(len(callback.timestamp_log)-1) /
(timestamp_log[-1].timestamp - timestamp_log[0].timestamp))
return stats
......@@ -215,11 +275,14 @@ def define_keras_flags():
help='The number of steps to run for training. If it is larger than '
'# batches per epoch, then use # batches per epoch. When this flag is '
'set, only one epoch is going to run for training.')
flags.DEFINE_boolean(
name='enable_e2e_xprof', default=False,
help='Save end-to-end profiling data to model dir using Xprof. Profiling '
'has an overhead on both computation and memory usage, and can generate '
'gigantic files when profiling a lot of steps.')
flags.DEFINE_string(
name='profile_steps', default=None,
help='Save profiling data to model dir at given range of steps. The '
'value must be a comma separated pair of positive integers, specifying '
'the first and last step to profile. For example, "--profile_steps=2,4" '
'triggers the profiler to process 3 steps, starting from the 2nd step. '
'Note that profiler has a non-trivial performance overhead, and the '
'output file can be gigantic if profiling many steps.')
def get_synth_input_fn(height, width, num_channels, num_classes,
......
......@@ -22,7 +22,6 @@ from absl import app as absl_app
from absl import flags
import tensorflow as tf # pylint: disable=g-bad-import-order
from tensorflow.python.eager import profiler
from official.resnet import imagenet_main
from official.resnet.keras import keras_common
from official.resnet.keras import resnet_model
......@@ -177,7 +176,7 @@ def run(flags_obj):
optimizer=optimizer,
metrics=['sparse_categorical_accuracy'])
time_callback, tensorboard_callback, lr_callback = keras_common.get_callbacks(
callbacks = keras_common.get_callbacks(
learning_rate_schedule, imagenet_main.NUM_IMAGES['train'])
train_steps = imagenet_main.NUM_IMAGES['train'] // flags_obj.batch_size
......@@ -199,12 +198,6 @@ def run(flags_obj):
num_eval_steps = None
validation_data = None
callbacks = [time_callback, lr_callback]
if flags_obj.enable_tensorboard:
callbacks.append(tensorboard_callback)
if flags_obj.enable_e2e_xprof:
profiler.start()
history = model.fit(train_input_dataset,
epochs=train_epochs,
steps_per_epoch=train_steps,
......@@ -214,16 +207,12 @@ def run(flags_obj):
validation_freq=flags_obj.epochs_between_evals,
verbose=2)
if flags_obj.enable_e2e_xprof:
results = profiler.stop()
profiler.save(flags_obj.model_dir, results)
eval_output = None
if not flags_obj.skip_eval:
eval_output = model.evaluate(eval_input_dataset,
steps=num_eval_steps,
verbose=2)
stats = keras_common.build_stats(history, eval_output, time_callback)
stats = keras_common.build_stats(history, eval_output, callbacks)
return stats
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment