Commit 29955e8b authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 334062074
parent f039e4b9
......@@ -102,10 +102,24 @@ def run_continuous_finetune(
summary_writer = tf.summary.create_file_writer(
os.path.join(model_dir, 'eval'))
global_step = 0
def timeout_fn():
if pretrain_steps and global_step < pretrain_steps:
# Keeps waiting for another timeout period.
logging.info(
'Continue waiting for new checkpoint as current pretrain '
'global_step=%d and target is %d.', global_step, pretrain_steps)
return False
# Quits the loop.
return True
for pretrain_ckpt in tf.train.checkpoints_iterator(
checkpoint_dir=params.task.init_checkpoint,
min_interval_secs=10,
timeout=params.trainer.continuous_eval_timeout):
timeout=params.trainer.continuous_eval_timeout,
timeout_fn=timeout_fn):
with distribution_strategy.scope():
global_step = train_utils.read_global_step_from_checkpoint(pretrain_ckpt)
......@@ -154,11 +168,6 @@ def run_continuous_finetune(
# if we need gc here.
gc.collect()
if pretrain_steps and global_step.numpy() >= pretrain_steps:
logging.info('The global_step reaches the pretraining end. Continuous '
'finetuning terminates.')
break
if run_post_eval:
return eval_metrics
return {}
......
......@@ -15,10 +15,9 @@
# ==============================================================================
import os
# Import libraries
from absl import flags
from absl.testing import flagsaver
from absl.testing import parameterized
import tensorflow as tf
from official.common import flags as tfm_flags
from official.core import task_factory
......@@ -31,14 +30,14 @@ FLAGS = flags.FLAGS
tfm_flags.define_flags()
class ContinuousFinetuneTest(tf.test.TestCase):
class ContinuousFinetuneTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super().setUp()
self._model_dir = os.path.join(self.get_temp_dir(), 'model_dir')
@flagsaver.flagsaver
def testTrainCtl(self):
@parameterized.parameters(None, 1)
def testTrainCtl(self, pretrain_steps):
src_model_dir = self.get_temp_dir()
flags_dict = dict(
experiment='mock',
......@@ -81,7 +80,11 @@ class ContinuousFinetuneTest(tf.test.TestCase):
params = train_utils.parse_configuration(FLAGS)
eval_metrics = train_ctl_continuous_finetune.run_continuous_finetune(
FLAGS.mode, params, FLAGS.model_dir, run_post_eval=True)
FLAGS.mode,
params,
FLAGS.model_dir,
run_post_eval=True,
pretrain_steps=pretrain_steps)
self.assertIn('best_acc', eval_metrics)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment