Commit 9da3a081 authored by Le Hou's avatar Le Hou Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 357284280
parent 537aaad5
...@@ -145,6 +145,11 @@ def run_continuous_finetune( ...@@ -145,6 +145,11 @@ def run_continuous_finetune(
min_interval_secs=10, min_interval_secs=10,
timeout=params.trainer.continuous_eval_timeout, timeout=params.trainer.continuous_eval_timeout,
timeout_fn=timeout_fn): timeout_fn=timeout_fn):
# If there are checkpoints, they might be the finetune checkpoint of a
# different pretrained checkpoint. So we just remove all checkpoints.
train_utils.remove_ckpts(model_dir)
with distribution_strategy.scope(): with distribution_strategy.scope():
global_step = train_utils.read_global_step_from_checkpoint(pretrain_ckpt) global_step = train_utils.read_global_step_from_checkpoint(pretrain_ckpt)
# Replaces params.task.init_checkpoint to make sure that we load # Replaces params.task.init_checkpoint to make sure that we load
......
...@@ -90,6 +90,9 @@ class ContinuousFinetuneTest(tf.test.TestCase, parameterized.TestCase): ...@@ -90,6 +90,9 @@ class ContinuousFinetuneTest(tf.test.TestCase, parameterized.TestCase):
pretrain_steps=pretrain_steps) pretrain_steps=pretrain_steps)
self.assertIn('best_acc', eval_metrics) self.assertIn('best_acc', eval_metrics)
self.assertFalse(
tf.io.gfile.exists(os.path.join(FLAGS.model_dir, 'checkpoint')))
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment