"docs/EN/source/vscode:/vscode.git/clone" did not exist on "dfc3b85ed20b048486db697a703cc542835802a1"
Commit 13d44a05 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Fix controller bugs. Add tests for optional args.

PiperOrigin-RevId: 302323163
parent a348a90b
...@@ -117,11 +117,18 @@ class Controller(object): ...@@ -117,11 +117,18 @@ class Controller(object):
if self.train_fn is not None: if self.train_fn is not None:
self.train_steps = train_steps self.train_steps = train_steps
self.steps_per_loop = steps_per_loop self.steps_per_loop = steps_per_loop
self.summary_dir = summary_dir or checkpoint_manager.directory if summary_dir:
self.summary_dir = summary_dir
elif checkpoint_manager:
self.summary_dir = checkpoint_manager.directory
else:
self.summary_dir = None
self.summary_interval = summary_interval self.summary_interval = summary_interval
summary_writer = tf.summary.create_file_writer( if self.summary_dir and self.summary_interval:
self.summary_dir) if self.summary_interval else None summary_writer = tf.summary.create_file_writer(self.summary_dir)
else:
summary_writer = None
# TODO(rxsang): Consider pass SummaryManager directly into Controller for # TODO(rxsang): Consider pass SummaryManager directly into Controller for
# maximum customizability. # maximum customizability.
self.summary_manager = utils.SummaryManager( self.summary_manager = utils.SummaryManager(
...@@ -140,14 +147,14 @@ class Controller(object): ...@@ -140,14 +147,14 @@ class Controller(object):
self.eval_steps = eval_steps self.eval_steps = eval_steps
self.eval_interval = eval_interval self.eval_interval = eval_interval
# Create and initialize the interval triggers. # Creates and initializes the interval triggers.
self.eval_trigger = utils.IntervalTrigger(self.eval_interval, self.eval_trigger = utils.IntervalTrigger(self.eval_interval,
self.global_step.numpy()) self.global_step.numpy()) # pytype: disable=attribute-error
if self.global_step: if self.global_step:
tf.summary.experimental.set_step(self.global_step) tf.summary.experimental.set_step(self.global_step)
# Restore Model if needed. # Restores the model if needed.
if self.checkpoint_manager is not None: if self.checkpoint_manager is not None:
model_restored = self._restore_model() model_restored = self._restore_model()
if not model_restored and self.checkpoint_manager.checkpoint_interval: if not model_restored and self.checkpoint_manager.checkpoint_interval:
...@@ -192,7 +199,7 @@ class Controller(object): ...@@ -192,7 +199,7 @@ class Controller(object):
self.eval_summary_manager.flush() self.eval_summary_manager.flush()
def _maybe_save_checkpoints(self, current_step, force_trigger=False): def _maybe_save_checkpoints(self, current_step, force_trigger=False):
if self.checkpoint_manager.checkpoint_interval: if self.checkpoint_manager and self.checkpoint_manager.checkpoint_interval:
ckpt_path = self.checkpoint_manager.save( ckpt_path = self.checkpoint_manager.save(
checkpoint_number=current_step, check_interval=not force_trigger) checkpoint_number=current_step, check_interval=not force_trigger)
if ckpt_path is not None: if ckpt_path is not None:
......
...@@ -143,6 +143,52 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -143,6 +143,52 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
super(ControllerTest, self).setUp() super(ControllerTest, self).setUp()
self.model_dir = self.get_temp_dir() self.model_dir = self.get_temp_dir()
def test_no_checkpoint(self):
test_runnable = TestRunnable()
# No checkpoint manager and no strategy.
test_controller = controller.Controller(
train_fn=test_runnable.train,
eval_fn=test_runnable.evaluate,
global_step=test_runnable.global_step,
train_steps=10,
steps_per_loop=2,
summary_dir=os.path.join(self.model_dir, "summaries/train"),
summary_interval=2,
eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"),
eval_steps=2,
eval_interval=5)
test_controller.train(evaluate=True)
self.assertEqual(test_runnable.global_step.numpy(), 10)
# Loss and accuracy values should be written into summaries.
self.assertNotEmpty(
tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/train")))
self.assertTrue(
check_eventfile_for_keyword(
"loss", os.path.join(self.model_dir, "summaries/train")))
self.assertNotEmpty(
tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/eval")))
self.assertTrue(
check_eventfile_for_keyword(
"eval_loss", os.path.join(self.model_dir, "summaries/eval")))
# No checkpoint, so global step starts from 0.
test_runnable.global_step.assign(0)
test_controller.train(evaluate=True)
self.assertEqual(test_runnable.global_step.numpy(), 10)
def test_no_checkpoint_and_summaries(self):
test_runnable = TestRunnable()
# No checkpoint + summary directories.
test_controller = controller.Controller(
train_fn=test_runnable.train,
eval_fn=test_runnable.evaluate,
global_step=test_runnable.global_step,
train_steps=10,
steps_per_loop=2,
eval_steps=2,
eval_interval=5)
test_controller.train(evaluate=True)
self.assertEqual(test_runnable.global_step.numpy(), 10)
@combinations.generate(all_strategy_combinations()) @combinations.generate(all_strategy_combinations())
def test_train_and_evaluate(self, strategy): def test_train_and_evaluate(self, strategy):
with strategy.scope(): with strategy.scope():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment