"vscode:/vscode.git/clone" did not exist on "2526c614bd156a65bdf326366b57358d0873a781"
Commit 7c29567d authored by Ruoxin Sang's avatar Ruoxin Sang Committed by A. Unique TensorFlower
Browse files

Remove the logic of saving checkpoint at step 0 in Orbit.

PiperOrigin-RevId: 326116282
parent 67d39836
......@@ -134,13 +134,9 @@ class Controller:
# TODO(momernick): We probably only want to do this on certain occasions?
if self.checkpoint_manager is not None:
checkpoint_interval = self.checkpoint_manager.checkpoint_interval
model_restored = self.restore_checkpoint()
if not model_restored and (checkpoint_interval and
self.trainer is not None):
# If the model is not restored from a checkpoint, and
# `checkpoint_interval` is enabled for training, save an initial
# checkpoint.
self.save_checkpoint()
restored_path = self.restore_checkpoint()
if restored_path:
logging.info("Restored from checkpoint: %s", restored_path)
def train(self, steps: int, checkpoint_at_completion: bool = True):
"""Runs training.
......
......@@ -667,9 +667,9 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
test_controller.train_and_evaluate(
train_steps=10, eval_steps=2, eval_interval=5)
# Expect 3 checkpoints to be saved at step: 0, 5, 10.
# Expect 3 checkpoints to be saved at step: 5, 10.
self.assertLen(
tf.io.gfile.glob(os.path.join(self.model_dir, "ckpt-*.data*")), 3)
tf.io.gfile.glob(os.path.join(self.model_dir, "ckpt-*.data*")), 2)
# Expect evaluation is performed 2 times at step: 5, 10.
self.assertLen(
summaries_with_matching_keyword("eval_loss", self.model_dir), 2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment