Commit 649a5944 authored by Le Hou's avatar Le Hou Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 388586684
parent 0bd9a8b2
...@@ -27,9 +27,15 @@ from official.core import task_factory ...@@ -27,9 +27,15 @@ from official.core import task_factory
from official.core import train_lib from official.core import train_lib
from official.core import train_utils from official.core import train_utils
from official.modeling import performance from official.modeling import performance
from official.nlp import continuous_finetune_lib
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
flags.DEFINE_integer(
'pretrain_steps',
default=None,
help='The number of total training steps for the pretraining job.')
def main(_): def main(_):
gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params) gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params)
...@@ -40,27 +46,33 @@ def main(_): ...@@ -40,27 +46,33 @@ def main(_):
# may race against the train job for writing the same file. # may race against the train job for writing the same file.
train_utils.serialize_config(params, model_dir) train_utils.serialize_config(params, model_dir)
# Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16' if FLAGS.mode == 'continuous_train_and_eval':
# can have significant impact on model speeds by utilizing float16 in case of continuous_finetune_lib.run_continuous_finetune(
# GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when FLAGS.mode, params, model_dir, pretrain_steps=FLAGS.pretrain_steps)
# dtype is float16
if params.runtime.mixed_precision_dtype: else:
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype) # Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16'
distribution_strategy = distribute_utils.get_distribution_strategy( # can have significant impact on model speeds by utilizing float16 in case
distribution_strategy=params.runtime.distribution_strategy, # of GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only
all_reduce_alg=params.runtime.all_reduce_alg, # when dtype is float16
num_gpus=params.runtime.num_gpus, if params.runtime.mixed_precision_dtype:
tpu_address=params.runtime.tpu, performance.set_mixed_precision_policy(
**params.runtime.model_parallelism()) params.runtime.mixed_precision_dtype)
with distribution_strategy.scope(): distribution_strategy = distribute_utils.get_distribution_strategy(
task = task_factory.get_task(params.task, logging_dir=model_dir) distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg,
num_gpus=params.runtime.num_gpus,
tpu_address=params.runtime.tpu,
**params.runtime.model_parallelism())
with distribution_strategy.scope():
task = task_factory.get_task(params.task, logging_dir=model_dir)
train_lib.run_experiment( train_lib.run_experiment(
distribution_strategy=distribution_strategy, distribution_strategy=distribution_strategy,
task=task, task=task,
mode=FLAGS.mode, mode=FLAGS.mode,
params=params, params=params,
model_dir=model_dir) model_dir=model_dir)
train_utils.save_gin_config(FLAGS.mode, model_dir) train_utils.save_gin_config(FLAGS.mode, model_dir)
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TFM continuous finetuning+eval training driver."""
from absl import app
from absl import flags
import gin
# pylint: disable=unused-import
from official.common import registry_imports
# pylint: enable=unused-import
from official.common import flags as tfm_flags
from official.core import train_utils
from official.nlp import continuous_finetune_lib
FLAGS = flags.FLAGS
flags.DEFINE_integer(
'pretrain_steps',
default=None,
help='The number of total training steps for the pretraining job.')
def main(_):
# TODO(b/177863554): consolidate to nlp/train.py
gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params)
params = train_utils.parse_configuration(FLAGS)
model_dir = FLAGS.model_dir
train_utils.serialize_config(params, model_dir)
continuous_finetune_lib.run_continuous_finetune(
FLAGS.mode, params, model_dir, pretrain_steps=FLAGS.pretrain_steps)
train_utils.save_gin_config(FLAGS.mode, model_dir)
if __name__ == '__main__':
tfm_flags.define_flags()
app.run(main)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment