Commit 34db82bf authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 331877703
parent 9e8d7643
......@@ -17,7 +17,7 @@
import os
import time
from typing import Mapping, Any
from typing import Any, Mapping, Optional
from absl import app
from absl import flags
......@@ -36,30 +36,36 @@ from official.core import train_utils
from official.modeling import performance
from official.modeling.hyperparams import config_definitions
FLAGS = flags.FLAGS
flags.DEFINE_integer(
'pretrain_steps',
default=None,
help='The number of total training steps for the pretraining job.')
def run_continuous_finetune(
mode: str,
params: config_definitions.ExperimentConfig,
model_dir: str,
run_post_eval: bool = False,
pretrain_steps: Optional[int] = None,
) -> Mapping[str, Any]:
"""Run modes with continuous training.
Currently only supports continuous_train_and_eval.
Args:
mode: A 'str', specifying the mode.
continuous_train_and_eval - monitors a checkpoint directory. Once a new
checkpoint is discovered, loads the checkpoint, finetune the model by
training it (probably on another dataset or with another task), then
evaluate the finetuned model.
mode: A 'str', specifying the mode. continuous_train_and_eval - monitors a
checkpoint directory. Once a new checkpoint is discovered, loads the
checkpoint, finetune the model by training it (probably on another dataset
or with another task), then evaluate the finetuned model.
params: ExperimentConfig instance.
model_dir: A 'str', a path to store model checkpoints and summaries.
run_post_eval: Whether to run post eval once after training, metrics logs
are returned.
pretrain_steps: Optional, the number of total training steps for the
pretraining job.
Returns:
eval logs: returns eval metrics logs when run_post_eval is set to True,
......@@ -140,6 +146,11 @@ def run_continuous_finetune(
train_utils.remove_ckpts(model_dir)
if pretrain_steps and global_step.numpy() >= pretrain_steps:
logging.info('The global_step reaches the pretraining end. Continuous '
'finetuning terminates.')
break
if run_post_eval:
return eval_metrics
return {}
......@@ -150,7 +161,7 @@ def main(_):
params = train_utils.parse_configuration(FLAGS)
model_dir = FLAGS.model_dir
train_utils.serialize_config(params, model_dir)
run_continuous_finetune(FLAGS.mode, params, model_dir)
run_continuous_finetune(FLAGS.mode, params, model_dir, FLAGS.pretrain_steps)
if __name__ == '__main__':
......
......@@ -31,10 +31,10 @@ FLAGS = flags.FLAGS
tfm_flags.define_flags()
class MainContinuousFinetuneTest(tf.test.TestCase):
class ContinuousFinetuneTest(tf.test.TestCase):
def setUp(self):
super(MainContinuousFinetuneTest, self).setUp()
super().setUp()
self._model_dir = os.path.join(self.get_temp_dir(), 'model_dir')
@flagsaver.flagsaver
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment