Commit 31da2245 authored by Jeremiah Harmsen's avatar Jeremiah Harmsen Committed by A. Unique TensorFlower
Browse files

Add calling of custom callback on_epoch_begin() and on_epoch_end() methods.

Minor refactoring to (1) encapsulate callbacks in a CallbackList and (2) allow evaluation results to be threaded to on_epoch_end callback.

PiperOrigin-RevId: 306827406
parent 13dd0f7f
...@@ -153,7 +153,8 @@ def run_customized_training_loop( ...@@ -153,7 +153,8 @@ def run_customized_training_loop(
`model_fn`. `model_fn`.
custom_callbacks: A list of Keras Callbacks objects to run during custom_callbacks: A list of Keras Callbacks objects to run during
training. More specifically, `on_batch_begin()`, `on_batch_end()`, training. More specifically, `on_batch_begin()`, `on_batch_end()`,
methods are invoked during training. `on_epoch_begin()`, `on_epoch_end()` methods are invoked during
training. Note that some metrics may be missing from `logs`.
run_eagerly: Whether to run model training in pure eager execution. This run_eagerly: Whether to run model training in pure eager execution. This
should be disable for TPUStrategy. should be disable for TPUStrategy.
sub_model_export_name: If not None, will export `sub_model` returned by sub_model_export_name: If not None, will export `sub_model` returned by
...@@ -225,6 +226,8 @@ def run_customized_training_loop( ...@@ -225,6 +226,8 @@ def run_customized_training_loop(
raise ValueError( raise ValueError(
'if `metric_fn` is specified, metric_fn must be a callable.') 'if `metric_fn` is specified, metric_fn must be a callable.')
callback_list = tf.keras.callbacks.CallbackList(custom_callbacks)
total_training_steps = steps_per_epoch * epochs total_training_steps = steps_per_epoch * epochs
train_iterator = _get_input_iterator(train_input_fn, strategy) train_iterator = _get_input_iterator(train_input_fn, strategy)
...@@ -361,32 +364,30 @@ def run_customized_training_loop( ...@@ -361,32 +364,30 @@ def run_customized_training_loop(
test_step = tf.function(test_step) test_step = tf.function(test_step)
def _run_evaluation(current_training_step, test_iterator): def _run_evaluation(current_training_step, test_iterator):
"""Runs validation steps and aggregate metrics.""" """Runs validation steps and aggregate metrics.
Args:
current_training_step: tf.int32 tensor containing the current step.
test_iterator: distributed iterator of test datasets.
Returns:
A dict of metic names and values.
"""
for _ in range(eval_steps): for _ in range(eval_steps):
test_step(test_iterator) test_step(test_iterator)
logs = {}
with eval_summary_writer.as_default(): with eval_summary_writer.as_default():
for metric in eval_metrics + model.metrics: for metric in eval_metrics + model.metrics:
metric_value = _float_metric_value(metric) metric_value = _float_metric_value(metric)
logs[metric.name] = metric_value
logging.info('Step: [%d] Validation %s = %f', current_training_step, logging.info('Step: [%d] Validation %s = %f', current_training_step,
metric.name, metric_value) metric.name, metric_value)
tf.summary.scalar( tf.summary.scalar(
metric.name, metric_value, step=current_training_step) metric.name, metric_value, step=current_training_step)
eval_summary_writer.flush() eval_summary_writer.flush()
def _run_callbacks_on_batch_begin(batch): return logs
"""Runs custom callbacks at the start of every step."""
if not custom_callbacks:
return
for callback in custom_callbacks:
callback.on_batch_begin(batch)
def _run_callbacks_on_batch_end(batch, logs):
"""Runs custom callbacks at the end of every step."""
if not custom_callbacks:
return
for callback in custom_callbacks:
callback.on_batch_end(batch, logs)
# Training loop starts here. # Training loop starts here.
checkpoint = tf.train.Checkpoint( checkpoint = tf.train.Checkpoint(
...@@ -407,13 +408,16 @@ def run_customized_training_loop( ...@@ -407,13 +408,16 @@ def run_customized_training_loop(
checkpoint_name = 'ctl_step_{step}.ckpt' checkpoint_name = 'ctl_step_{step}.ckpt'
while current_step < total_training_steps: while current_step < total_training_steps:
if current_step % steps_per_epoch == 0:
callback_list.on_epoch_begin(int(current_step / steps_per_epoch) + 1)
# Training loss/metric are taking average over steps inside micro # Training loss/metric are taking average over steps inside micro
# training loop. We reset the their values before each round. # training loop. We reset the their values before each round.
train_loss_metric.reset_states() train_loss_metric.reset_states()
for metric in train_metrics + model.metrics: for metric in train_metrics + model.metrics:
metric.reset_states() metric.reset_states()
_run_callbacks_on_batch_begin(current_step) callback_list.on_batch_begin(current_step)
# Runs several steps in the host while loop. # Runs several steps in the host while loop.
steps = steps_to_run(current_step, steps_per_epoch, steps_per_loop) steps = steps_to_run(current_step, steps_per_epoch, steps_per_loop)
...@@ -428,7 +432,7 @@ def run_customized_training_loop( ...@@ -428,7 +432,7 @@ def run_customized_training_loop(
tf.convert_to_tensor(steps, dtype=tf.int32)) tf.convert_to_tensor(steps, dtype=tf.int32))
train_loss = _float_metric_value(train_loss_metric) train_loss = _float_metric_value(train_loss_metric)
current_step += steps current_step += steps
_run_callbacks_on_batch_end(current_step - 1, {'loss': train_loss}) callback_list.on_batch_end(current_step - 1, {'loss': train_loss})
# Updates training logging. # Updates training logging.
training_status = 'Train Step: %d/%d / loss = %s' % ( training_status = 'Train Step: %d/%d / loss = %s' % (
...@@ -458,25 +462,30 @@ def run_customized_training_loop( ...@@ -458,25 +462,30 @@ def run_customized_training_loop(
if current_step < total_training_steps: if current_step < total_training_steps:
_save_checkpoint(strategy, checkpoint, model_dir, _save_checkpoint(strategy, checkpoint, model_dir,
checkpoint_name.format(step=current_step)) checkpoint_name.format(step=current_step))
logs = None
if eval_input_fn: if eval_input_fn:
logging.info('Running evaluation after step: %s.', current_step) logging.info('Running evaluation after step: %s.', current_step)
_run_evaluation(current_step, logs = _run_evaluation(current_step,
_get_input_iterator(eval_input_fn, strategy)) _get_input_iterator(eval_input_fn, strategy))
# Re-initialize evaluation metric. # Re-initialize evaluation metric.
for metric in eval_metrics + model.metrics: for metric in eval_metrics + model.metrics:
metric.reset_states() metric.reset_states()
callback_list.on_epoch_end(int(current_step / steps_per_epoch), logs)
if sub_model_export_name: if sub_model_export_name:
_save_checkpoint(strategy, sub_model_checkpoint, model_dir, _save_checkpoint(strategy, sub_model_checkpoint, model_dir,
'%s.ckpt' % sub_model_export_name) '%s.ckpt' % sub_model_export_name)
_save_checkpoint(strategy, checkpoint, model_dir, _save_checkpoint(strategy, checkpoint, model_dir,
checkpoint_name.format(step=current_step)) checkpoint_name.format(step=current_step))
logs = None
if eval_input_fn: if eval_input_fn:
logging.info('Running final evaluation after training is complete.') logging.info('Running final evaluation after training is complete.')
_run_evaluation(current_step, logs = _run_evaluation(current_step,
_get_input_iterator(eval_input_fn, strategy)) _get_input_iterator(eval_input_fn, strategy))
callback_list.on_epoch_end(int(current_step / steps_per_epoch), logs)
training_summary = { training_summary = {
'total_training_steps': total_training_steps, 'total_training_steps': total_training_steps,
......
...@@ -135,6 +135,27 @@ def check_eventfile_for_keyword(keyword, summary_dir): ...@@ -135,6 +135,27 @@ def check_eventfile_for_keyword(keyword, summary_dir):
return any(summaries_with_matching_keyword(keyword, summary_dir)) return any(summaries_with_matching_keyword(keyword, summary_dir))
class RecordingCallback(tf.keras.callbacks.Callback):
def __init__(self):
self.batch_begin = [] # (batch, logs)
self.batch_end = [] # (batch, logs)
self.epoch_begin = [] # (epoch, logs)
self.epoch_end = [] # (epoch, logs)
def on_batch_begin(self, batch, logs=None):
self.batch_begin.append((batch, logs))
def on_batch_end(self, batch, logs=None):
self.batch_end.append((batch, logs))
def on_epoch_begin(self, epoch, logs=None):
self.epoch_begin.append((epoch, logs))
def on_epoch_end(self, epoch, logs=None):
self.epoch_end.append((epoch, logs))
class ModelTrainingUtilsTest(tf.test.TestCase, parameterized.TestCase): class ModelTrainingUtilsTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self): def setUp(self):
...@@ -224,6 +245,41 @@ class ModelTrainingUtilsTest(tf.test.TestCase, parameterized.TestCase): ...@@ -224,6 +245,41 @@ class ModelTrainingUtilsTest(tf.test.TestCase, parameterized.TestCase):
check_eventfile_for_keyword('mean_input', check_eventfile_for_keyword('mean_input',
os.path.join(model_dir, 'summaries/eval'))) os.path.join(model_dir, 'summaries/eval')))
@combinations.generate(eager_strategy_combinations())
def test_train_check_callbacks(self, distribution):
model_dir = self.get_temp_dir()
callback = RecordingCallback()
callbacks = [callback]
input_fn = create_fake_data_input_fn(
batch_size=8, features_shape=[128], num_classes=3)
model_training_utils.run_customized_training_loop(
strategy=distribution,
model_fn=self._model_fn,
loss_fn=tf.keras.losses.categorical_crossentropy,
model_dir=model_dir,
steps_per_epoch=20,
steps_per_loop=10,
epochs=2,
train_input_fn=input_fn,
eval_input_fn=input_fn,
eval_steps=10,
init_checkpoint=None,
metric_fn=metric_fn,
custom_callbacks=callbacks,
run_eagerly=False)
self.assertEqual(callback.epoch_begin, [(1, {}), (2, {})])
epoch_ends, epoch_end_infos = zip(*callback.epoch_end)
self.assertEqual(list(epoch_ends), [1, 2])
for info in epoch_end_infos:
self.assertIn('accuracy', info)
self.assertEqual(callback.batch_begin,
[(0, {}), (10, {}), (20, {}), (30, {})])
batch_ends, batch_end_infos = zip(*callback.batch_end)
self.assertEqual(list(batch_ends), [9, 19, 29, 39])
for info in batch_end_infos:
self.assertIn('loss', info)
@combinations.generate( @combinations.generate(
combinations.combine( combinations.combine(
distribution=[ distribution=[
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment