Commit a296fa9c authored by Le Hou's avatar Le Hou Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 329604859
parent 2b166642
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TFM common training driver."""
from absl import app
from absl import flags
import gin
from official.core import train_utils
# pylint: disable=unused-import
from official.common import registry_imports
# pylint: enable=unused-import
from official.common import flags as tfm_flags
from official.core import task_factory
from official.core import train_lib
from official.modeling import performance
from official.utils.misc import distribution_utils
FLAGS = flags.FLAGS
def main(_):
gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params)
params = train_utils.parse_configuration(FLAGS)
model_dir = FLAGS.model_dir
if 'train' in FLAGS.mode:
# Pure eval modes do not output yaml files. Otherwise continuous eval job
# may race against the train job for writing the same file.
train_utils.serialize_config(params, model_dir)
# Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16'
# can have significant impact on model speeds by utilizing float16 in case of
# GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
# dtype is float16
if params.runtime.mixed_precision_dtype:
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype,
params.runtime.loss_scale)
distribution_strategy = distribution_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg,
num_gpus=params.runtime.num_gpus,
tpu_address=params.runtime.tpu)
with distribution_strategy.scope():
task = task_factory.get_task(params.task, logging_dir=model_dir)
train_lib.run_experiment(
distribution_strategy=distribution_strategy,
task=task,
mode=FLAGS.mode,
params=params,
model_dir=model_dir)
if __name__ == '__main__':
tfm_flags.define_flags()
app.run(main)
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TFM continuous finetuning+eval training driver."""
import os
import time
from typing import Mapping, Any
from absl import app
from absl import flags
from absl import logging
import gin
import tensorflow as tf
# pylint: disable=unused-import
from official.common import registry_imports
# pylint: enable=unused-import
from official.common import flags as tfm_flags
from official.core import task_factory
from official.core import train_lib
from official.core import train_utils
from official.modeling import performance
from official.modeling.hyperparams import config_definitions
from official.utils.misc import distribution_utils
FLAGS = flags.FLAGS
def run_continuous_finetune(
mode: str,
params: config_definitions.ExperimentConfig,
model_dir: str,
run_post_eval: bool = False,
) -> Mapping[str, Any]:
"""Run modes with continuous training.
Currently only supports continuous_train_and_eval.
Args:
mode: A 'str', specifying the mode.
continuous_train_and_eval - monitors a checkpoint directory. Once a new
checkpoint is discovered, loads the checkpoint, finetune the model by
training it (probably on another dataset or with another task), then
evaluate the finetuned model.
params: ExperimentConfig instance.
model_dir: A 'str', a path to store model checkpoints and summaries.
run_post_eval: Whether to run post eval once after training, metrics logs
are returned.
Returns:
eval logs: returns eval metrics logs when run_post_eval is set to True,
othewise, returns {}.
"""
assert mode == 'continuous_train_and_eval', (
'Only continuous_train_and_eval is supported by continuous_finetune. '
'Got mode: {}'.format(mode))
# Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16'
# can have significant impact on model speeds by utilizing float16 in case of
# GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
# dtype is float16
if params.runtime.mixed_precision_dtype:
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype,
params.runtime.loss_scale)
distribution_strategy = distribution_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg,
num_gpus=params.runtime.num_gpus,
tpu_address=params.runtime.tpu)
retry_times = 0
while not tf.io.gfile.isdir(params.task.init_checkpoint):
# Wait for the init_checkpoint directory to be created.
if retry_times >= 60:
raise ValueError(
'ExperimentConfig.task.init_checkpoint must be a directory for '
'continuous_train_and_eval mode.')
retry_times += 1
time.sleep(60)
summary_writer = tf.summary.create_file_writer(
os.path.join(model_dir, 'eval'))
for pretrain_ckpt in tf.train.checkpoints_iterator(
checkpoint_dir=params.task.init_checkpoint,
min_interval_secs=10,
timeout=params.trainer.continuous_eval_timeout):
with distribution_strategy.scope():
global_step = train_utils.read_global_step_from_checkpoint(pretrain_ckpt)
if params.trainer.best_checkpoint_export_subdir:
best_ckpt_subdir = '{}_{}'.format(
params.trainer.best_checkpoint_export_subdir, global_step)
params_replaced = params.replace(
task={'init_checkpoint': pretrain_ckpt},
trainer={'best_checkpoint_export_subdir': best_ckpt_subdir})
else:
params_replaced = params.replace(task={'init_checkpoint': pretrain_ckpt})
params_replaced.lock()
logging.info('Running finetuning with params: %s', params_replaced)
with distribution_strategy.scope():
task = task_factory.get_task(params_replaced.task, logging_dir=model_dir)
_, eval_metrics = train_lib.run_experiment(
distribution_strategy=distribution_strategy,
task=task,
mode='train_and_eval',
# replace params.task.init_checkpoint to make sure that we load
# exactly this pretrain checkpoint.
params=params_replaced,
model_dir=model_dir,
run_post_eval=True,
save_summary=False)
logging.info('Evaluation finished. Pretrain global_step: %d', global_step)
train_utils.write_json_summary(model_dir, global_step, eval_metrics)
if not os.path.basename(model_dir): # if model_dir.endswith('/')
summary_grp = os.path.dirname(model_dir) + '_' + task.__class__.__name__
else:
summary_grp = os.path.basename(model_dir) + '_' + task.__class__.__name__
summaries = {}
for name, value in eval_metrics.items():
summaries[summary_grp + '/' + name] = value
train_utils.write_summary(summary_writer, global_step, summaries)
train_utils.remove_ckpts(model_dir)
if run_post_eval:
return eval_metrics
return {}
def main(_):
gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params)
params = train_utils.parse_configuration(FLAGS)
model_dir = FLAGS.model_dir
train_utils.serialize_config(params, model_dir)
run_continuous_finetune(FLAGS.mode, params, model_dir)
if __name__ == '__main__':
tfm_flags.define_flags()
app.run(main)
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import os
# Import libraries
from absl import flags
from absl.testing import flagsaver
import tensorflow as tf
from official.common import flags as tfm_flags
from official.core import task_factory
from official.core import train_lib
from official.core import train_utils
from official.nlp import train_ctl_continuous_finetune
FLAGS = flags.FLAGS
tfm_flags.define_flags()
class MainContinuousFinetuneTest(tf.test.TestCase):
def setUp(self):
super(MainContinuousFinetuneTest, self).setUp()
self._model_dir = os.path.join(self.get_temp_dir(), 'model_dir')
@flagsaver.flagsaver
def testTrainCtl(self):
src_model_dir = self.get_temp_dir()
flags_dict = dict(
experiment='mock',
mode='continuous_train_and_eval',
model_dir=self._model_dir,
params_override={
'task': {
'init_checkpoint': src_model_dir,
},
'trainer': {
'continuous_eval_timeout': 1,
'steps_per_loop': 1,
'train_steps': 1,
'validation_steps': 1,
'best_checkpoint_export_subdir': 'best_ckpt',
'best_checkpoint_eval_metric': 'acc',
'optimizer_config': {
'optimizer': {
'type': 'sgd'
},
'learning_rate': {
'type': 'constant'
}
}
}
})
with flagsaver.flagsaver(**flags_dict):
# Train and save some checkpoints.
params = train_utils.parse_configuration(flags.FLAGS)
distribution_strategy = tf.distribute.get_strategy()
with distribution_strategy.scope():
task = task_factory.get_task(params.task, logging_dir=src_model_dir)
_ = train_lib.run_experiment(
distribution_strategy=distribution_strategy,
task=task,
mode='train',
params=params,
model_dir=src_model_dir)
params = train_utils.parse_configuration(FLAGS)
eval_metrics = train_ctl_continuous_finetune.run_continuous_finetune(
FLAGS.mode, params, FLAGS.model_dir, run_post_eval=True)
self.assertIn('best_acc', eval_metrics)
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment