Commit 13d44a05 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Fix controller bugs. Add tests for optional args.

PiperOrigin-RevId: 302323163
parent a348a90b
......@@ -117,11 +117,18 @@ class Controller(object):
if self.train_fn is not None:
self.train_steps = train_steps
self.steps_per_loop = steps_per_loop
self.summary_dir = summary_dir or checkpoint_manager.directory
if summary_dir:
self.summary_dir = summary_dir
elif checkpoint_manager:
self.summary_dir = checkpoint_manager.directory
else:
self.summary_dir = None
self.summary_interval = summary_interval
summary_writer = tf.summary.create_file_writer(
self.summary_dir) if self.summary_interval else None
if self.summary_dir and self.summary_interval:
summary_writer = tf.summary.create_file_writer(self.summary_dir)
else:
summary_writer = None
# TODO(rxsang): Consider pass SummaryManager directly into Controller for
# maximum customizability.
self.summary_manager = utils.SummaryManager(
......@@ -140,14 +147,14 @@ class Controller(object):
self.eval_steps = eval_steps
self.eval_interval = eval_interval
# Create and initialize the interval triggers.
# Creates and initializes the interval triggers.
self.eval_trigger = utils.IntervalTrigger(self.eval_interval,
self.global_step.numpy())
self.global_step.numpy()) # pytype: disable=attribute-error
if self.global_step:
tf.summary.experimental.set_step(self.global_step)
# Restore Model if needed.
# Restores the model if needed.
if self.checkpoint_manager is not None:
model_restored = self._restore_model()
if not model_restored and self.checkpoint_manager.checkpoint_interval:
......@@ -192,7 +199,7 @@ class Controller(object):
self.eval_summary_manager.flush()
def _maybe_save_checkpoints(self, current_step, force_trigger=False):
if self.checkpoint_manager.checkpoint_interval:
if self.checkpoint_manager and self.checkpoint_manager.checkpoint_interval:
ckpt_path = self.checkpoint_manager.save(
checkpoint_number=current_step, check_interval=not force_trigger)
if ckpt_path is not None:
......
......@@ -143,6 +143,52 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
super(ControllerTest, self).setUp()
self.model_dir = self.get_temp_dir()
def test_no_checkpoint(self):
test_runnable = TestRunnable()
# No checkpoint manager and no strategy.
test_controller = controller.Controller(
train_fn=test_runnable.train,
eval_fn=test_runnable.evaluate,
global_step=test_runnable.global_step,
train_steps=10,
steps_per_loop=2,
summary_dir=os.path.join(self.model_dir, "summaries/train"),
summary_interval=2,
eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"),
eval_steps=2,
eval_interval=5)
test_controller.train(evaluate=True)
self.assertEqual(test_runnable.global_step.numpy(), 10)
# Loss and accuracy values should be written into summaries.
self.assertNotEmpty(
tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/train")))
self.assertTrue(
check_eventfile_for_keyword(
"loss", os.path.join(self.model_dir, "summaries/train")))
self.assertNotEmpty(
tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/eval")))
self.assertTrue(
check_eventfile_for_keyword(
"eval_loss", os.path.join(self.model_dir, "summaries/eval")))
# No checkpoint, so global step starts from 0.
test_runnable.global_step.assign(0)
test_controller.train(evaluate=True)
self.assertEqual(test_runnable.global_step.numpy(), 10)
def test_no_checkpoint_and_summaries(self):
test_runnable = TestRunnable()
# No checkpoint + summary directories.
test_controller = controller.Controller(
train_fn=test_runnable.train,
eval_fn=test_runnable.evaluate,
global_step=test_runnable.global_step,
train_steps=10,
steps_per_loop=2,
eval_steps=2,
eval_interval=5)
test_controller.train(evaluate=True)
self.assertEqual(test_runnable.global_step.numpy(), 10)
@combinations.generate(all_strategy_combinations())
def test_train_and_evaluate(self, strategy):
with strategy.scope():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment