Unverified Commit 3f94db4e authored by Haoyu Zhang's avatar Haoyu Zhang Committed by GitHub
Browse files

Add profiler callback for Keras models (#6528)

* Add profiler callback for Keras models

* Update build stats to identify time callback by type

* Add warning message when both TensorBoard and profiler callbacks are used
parent 7467ccde
...@@ -161,7 +161,7 @@ def run(flags_obj): ...@@ -161,7 +161,7 @@ def run(flags_obj):
optimizer=optimizer, optimizer=optimizer,
metrics=['categorical_accuracy']) metrics=['categorical_accuracy'])
time_callback, tensorboard_callback, lr_callback = keras_common.get_callbacks( callbacks = keras_common.get_callbacks(
learning_rate_schedule, cifar_main.NUM_IMAGES['train']) learning_rate_schedule, cifar_main.NUM_IMAGES['train'])
train_steps = cifar_main.NUM_IMAGES['train'] // flags_obj.batch_size train_steps = cifar_main.NUM_IMAGES['train'] // flags_obj.batch_size
...@@ -180,10 +180,6 @@ def run(flags_obj): ...@@ -180,10 +180,6 @@ def run(flags_obj):
num_eval_steps = None num_eval_steps = None
validation_data = None validation_data = None
callbacks = [time_callback, lr_callback]
if flags_obj.enable_tensorboard:
callbacks.append(tensorboard_callback)
history = model.fit(train_input_dataset, history = model.fit(train_input_dataset,
epochs=train_epochs, epochs=train_epochs,
steps_per_epoch=train_steps, steps_per_epoch=train_steps,
...@@ -197,7 +193,7 @@ def run(flags_obj): ...@@ -197,7 +193,7 @@ def run(flags_obj):
eval_output = model.evaluate(eval_input_dataset, eval_output = model.evaluate(eval_input_dataset,
steps=num_eval_steps, steps=num_eval_steps,
verbose=2) verbose=2)
stats = keras_common.build_stats(history, eval_output, time_callback) stats = keras_common.build_stats(history, eval_output, callbacks)
return stats return stats
......
...@@ -30,6 +30,7 @@ import tensorflow as tf ...@@ -30,6 +30,7 @@ import tensorflow as tf
from official.utils.misc import keras_utils from official.utils.misc import keras_utils
# pylint: disable=ungrouped-imports # pylint: disable=ungrouped-imports
from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.eager import profiler
from tensorflow.python.keras.optimizer_v2 import (gradient_descent as from tensorflow.python.keras.optimizer_v2 import (gradient_descent as
gradient_descent_v2) gradient_descent_v2)
...@@ -78,6 +79,29 @@ class LearningRateBatchScheduler(tf.keras.callbacks.Callback): ...@@ -78,6 +79,29 @@ class LearningRateBatchScheduler(tf.keras.callbacks.Callback):
'change learning rate to %s.', self.epochs, batch, lr) 'change learning rate to %s.', self.epochs, batch, lr)
class ProfilerCallback(tf.keras.callbacks.Callback):
"""Save profiles in specified step range to log directory."""
def __init__(self, log_dir, start_step, stop_step):
super(ProfilerCallback, self).__init__()
self.log_dir = log_dir
self.start_step = start_step
self.stop_step = stop_step
def on_batch_begin(self, batch, logs=None):
if batch == self.start_step:
profiler.start()
tf.compat.v1.logging.info('Profiler started at Step %s', self.start_step)
def on_batch_end(self, batch, logs=None):
if batch == self.stop_step:
results = profiler.stop()
profiler.save(self.log_dir, results)
tf.compat.v1.logging.info(
'Profiler saved profiles for steps between %s and %s to %s',
self.start_step, self.stop_step, self.log_dir)
def get_config_proto_v1(): def get_config_proto_v1():
"""Return config proto according to flag settings, or None to use default.""" """Return config proto according to flag settings, or None to use default."""
config = None config = None
...@@ -143,19 +167,50 @@ def get_optimizer(): ...@@ -143,19 +167,50 @@ def get_optimizer():
def get_callbacks(learning_rate_schedule_fn, num_images): def get_callbacks(learning_rate_schedule_fn, num_images):
"""Returns common callbacks.""" """Returns common callbacks."""
time_callback = keras_utils.TimeHistory(FLAGS.batch_size, FLAGS.log_steps) time_callback = keras_utils.TimeHistory(FLAGS.batch_size, FLAGS.log_steps)
tensorboard_callback = tf.keras.callbacks.TensorBoard(
log_dir=FLAGS.model_dir)
lr_callback = LearningRateBatchScheduler( lr_callback = LearningRateBatchScheduler(
learning_rate_schedule_fn, learning_rate_schedule_fn,
batch_size=FLAGS.batch_size, batch_size=FLAGS.batch_size,
num_images=num_images) num_images=num_images)
callbacks = [time_callback, lr_callback]
return time_callback, tensorboard_callback, lr_callback
if FLAGS.enable_tensorboard:
tensorboard_callback = tf.keras.callbacks.TensorBoard(
def build_stats(history, eval_output, time_callback): log_dir=FLAGS.model_dir)
callbacks.append(tensorboard_callback)
if FLAGS.profile_steps:
profiler_callback = get_profiler_callback()
callbacks.append(profiler_callback)
return callbacks
def get_profiler_callback():
"""Validate profile_steps flag value and return profiler callback."""
profile_steps_error_message = (
'profile_steps must be a comma separated pair of positive integers, '
'specifying the first and last steps to be profiled.'
)
try:
profile_steps = [int(i) for i in FLAGS.profile_steps.split(',')]
except ValueError:
raise ValueError(profile_steps_error_message)
if len(profile_steps) != 2:
raise ValueError(profile_steps_error_message)
start_step, stop_step = profile_steps
if start_step < 0 or start_step > stop_step:
raise ValueError(profile_steps_error_message)
if FLAGS.enable_tensorboard:
tf.compat.v1.logging.warn(
'Both TensorBoard and profiler callbacks are used. Note that the '
'TensorBoard callback profiles the 2nd step (unless otherwise '
'specified). Please make sure the steps profiled by the two callbacks '
'do not overlap.')
return ProfilerCallback(FLAGS.model_dir, start_step, stop_step)
def build_stats(history, eval_output, callbacks):
"""Normalizes and returns dictionary of stats. """Normalizes and returns dictionary of stats.
Args: Args:
...@@ -163,7 +218,8 @@ def build_stats(history, eval_output, time_callback): ...@@ -163,7 +218,8 @@ def build_stats(history, eval_output, time_callback):
and sparse_categorical_accuracy. and sparse_categorical_accuracy.
eval_output: Output of the eval step. Assumes first value is eval_loss and eval_output: Output of the eval step. Assumes first value is eval_loss and
second value is accuracy_top_1. second value is accuracy_top_1.
time_callback: Time tracking callback likely used during keras.fit. callbacks: a list of callbacks which might include a time history callback
used during keras.fit.
Returns: Returns:
Dictionary of normalized results. Dictionary of normalized results.
...@@ -183,16 +239,20 @@ def build_stats(history, eval_output, time_callback): ...@@ -183,16 +239,20 @@ def build_stats(history, eval_output, time_callback):
elif 'sparse_categorical_accuracy' in train_hist: elif 'sparse_categorical_accuracy' in train_hist:
stats[TRAIN_TOP_1] = train_hist['sparse_categorical_accuracy'][-1].item() stats[TRAIN_TOP_1] = train_hist['sparse_categorical_accuracy'][-1].item()
if time_callback: if not callbacks:
timestamp_log = time_callback.timestamp_log return stats
stats['step_timestamp_log'] = timestamp_log
stats['train_finish_time'] = time_callback.train_finish_time # Look for the time history callback which was used during keras.fit
if len(timestamp_log) > 1: for callback in callbacks:
stats['avg_exp_per_second'] = ( if isinstance(callback, keras_utils.TimeHistory):
time_callback.batch_size * time_callback.log_steps * timestamp_log = callback.timestamp_log
(len(time_callback.timestamp_log)-1) / stats['step_timestamp_log'] = timestamp_log
(timestamp_log[-1].timestamp - timestamp_log[0].timestamp)) stats['train_finish_time'] = callback.train_finish_time
if len(timestamp_log) > 1:
stats['avg_exp_per_second'] = (
callback.batch_size * callback.log_steps *
(len(callback.timestamp_log)-1) /
(timestamp_log[-1].timestamp - timestamp_log[0].timestamp))
return stats return stats
...@@ -215,11 +275,14 @@ def define_keras_flags(): ...@@ -215,11 +275,14 @@ def define_keras_flags():
help='The number of steps to run for training. If it is larger than ' help='The number of steps to run for training. If it is larger than '
'# batches per epoch, then use # batches per epoch. When this flag is ' '# batches per epoch, then use # batches per epoch. When this flag is '
'set, only one epoch is going to run for training.') 'set, only one epoch is going to run for training.')
flags.DEFINE_boolean( flags.DEFINE_string(
name='enable_e2e_xprof', default=False, name='profile_steps', default=None,
help='Save end-to-end profiling data to model dir using Xprof. Profiling ' help='Save profiling data to model dir at given range of steps. The '
'has an overhead on both computation and memory usage, and can generate ' 'value must be a comma separated pair of positive integers, specifying '
'gigantic files when profiling a lot of steps.') 'the first and last step to profile. For example, "--profile_steps=2,4" '
'triggers the profiler to process 3 steps, starting from the 2nd step. '
'Note that profiler has a non-trivial performance overhead, and the '
'output file can be gigantic if profiling many steps.')
def get_synth_input_fn(height, width, num_channels, num_classes, def get_synth_input_fn(height, width, num_channels, num_classes,
......
...@@ -22,7 +22,6 @@ from absl import app as absl_app ...@@ -22,7 +22,6 @@ from absl import app as absl_app
from absl import flags from absl import flags
import tensorflow as tf # pylint: disable=g-bad-import-order import tensorflow as tf # pylint: disable=g-bad-import-order
from tensorflow.python.eager import profiler
from official.resnet import imagenet_main from official.resnet import imagenet_main
from official.resnet.keras import keras_common from official.resnet.keras import keras_common
from official.resnet.keras import resnet_model from official.resnet.keras import resnet_model
...@@ -177,7 +176,7 @@ def run(flags_obj): ...@@ -177,7 +176,7 @@ def run(flags_obj):
optimizer=optimizer, optimizer=optimizer,
metrics=['sparse_categorical_accuracy']) metrics=['sparse_categorical_accuracy'])
time_callback, tensorboard_callback, lr_callback = keras_common.get_callbacks( callbacks = keras_common.get_callbacks(
learning_rate_schedule, imagenet_main.NUM_IMAGES['train']) learning_rate_schedule, imagenet_main.NUM_IMAGES['train'])
train_steps = imagenet_main.NUM_IMAGES['train'] // flags_obj.batch_size train_steps = imagenet_main.NUM_IMAGES['train'] // flags_obj.batch_size
...@@ -199,12 +198,6 @@ def run(flags_obj): ...@@ -199,12 +198,6 @@ def run(flags_obj):
num_eval_steps = None num_eval_steps = None
validation_data = None validation_data = None
callbacks = [time_callback, lr_callback]
if flags_obj.enable_tensorboard:
callbacks.append(tensorboard_callback)
if flags_obj.enable_e2e_xprof:
profiler.start()
history = model.fit(train_input_dataset, history = model.fit(train_input_dataset,
epochs=train_epochs, epochs=train_epochs,
steps_per_epoch=train_steps, steps_per_epoch=train_steps,
...@@ -214,16 +207,12 @@ def run(flags_obj): ...@@ -214,16 +207,12 @@ def run(flags_obj):
validation_freq=flags_obj.epochs_between_evals, validation_freq=flags_obj.epochs_between_evals,
verbose=2) verbose=2)
if flags_obj.enable_e2e_xprof:
results = profiler.stop()
profiler.save(flags_obj.model_dir, results)
eval_output = None eval_output = None
if not flags_obj.skip_eval: if not flags_obj.skip_eval:
eval_output = model.evaluate(eval_input_dataset, eval_output = model.evaluate(eval_input_dataset,
steps=num_eval_steps, steps=num_eval_steps,
verbose=2) verbose=2)
stats = keras_common.build_stats(history, eval_output, time_callback) stats = keras_common.build_stats(history, eval_output, callbacks)
return stats return stats
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment