Commit 29955e8b authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 334062074
parent f039e4b9
...@@ -102,10 +102,24 @@ def run_continuous_finetune( ...@@ -102,10 +102,24 @@ def run_continuous_finetune(
summary_writer = tf.summary.create_file_writer( summary_writer = tf.summary.create_file_writer(
os.path.join(model_dir, 'eval')) os.path.join(model_dir, 'eval'))
global_step = 0
def timeout_fn():
if pretrain_steps and global_step < pretrain_steps:
# Keeps waiting for another timeout period.
logging.info(
'Continue waiting for new checkpoint as current pretrain '
'global_step=%d and target is %d.', global_step, pretrain_steps)
return False
# Quits the loop.
return True
for pretrain_ckpt in tf.train.checkpoints_iterator( for pretrain_ckpt in tf.train.checkpoints_iterator(
checkpoint_dir=params.task.init_checkpoint, checkpoint_dir=params.task.init_checkpoint,
min_interval_secs=10, min_interval_secs=10,
timeout=params.trainer.continuous_eval_timeout): timeout=params.trainer.continuous_eval_timeout,
timeout_fn=timeout_fn):
with distribution_strategy.scope(): with distribution_strategy.scope():
global_step = train_utils.read_global_step_from_checkpoint(pretrain_ckpt) global_step = train_utils.read_global_step_from_checkpoint(pretrain_ckpt)
...@@ -154,11 +168,6 @@ def run_continuous_finetune( ...@@ -154,11 +168,6 @@ def run_continuous_finetune(
# if we need gc here. # if we need gc here.
gc.collect() gc.collect()
if pretrain_steps and global_step.numpy() >= pretrain_steps:
logging.info('The global_step reaches the pretraining end. Continuous '
'finetuning terminates.')
break
if run_post_eval: if run_post_eval:
return eval_metrics return eval_metrics
return {} return {}
......
...@@ -15,10 +15,9 @@ ...@@ -15,10 +15,9 @@
# ============================================================================== # ==============================================================================
import os import os
# Import libraries
from absl import flags from absl import flags
from absl.testing import flagsaver from absl.testing import flagsaver
from absl.testing import parameterized
import tensorflow as tf import tensorflow as tf
from official.common import flags as tfm_flags from official.common import flags as tfm_flags
from official.core import task_factory from official.core import task_factory
...@@ -31,14 +30,14 @@ FLAGS = flags.FLAGS ...@@ -31,14 +30,14 @@ FLAGS = flags.FLAGS
tfm_flags.define_flags() tfm_flags.define_flags()
class ContinuousFinetuneTest(tf.test.TestCase): class ContinuousFinetuneTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self): def setUp(self):
super().setUp() super().setUp()
self._model_dir = os.path.join(self.get_temp_dir(), 'model_dir') self._model_dir = os.path.join(self.get_temp_dir(), 'model_dir')
@flagsaver.flagsaver @parameterized.parameters(None, 1)
def testTrainCtl(self): def testTrainCtl(self, pretrain_steps):
src_model_dir = self.get_temp_dir() src_model_dir = self.get_temp_dir()
flags_dict = dict( flags_dict = dict(
experiment='mock', experiment='mock',
...@@ -81,7 +80,11 @@ class ContinuousFinetuneTest(tf.test.TestCase): ...@@ -81,7 +80,11 @@ class ContinuousFinetuneTest(tf.test.TestCase):
params = train_utils.parse_configuration(FLAGS) params = train_utils.parse_configuration(FLAGS)
eval_metrics = train_ctl_continuous_finetune.run_continuous_finetune( eval_metrics = train_ctl_continuous_finetune.run_continuous_finetune(
FLAGS.mode, params, FLAGS.model_dir, run_post_eval=True) FLAGS.mode,
params,
FLAGS.model_dir,
run_post_eval=True,
pretrain_steps=pretrain_steps)
self.assertIn('best_acc', eval_metrics) self.assertIn('best_acc', eval_metrics)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment