Unverified Commit 3254cabb authored by Toby Boyd's avatar Toby Boyd Committed by GitHub
Browse files

Moved common keras code to utils. (#6859)

parent b9cab01b
......@@ -30,7 +30,6 @@ import tensorflow as tf
from official.utils.misc import keras_utils
# pylint: disable=ungrouped-imports
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.eager import profiler
from tensorflow.python.keras.optimizer_v2 import (gradient_descent as
gradient_descent_v2)
......@@ -146,29 +145,6 @@ class PiecewiseConstantDecayWithWarmup(
}
class ProfilerCallback(tf.keras.callbacks.Callback):
"""Save profiles in specified step range to log directory."""
def __init__(self, log_dir, start_step, stop_step):
super(ProfilerCallback, self).__init__()
self.log_dir = log_dir
self.start_step = start_step
self.stop_step = stop_step
def on_batch_begin(self, batch, logs=None):
if batch == self.start_step:
profiler.start()
tf.compat.v1.logging.info('Profiler started at Step %s', self.start_step)
def on_batch_end(self, batch, logs=None):
if batch == self.stop_step:
results = profiler.stop()
profiler.save(self.log_dir, results)
tf.compat.v1.logging.info(
'Profiler saved profiles for steps between %s and %s to %s',
self.start_step, self.stop_step, self.log_dir)
def get_config_proto_v1():
"""Return config proto according to flag settings, or None to use default."""
config = None
......@@ -250,37 +226,15 @@ def get_callbacks(learning_rate_schedule_fn, num_images):
callbacks.append(tensorboard_callback)
if FLAGS.profile_steps:
profiler_callback = get_profiler_callback()
profiler_callback = keras_utils.get_profiler_callback(
FLAGS.model_dir,
FLAGS.profile_steps,
FLAGS.enable_tensorboard)
callbacks.append(profiler_callback)
return callbacks
def get_profiler_callback():
"""Validate profile_steps flag value and return profiler callback."""
profile_steps_error_message = (
'profile_steps must be a comma separated pair of positive integers, '
'specifying the first and last steps to be profiled.'
)
try:
profile_steps = [int(i) for i in FLAGS.profile_steps.split(',')]
except ValueError:
raise ValueError(profile_steps_error_message)
if len(profile_steps) != 2:
raise ValueError(profile_steps_error_message)
start_step, stop_step = profile_steps
if start_step < 0 or start_step > stop_step:
raise ValueError(profile_steps_error_message)
if FLAGS.enable_tensorboard:
tf.compat.v1.logging.warn(
'Both TensorBoard and profiler callbacks are used. Note that the '
'TensorBoard callback profiles the 2nd step (unless otherwise '
'specified). Please make sure the steps profiled by the two callbacks '
'do not overlap.')
return ProfilerCallback(FLAGS.model_dir, start_step, stop_step)
def build_stats(history, eval_output, callbacks):
"""Normalizes and returns dictionary of stats.
......
......@@ -20,8 +20,8 @@ from __future__ import print_function
import time
from absl import flags
import tensorflow as tf
from tensorflow.python.eager import profiler
class BatchTimestamp(object):
......@@ -80,3 +80,51 @@ class TimeHistory(tf.keras.callbacks.Callback):
"BenchmarkMetric: {'num_batches':%d, 'time_taken': %f,"
"'examples_per_second': %f}" %
(batch, elapsed_time, examples_per_second))
def get_profiler_callback(model_dir, profile_steps, enable_tensorboard):
"""Validate profile_steps flag value and return profiler callback."""
profile_steps_error_message = (
'profile_steps must be a comma separated pair of positive integers, '
'specifying the first and last steps to be profiled.'
)
try:
profile_steps = [int(i) for i in profile_steps.split(',')]
except ValueError:
raise ValueError(profile_steps_error_message)
if len(profile_steps) != 2:
raise ValueError(profile_steps_error_message)
start_step, stop_step = profile_steps
if start_step < 0 or start_step > stop_step:
raise ValueError(profile_steps_error_message)
if enable_tensorboard:
tf.compat.v1.logging.warn(
'Both TensorBoard and profiler callbacks are used. Note that the '
'TensorBoard callback profiles the 2nd step (unless otherwise '
'specified). Please make sure the steps profiled by the two callbacks '
'do not overlap.')
return ProfilerCallback(model_dir, start_step, stop_step)
class ProfilerCallback(tf.keras.callbacks.Callback):
"""Save profiles in specified step range to log directory."""
def __init__(self, log_dir, start_step, stop_step):
super(ProfilerCallback, self).__init__()
self.log_dir = log_dir
self.start_step = start_step
self.stop_step = stop_step
def on_batch_begin(self, batch, logs=None):
if batch == self.start_step:
profiler.start()
tf.compat.v1.logging.info('Profiler started at Step %s', self.start_step)
def on_batch_end(self, batch, logs=None):
if batch == self.stop_step:
results = profiler.stop()
profiler.save(self.log_dir, results)
tf.compat.v1.logging.info(
'Profiler saved profiles for steps between %s and %s to %s',
self.start_step, self.stop_step, self.log_dir)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment