".github/git@developer.sourcefind.cn:change/sglang.git" did not exist on "3289c1207d238fada989ca1dcd3b948befb870be"
Commit b60dc237 authored by Will Cromar's avatar Will Cromar Committed by A. Unique TensorFlower
Browse files

Write examples/second and steps/second summaries in TimeHistory callback.

PiperOrigin-RevId: 296507807
parent 706a0bd9
......@@ -298,13 +298,16 @@ class EpochHelper(object):
self._epoch_steps = epoch_steps
self._global_step = global_step
self._current_epoch = None
self._epoch_start_step = None
self._in_epoch = False
def epoch_begin(self):
"""Returns whether a new epoch should begin."""
if self._in_epoch:
return False
self._current_epoch = self._global_step.numpy() / self._epoch_steps
current_step = self._global_step.numpy()
self._epoch_start_step = current_step
self._current_epoch = current_step // self._epoch_steps
self._in_epoch = True
return True
......@@ -313,13 +316,18 @@ class EpochHelper(object):
if not self._in_epoch:
raise ValueError("`epoch_end` can only be called inside an epoch")
current_step = self._global_step.numpy()
epoch = current_step / self._epoch_steps
epoch = current_step // self._epoch_steps
if epoch > self._current_epoch:
self._in_epoch = False
return True
return False
@property
def batch_index(self):
"""Index of the next batch within the current epoch."""
return self._global_step.numpy() - self._epoch_start_step
@property
def current_epoch(self):
return self._current_epoch
......@@ -44,17 +44,28 @@ class BatchTimestamp(object):
class TimeHistory(tf.keras.callbacks.Callback):
"""Callback for Keras models."""
def __init__(self, batch_size, log_steps):
def __init__(self, batch_size, log_steps, logdir=None):
"""Callback for logging performance.
Args:
batch_size: Total batch size.
log_steps: Interval of steps between logging of batch level stats.
logdir: Optional directory to write TensorBoard summaries.
"""
# TODO(wcromar): remove this parameter and rely on `logs` parameter of
# on_train_batch_end()
self.batch_size = batch_size
super(TimeHistory, self).__init__()
self.log_steps = log_steps
self.global_steps = 0
self.last_log_step = 0
self.steps_before_epoch = 0
self.steps_in_epoch = 0
self.start_time = None
if logdir:
self.summary_writer = tf.summary.create_file_writer(logdir)
else:
self.summary_writer = None
# Logs start of step 1 then end of each step based on log_steps interval.
self.timestamp_log = []
......@@ -62,38 +73,70 @@ class TimeHistory(tf.keras.callbacks.Callback):
# Records the time each epoch takes to run from start to finish of epoch.
self.epoch_runtime_log = []
@property
def global_steps(self):
"""The current 1-indexed global step."""
return self.steps_before_epoch + self.steps_in_epoch
@property
def average_steps_per_second(self):
"""The average training steps per second across all epochs."""
return self.global_steps / sum(self.epoch_runtime_log)
@property
def average_examples_per_second(self):
"""The average number of training examples per second across all epochs."""
return self.average_steps_per_second * self.batch_size
def on_train_end(self, logs=None):
self.train_finish_time = time.time()
if self.summary_writer:
self.summary_writer.flush()
def on_epoch_begin(self, epoch, logs=None):
self.epoch_start = time.time()
def on_batch_begin(self, batch, logs=None):
self.global_steps += 1
if self.global_steps == 1:
if not self.start_time:
self.start_time = time.time()
# Record the timestamp of the first global step
if not self.timestamp_log:
self.timestamp_log.append(BatchTimestamp(self.global_steps,
self.start_time))
def on_batch_end(self, batch, logs=None):
"""Records elapse time of the batch and calculates examples per second."""
if self.global_steps % self.log_steps == 0:
timestamp = time.time()
elapsed_time = timestamp - self.start_time
examples_per_second = (self.batch_size * self.log_steps) / elapsed_time
self.timestamp_log.append(BatchTimestamp(self.global_steps, timestamp))
self.steps_in_epoch = batch + 1
steps_since_last_log = self.global_steps - self.last_log_step
if steps_since_last_log >= self.log_steps:
now = time.time()
elapsed_time = now - self.start_time
steps_per_second = steps_since_last_log / elapsed_time
examples_per_second = steps_per_second * self.batch_size
self.timestamp_log.append(BatchTimestamp(self.global_steps, now))
logging.info(
"BenchmarkMetric: {'global step':%d, 'time_taken': %f,"
"'examples_per_second': %f}",
self.global_steps, elapsed_time, examples_per_second)
self.start_time = timestamp
"TimeHistory: %.2f examples/second between steps %d and %d",
examples_per_second, self.last_log_step, self.global_steps)
if self.summary_writer:
with self.summary_writer.as_default():
tf.summary.scalar('global_step/sec', steps_per_second,
self.global_steps)
tf.summary.scalar('examples/sec', examples_per_second,
self.global_steps)
self.last_log_step = self.global_steps
self.start_time = None
def on_epoch_end(self, epoch, logs=None):
epoch_run_time = time.time() - self.epoch_start
self.epoch_runtime_log.append(epoch_run_time)
logging.info(
"BenchmarkMetric: {'epoch':%d, 'time_taken': %f}",
epoch, epoch_run_time)
self.steps_before_epoch += self.steps_in_epoch
self.steps_in_epoch = 0
def get_profiler_callback(model_dir, profile_steps, enable_tensorboard,
......
......@@ -188,7 +188,10 @@ def get_callbacks(
enable_checkpoint_and_export=False,
model_dir=None):
"""Returns common callbacks."""
time_callback = keras_utils.TimeHistory(FLAGS.batch_size, FLAGS.log_steps)
time_callback = keras_utils.TimeHistory(
FLAGS.batch_size,
FLAGS.log_steps,
logdir=FLAGS.model_dir if FLAGS.enable_tensorboard else None)
callbacks = [time_callback]
if not FLAGS.use_tensor_lr and learning_rate_schedule_fn:
......@@ -265,11 +268,9 @@ def build_stats(history, eval_output, callbacks):
timestamp_log = callback.timestamp_log
stats['step_timestamp_log'] = timestamp_log
stats['train_finish_time'] = callback.train_finish_time
if len(timestamp_log) > 1:
stats['avg_exp_per_second'] = (
callback.batch_size * callback.log_steps *
(len(callback.timestamp_log)-1) /
(timestamp_log[-1].timestamp - timestamp_log[0].timestamp))
if callback.epoch_runtime_log:
stats['avg_exp_per_second'] = callback.average_examples_per_second
return stats
......
......@@ -64,15 +64,8 @@ def build_stats(runnable, time_callback):
timestamp_log = time_callback.timestamp_log
stats['step_timestamp_log'] = timestamp_log
stats['train_finish_time'] = time_callback.train_finish_time
if len(timestamp_log) > 1:
stats['avg_exp_per_second'] = (
time_callback.batch_size * time_callback.log_steps *
(len(time_callback.timestamp_log) - 1) /
(timestamp_log[-1].timestamp - timestamp_log[0].timestamp))
avg_exp_per_second = tf.reduce_mean(
runnable.examples_per_second_history).numpy(),
stats['avg_exp_per_second'] = avg_exp_per_second
if time_callback.epoch_runtime_log:
stats['avg_exp_per_second'] = time_callback.average_examples_per_second
return stats
......@@ -154,8 +147,10 @@ def run(flags_obj):
'total steps: %d; Eval %d steps', train_epochs, per_epoch_steps,
train_epochs * per_epoch_steps, eval_steps)
time_callback = keras_utils.TimeHistory(flags_obj.batch_size,
flags_obj.log_steps)
time_callback = keras_utils.TimeHistory(
flags_obj.batch_size,
flags_obj.log_steps,
logdir=flags_obj.model_dir if flags_obj.enable_tensorboard else None)
with distribution_utils.get_strategy_scope(strategy):
runnable = resnet_runnable.ResnetRunnable(flags_obj, time_callback,
per_epoch_steps)
......
......@@ -114,7 +114,6 @@ class ResnetRunnable(standard_runnable.StandardTrainable,
# Handling epochs.
self.epoch_steps = epoch_steps
self.epoch_helper = utils.EpochHelper(epoch_steps, self.global_step)
self.examples_per_second_history = []
def build_train_dataset(self):
"""See base class."""
......@@ -147,8 +146,8 @@ class ResnetRunnable(standard_runnable.StandardTrainable,
self.train_loss.reset_states()
self.train_accuracy.reset_states()
self.time_callback.on_batch_begin(self.global_step)
self._epoch_begin()
self.time_callback.on_batch_begin(self.epoch_helper.batch_index)
def train_step(self, iterator):
"""See base class."""
......@@ -194,12 +193,13 @@ class ResnetRunnable(standard_runnable.StandardTrainable,
def train_loop_end(self):
"""See base class."""
self.time_callback.on_batch_end(self.global_step)
self._epoch_end()
return {
metrics = {
'train_loss': self.train_loss.result(),
'train_accuracy': self.train_accuracy.result(),
}
self.time_callback.on_batch_end(self.epoch_helper.batch_index - 1)
self._epoch_end()
return metrics
def eval_begin(self):
"""See base class."""
......@@ -234,10 +234,3 @@ class ResnetRunnable(standard_runnable.StandardTrainable,
def _epoch_end(self):
if self.epoch_helper.epoch_end():
self.time_callback.on_epoch_end(self.epoch_helper.current_epoch)
epoch_time = self.time_callback.epoch_runtime_log[-1]
steps_per_second = self.epoch_steps / epoch_time
examples_per_second = steps_per_second * self.flags_obj.batch_size
self.examples_per_second_history.append(examples_per_second)
tf.summary.scalar('global_step/sec', steps_per_second)
tf.summary.scalar('examples/sec', examples_per_second)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment