Commit 7c29567d authored by Ruoxin Sang's avatar Ruoxin Sang Committed by A. Unique TensorFlower
Browse files

Remove the logic of saving checkpoint at step 0 in Orbit.

PiperOrigin-RevId: 326116282
parent 67d39836
...@@ -134,13 +134,9 @@ class Controller: ...@@ -134,13 +134,9 @@ class Controller:
# TODO(momernick): We probably only want to do this on certain occasions? # TODO(momernick): We probably only want to do this on certain occasions?
if self.checkpoint_manager is not None: if self.checkpoint_manager is not None:
checkpoint_interval = self.checkpoint_manager.checkpoint_interval checkpoint_interval = self.checkpoint_manager.checkpoint_interval
model_restored = self.restore_checkpoint() restored_path = self.restore_checkpoint()
if not model_restored and (checkpoint_interval and if restored_path:
self.trainer is not None): logging.info("Restored from checkpoint: %s", restored_path)
# If the model is not restored from a checkpoint, and
# `checkpoint_interval` is enabled for training, save an initial
# checkpoint.
self.save_checkpoint()
def train(self, steps: int, checkpoint_at_completion: bool = True): def train(self, steps: int, checkpoint_at_completion: bool = True):
"""Runs training. """Runs training.
......
...@@ -667,9 +667,9 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -667,9 +667,9 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
test_controller.train_and_evaluate( test_controller.train_and_evaluate(
train_steps=10, eval_steps=2, eval_interval=5) train_steps=10, eval_steps=2, eval_interval=5)
# Expect 3 checkpoints to be saved at step: 0, 5, 10. # Expect 3 checkpoints to be saved at step: 5, 10.
self.assertLen( self.assertLen(
tf.io.gfile.glob(os.path.join(self.model_dir, "ckpt-*.data*")), 3) tf.io.gfile.glob(os.path.join(self.model_dir, "ckpt-*.data*")), 2)
# Expect evaluation is performed 2 times at step: 5, 10. # Expect evaluation is performed 2 times at step: 5, 10.
self.assertLen( self.assertLen(
summaries_with_matching_keyword("eval_loss", self.model_dir), 2) summaries_with_matching_keyword("eval_loss", self.model_dir), 2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment