Commit 34db82bf authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 331877703
parent 9e8d7643
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
import os import os
import time import time
from typing import Mapping, Any from typing import Any, Mapping, Optional
from absl import app from absl import app
from absl import flags from absl import flags
...@@ -36,30 +36,36 @@ from official.core import train_utils ...@@ -36,30 +36,36 @@ from official.core import train_utils
from official.modeling import performance from official.modeling import performance
from official.modeling.hyperparams import config_definitions from official.modeling.hyperparams import config_definitions
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
flags.DEFINE_integer(
'pretrain_steps',
default=None,
help='The number of total training steps for the pretraining job.')
def run_continuous_finetune( def run_continuous_finetune(
mode: str, mode: str,
params: config_definitions.ExperimentConfig, params: config_definitions.ExperimentConfig,
model_dir: str, model_dir: str,
run_post_eval: bool = False, run_post_eval: bool = False,
pretrain_steps: Optional[int] = None,
) -> Mapping[str, Any]: ) -> Mapping[str, Any]:
"""Run modes with continuous training. """Run modes with continuous training.
Currently only supports continuous_train_and_eval. Currently only supports continuous_train_and_eval.
Args: Args:
mode: A 'str', specifying the mode. mode: A 'str', specifying the mode. continuous_train_and_eval - monitors a
continuous_train_and_eval - monitors a checkpoint directory. Once a new checkpoint directory. Once a new checkpoint is discovered, loads the
checkpoint is discovered, loads the checkpoint, finetune the model by checkpoint, finetune the model by training it (probably on another dataset
training it (probably on another dataset or with another task), then or with another task), then evaluate the finetuned model.
evaluate the finetuned model.
params: ExperimentConfig instance. params: ExperimentConfig instance.
model_dir: A 'str', a path to store model checkpoints and summaries. model_dir: A 'str', a path to store model checkpoints and summaries.
run_post_eval: Whether to run post eval once after training, metrics logs run_post_eval: Whether to run post eval once after training, metrics logs
are returned. are returned.
pretrain_steps: Optional, the number of total training steps for the
pretraining job.
Returns: Returns:
eval logs: returns eval metrics logs when run_post_eval is set to True, eval logs: returns eval metrics logs when run_post_eval is set to True,
...@@ -140,6 +146,11 @@ def run_continuous_finetune( ...@@ -140,6 +146,11 @@ def run_continuous_finetune(
train_utils.remove_ckpts(model_dir) train_utils.remove_ckpts(model_dir)
if pretrain_steps and global_step.numpy() >= pretrain_steps:
logging.info('The global_step reaches the pretraining end. Continuous '
'finetuning terminates.')
break
if run_post_eval: if run_post_eval:
return eval_metrics return eval_metrics
return {} return {}
...@@ -150,7 +161,7 @@ def main(_): ...@@ -150,7 +161,7 @@ def main(_):
params = train_utils.parse_configuration(FLAGS) params = train_utils.parse_configuration(FLAGS)
model_dir = FLAGS.model_dir model_dir = FLAGS.model_dir
train_utils.serialize_config(params, model_dir) train_utils.serialize_config(params, model_dir)
run_continuous_finetune(FLAGS.mode, params, model_dir) run_continuous_finetune(FLAGS.mode, params, model_dir, FLAGS.pretrain_steps)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -31,10 +31,10 @@ FLAGS = flags.FLAGS ...@@ -31,10 +31,10 @@ FLAGS = flags.FLAGS
tfm_flags.define_flags() tfm_flags.define_flags()
class MainContinuousFinetuneTest(tf.test.TestCase): class ContinuousFinetuneTest(tf.test.TestCase):
def setUp(self): def setUp(self):
super(MainContinuousFinetuneTest, self).setUp() super().setUp()
self._model_dir = os.path.join(self.get_temp_dir(), 'model_dir') self._model_dir = os.path.join(self.get_temp_dir(), 'model_dir')
@flagsaver.flagsaver @flagsaver.flagsaver
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment