Commit febaae9a authored by Ruoxin Sang's avatar Ruoxin Sang Committed by A. Unique TensorFlower
Browse files

Some improvements and bug fixes to Controller:

1. Fix a bug that checkpoint will be saved after every training loop.
2. Only create the training and eval summaries writers if the corresponding `train_fn` and `eval_fn` are passed.
3. Flush the summary writers after training and eval finish.
4. Add a Controller test.

Also make sure there is no evaluation happening in Resnet CTL example if `skip_eval=True`.

PiperOrigin-RevId: 301489305
parent 101f1f04
...@@ -78,9 +78,10 @@ class Controller(object): ...@@ -78,9 +78,10 @@ class Controller(object):
eval_summary_dir: The directory to write eval summaries. If None, it will eval_summary_dir: The directory to write eval summaries. If None, it will
be set to `summary_dir`. be set to `summary_dir`.
eval_steps: Number of steps to run evaluation. eval_steps: Number of steps to run evaluation.
eval_interval: Step interval for evaluation. If None, will skip eval_interval: Step interval for evaluation. If None, will skip evaluation
evaluation. Note that evaluation only happens outside the training loop, in the middle of training. Note that evaluation only happens outside the
which the loop iteration is specify by `steps_per_loop` parameter. training loop, which the loop iteration is specify by `steps_per_loop`
parameter.
Raises: Raises:
ValueError: If both `train_fn` and `eval_fn` are None. ValueError: If both `train_fn` and `eval_fn` are None.
...@@ -111,35 +112,41 @@ class Controller(object): ...@@ -111,35 +112,41 @@ class Controller(object):
self.train_fn = train_fn self.train_fn = train_fn
self.eval_fn = eval_fn self.eval_fn = eval_fn
self.global_step = global_step self.global_step = global_step
self.train_steps = train_steps
self.steps_per_loop = steps_per_loop
self.summary_dir = summary_dir or checkpoint_manager.directory
self.checkpoint_manager = checkpoint_manager self.checkpoint_manager = checkpoint_manager
self.summary_interval = summary_interval if self.train_fn is not None:
summary_writer = tf.summary.create_file_writer( self.train_steps = train_steps
self.summary_dir) if self.summary_interval else None self.steps_per_loop = steps_per_loop
# TODO(rxsang): Consider pass SummaryManager directly into Controller for self.summary_dir = summary_dir or checkpoint_manager.directory
# maximum customizability.
self.summary_manager = utils.SummaryManager( self.summary_interval = summary_interval
summary_writer, summary_writer = tf.summary.create_file_writer(
tf.summary.scalar, self.summary_dir) if self.summary_interval else None
global_step=self.global_step, # TODO(rxsang): Consider pass SummaryManager directly into Controller for
summary_interval=self.summary_interval) # maximum customizability.
self.summary_manager = utils.SummaryManager(
summary_writer,
tf.summary.scalar,
global_step=self.global_step,
summary_interval=self.summary_interval)
if self.eval_fn is not None:
eval_summary_dir = eval_summary_dir or self.summary_dir
eval_summary_writer = tf.summary.create_file_writer(
eval_summary_dir) if eval_summary_dir else None
self.eval_summary_manager = utils.SummaryManager(
eval_summary_writer, tf.summary.scalar, global_step=self.global_step)
self.eval_steps = eval_steps
self.eval_interval = eval_interval
# Create and initialize the interval triggers.
self.eval_trigger = utils.IntervalTrigger(self.eval_interval,
self.global_step.numpy())
if self.global_step: if self.global_step:
tf.summary.experimental.set_step(self.global_step) tf.summary.experimental.set_step(self.global_step)
self.eval_summary_dir = eval_summary_dir or self.summary_dir
eval_summary_writer = tf.summary.create_file_writer(self.eval_summary_dir)
self.eval_summary_manager = utils.SummaryManager(
eval_summary_writer, tf.summary.scalar, global_step=self.global_step)
self.eval_steps = eval_steps
self.eval_interval = eval_interval
# Restore Model if needed. # Restore Model if needed.
if self.checkpoint_manager is not None: if self.checkpoint_manager is not None:
model_restored = self._restore_model() model_restored = self._restore_model()
...@@ -150,10 +157,6 @@ class Controller(object): ...@@ -150,10 +157,6 @@ class Controller(object):
checkpoint_number=self.global_step) checkpoint_number=self.global_step)
logging.info("Saved checkpoins in %s", ckpt_path) logging.info("Saved checkpoins in %s", ckpt_path)
# Create and initialize the interval triggers.
self.eval_trigger = utils.IntervalTrigger(self.eval_interval,
self.global_step.numpy())
def _restore_model(self, checkpoint_path=None): def _restore_model(self, checkpoint_path=None):
"""Restore or initialize the model. """Restore or initialize the model.
...@@ -186,11 +189,12 @@ class Controller(object): ...@@ -186,11 +189,12 @@ class Controller(object):
self._log_info(info) self._log_info(info)
self.eval_summary_manager.write_summaries(eval_outputs) self.eval_summary_manager.write_summaries(eval_outputs)
self.eval_summary_manager.flush()
def _maybe_save_checkpoints(self, current_step, force_trigger=False): def _maybe_save_checkpoints(self, current_step, force_trigger=False):
if self.checkpoint_manager.checkpoint_interval: if self.checkpoint_manager.checkpoint_interval:
ckpt_path = self.checkpoint_manager.save( ckpt_path = self.checkpoint_manager.save(
checkpoint_number=current_step, check_interval=force_trigger) checkpoint_number=current_step, check_interval=not force_trigger)
if ckpt_path is not None: if ckpt_path is not None:
logging.info("Saved checkpoins in %s", ckpt_path) logging.info("Saved checkpoins in %s", ckpt_path)
...@@ -265,6 +269,7 @@ class Controller(object): ...@@ -265,6 +269,7 @@ class Controller(object):
self._maybe_evaluate(current_step) self._maybe_evaluate(current_step)
self.summary_manager.write_summaries(train_outputs, always_write=True) self.summary_manager.write_summaries(train_outputs, always_write=True)
self.summary_manager.flush()
self._maybe_save_checkpoints(current_step, force_trigger=True) self._maybe_save_checkpoints(current_step, force_trigger=True)
if evaluate: if evaluate:
self._maybe_evaluate(current_step, force_trigger=True) self._maybe_evaluate(current_step, force_trigger=True)
......
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for official.staging.training.controller."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from official.staging.training import controller
from official.staging.training import standard_runnable
def all_strategy_combinations():
"""Gets combinations of distribution strategies."""
return combinations.combine(
strategy=[
strategy_combinations.one_device_strategy,
strategy_combinations.tpu_strategy,
strategy_combinations.one_device_strategy_gpu,
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
],
mode="eager",
)
def create_model():
x = tf.keras.layers.Input(shape=(3,), name="input")
y = tf.keras.layers.Dense(4, name="dense")(x)
model = tf.keras.Model(x, y)
return model
def summaries_with_matching_keyword(keyword, summary_dir):
"""Yields summary protos matching given keyword from event file."""
event_paths = tf.io.gfile.glob(os.path.join(summary_dir, "events*"))
for event in tf.compat.v1.train.summary_iterator(event_paths[-1]):
if event.summary is not None:
for value in event.summary.value:
if keyword in value.tag:
tf.compat.v1.logging.error(event)
yield event.summary
def check_eventfile_for_keyword(keyword, summary_dir):
"""Checks event files for the keyword."""
return any(summaries_with_matching_keyword(keyword, summary_dir))
def dataset_fn(ctx):
del ctx
inputs = np.zeros((10, 3), dtype=np.float32)
targets = np.zeros((10, 4), dtype=np.float32)
dataset = tf.data.Dataset.from_tensor_slices((inputs, targets))
dataset = dataset.repeat(100)
dataset = dataset.batch(10, drop_remainder=True)
return dataset
class TestRunnable(standard_runnable.StandardTrainable,
standard_runnable.StandardEvaluable):
"""Implements the training and evaluation APIs for the test model."""
def __init__(self):
standard_runnable.StandardTrainable.__init__(self)
standard_runnable.StandardEvaluable.__init__(self)
self.strategy = tf.distribute.get_strategy()
self.model = create_model()
self.optimizer = tf.keras.optimizers.RMSprop()
self.global_step = self.optimizer.iterations
self.train_loss = tf.keras.metrics.Mean("train_loss", dtype=tf.float32)
self.eval_loss = tf.keras.metrics.Mean("eval_loss", dtype=tf.float32)
def build_train_dataset(self):
return self.strategy.experimental_distribute_datasets_from_function(
dataset_fn)
def train_step(self, iterator):
def _replicated_step(inputs):
"""Replicated training step."""
inputs, targets = inputs
with tf.GradientTape() as tape:
outputs = self.model(inputs)
loss = tf.math.reduce_sum(outputs - targets)
grads = tape.gradient(loss, self.model.variables)
self.optimizer.apply_gradients(zip(grads, self.model.variables))
self.train_loss.update_state(loss)
self.strategy.run(_replicated_step, args=(next(iterator),))
def train_loop_end(self):
return {
"loss": self.train_loss.result(),
}
def build_eval_dataset(self):
return self.strategy.experimental_distribute_datasets_from_function(
dataset_fn)
def eval_begin(self):
self.eval_loss.reset_states()
def eval_step(self, iterator):
def _replicated_step(inputs):
"""Replicated evaluation step."""
inputs, targets = inputs
outputs = self.model(inputs)
loss = tf.math.reduce_sum(outputs - targets)
self.eval_loss.update_state(loss)
self.strategy.run(_replicated_step, args=(next(iterator),))
def eval_end(self):
return {
"eval_loss": self.eval_loss.result(),
}
class ControllerTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super(ControllerTest, self).setUp()
self.model_dir = self.get_temp_dir()
@combinations.generate(all_strategy_combinations())
def test_train_and_evaluate(self, strategy):
with strategy.scope():
test_runnable = TestRunnable()
checkpoint = tf.train.Checkpoint(
model=test_runnable.model, optimizer=test_runnable.optimizer)
checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
self.model_dir,
max_to_keep=None,
step_counter=test_runnable.global_step,
checkpoint_interval=10)
test_controller = controller.Controller(
strategy=strategy,
train_fn=test_runnable.train,
eval_fn=test_runnable.evaluate,
global_step=test_runnable.global_step,
train_steps=10,
steps_per_loop=2,
summary_dir=os.path.join(self.model_dir, "summaries/train"),
summary_interval=2,
checkpoint_manager=checkpoint_manager,
eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"),
eval_steps=2,
eval_interval=5)
test_controller.train(evaluate=True)
# Checkpoints are saved.
self.assertNotEmpty(tf.io.gfile.glob(os.path.join(self.model_dir, "ckpt*")))
# Loss and accuracy values should be written into summaries.
self.assertNotEmpty(
tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/train")))
self.assertTrue(
check_eventfile_for_keyword(
"loss", os.path.join(self.model_dir, "summaries/train")))
self.assertNotEmpty(
tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/eval")))
self.assertTrue(
check_eventfile_for_keyword(
"eval_loss", os.path.join(self.model_dir, "summaries/eval")))
@combinations.generate(all_strategy_combinations())
def test_train_only(self, strategy):
with strategy.scope():
test_runnable = TestRunnable()
checkpoint = tf.train.Checkpoint(
model=test_runnable.model, optimizer=test_runnable.optimizer)
checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
self.model_dir,
max_to_keep=None,
step_counter=test_runnable.global_step,
checkpoint_interval=10)
test_controller = controller.Controller(
strategy=strategy,
train_fn=test_runnable.train,
global_step=test_runnable.global_step,
train_steps=10,
steps_per_loop=2,
summary_dir=os.path.join(self.model_dir, "summaries/train"),
summary_interval=2,
checkpoint_manager=checkpoint_manager,
eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"),
)
test_controller.train(evaluate=False)
# Checkpoints are saved.
self.assertNotEmpty(tf.io.gfile.glob(os.path.join(self.model_dir, "ckpt*")))
# Only train summaries are written.
self.assertNotEmpty(
tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/train")))
self.assertTrue(
check_eventfile_for_keyword(
"loss", os.path.join(self.model_dir, "summaries/train")))
self.assertFalse(
tf.io.gfile.exists(os.path.join(self.model_dir, "summaries/eval")))
@combinations.generate(all_strategy_combinations())
def test_evaluate_only(self, strategy):
with strategy.scope():
test_runnable = TestRunnable()
checkpoint = tf.train.Checkpoint(model=test_runnable.model)
checkpoint.save(os.path.join(self.model_dir, "ckpt"))
checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
self.model_dir,
max_to_keep=None,
step_counter=test_runnable.global_step)
test_controller = controller.Controller(
strategy=strategy,
eval_fn=test_runnable.evaluate,
global_step=test_runnable.global_step,
checkpoint_manager=checkpoint_manager,
summary_dir=os.path.join(self.model_dir, "summaries/train"),
eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"),
eval_steps=2,
eval_interval=5)
test_controller.evaluate()
# Only eval summaries are written
self.assertFalse(
tf.io.gfile.exists(os.path.join(self.model_dir, "summaries/train")))
self.assertNotEmpty(
tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/eval")))
self.assertTrue(
check_eventfile_for_keyword(
"eval_loss", os.path.join(self.model_dir, "summaries/eval")))
if __name__ == "__main__":
tf.test.main()
...@@ -193,6 +193,11 @@ class SummaryManager(object): ...@@ -193,6 +193,11 @@ class SummaryManager(object):
"""Returns the underlying summary writer.""" """Returns the underlying summary writer."""
return self._summary_writer return self._summary_writer
def flush(self):
"""Flush the underlying summary writer."""
if self._enabled:
tf.summary.flush(self._summary_writer)
def write_summaries(self, items, always_write=True): def write_summaries(self, items, always_write=True):
"""Write a bulk of summaries. """Write a bulk of summaries.
......
...@@ -147,9 +147,7 @@ def run(flags_obj): ...@@ -147,9 +147,7 @@ def run(flags_obj):
runnable = resnet_runnable.ResnetRunnable(flags_obj, time_callback, runnable = resnet_runnable.ResnetRunnable(flags_obj, time_callback,
per_epoch_steps) per_epoch_steps)
eval_interval = ( eval_interval = flags_obj.epochs_between_evals * per_epoch_steps
flags_obj.epochs_between_evals *
per_epoch_steps if not flags_obj.skip_eval else None)
checkpoint_interval = ( checkpoint_interval = (
per_epoch_steps if flags_obj.enable_checkpoint_and_export else None) per_epoch_steps if flags_obj.enable_checkpoint_and_export else None)
summary_interval = per_epoch_steps if flags_obj.enable_tensorboard else None summary_interval = per_epoch_steps if flags_obj.enable_tensorboard else None
...@@ -174,7 +172,7 @@ def run(flags_obj): ...@@ -174,7 +172,7 @@ def run(flags_obj):
eval_interval=eval_interval) eval_interval=eval_interval)
time_callback.on_train_begin() time_callback.on_train_begin()
resnet_controller.train(evaluate=True) resnet_controller.train(evaluate=not flags_obj.skip_eval)
time_callback.on_train_end() time_callback.on_train_end()
stats = build_stats(runnable, time_callback) stats = build_stats(runnable, time_callback)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment