Commit 252e6384 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Use flags utils and distribution_utils

PiperOrigin-RevId: 281337671
parent 1b8c0ee8
......@@ -209,7 +209,7 @@ script should run with `tf-nightly`.
Just add the following flags to `run_classifier.py` or `run_squad.py`:
```shell
--strategy_type=tpu
--distribution_strategy=tpu
--tpu=grpc://${TPU_IP_ADDRESS}:8470
```
......@@ -243,7 +243,7 @@ python run_classifier.py \
--learning_rate=2e-5 \
--num_train_epochs=3 \
--model_dir=${MODEL_DIR} \
--strategy_type=mirror
--distribution_strategy=mirror
```
To use TPU, you only need to switch distribution strategy type to `tpu` with TPU
......@@ -267,7 +267,7 @@ python run_classifier.py \
--learning_rate=2e-5 \
--num_train_epochs=3 \
--model_dir=${MODEL_DIR} \
--strategy_type=tpu \
--distribution_strategy=tpu \
--tpu=grpc://${TPU_IP_ADDRESS}:8470
```
......@@ -299,7 +299,7 @@ python run_squad.py \
--learning_rate=8e-5 \
--num_train_epochs=2 \
--model_dir=${MODEL_DIR} \
--strategy_type=mirror
--distribution_strategy=mirror
```
To use TPU, you need switch distribution strategy type to `tpu` with TPU
......@@ -323,7 +323,7 @@ python run_squad.py \
--learning_rate=8e-5 \
--num_train_epochs=2 \
--model_dir=${MODEL_DIR} \
--strategy_type=tpu \
--distribution_strategy=tpu \
--tpu=grpc://${TPU_IP_ADDRESS}:8470
```
......
......@@ -22,11 +22,21 @@ from official.utils.flags import core as flags_core
def define_common_bert_flags():
"""Define common flags for BERT tasks."""
flags_core.define_base(
data_dir=False,
model_dir=True,
clean=False,
train_epochs=False,
epochs_between_evals=False,
stop_threshold=False,
batch_size=False,
num_gpu=True,
hooks=False,
export_dir=False,
distribution_strategy=True,
run_eagerly=True)
flags.DEFINE_string('bert_config_file', None,
'Bert configuration file to define core bert layers.')
flags.DEFINE_string('model_dir', None, (
'The directory where the model weights and training/evaluation summaries '
'are stored. If not specified, save to /tmp/bert20/.'))
flags.DEFINE_string(
'model_export_path', None,
'Path to the directory, where trainined model will be '
......@@ -35,11 +45,6 @@ def define_common_bert_flags():
flags.DEFINE_string(
'init_checkpoint', None,
'Initial checkpoint (usually from a pre-trained BERT model).')
flags.DEFINE_enum(
'strategy_type', 'mirror', ['tpu', 'mirror', 'multi_worker_mirror'],
'Distribution Strategy type to use for training. `tpu` uses '
'TPUStrategy for running on TPUs, `mirror` uses GPUs with single host, '
'`multi_worker_mirror` uses CPUs or GPUs with multiple hosts.')
flags.DEFINE_integer('num_train_epochs', 3,
'Total number of training epochs to perform.')
flags.DEFINE_integer(
......@@ -49,9 +54,6 @@ def define_common_bert_flags():
'inside.')
flags.DEFINE_float('learning_rate', 5e-5,
'The initial learning rate for Adam.')
flags.DEFINE_boolean(
'run_eagerly', False,
'Run the model op by op without building a model function.')
flags.DEFINE_boolean(
'scale_loss', False,
'Whether to divide the loss by number of replica inside the per-replica '
......
......@@ -35,8 +35,8 @@ from official.nlp import optimization
from official.nlp.bert import common_flags
from official.nlp.bert import input_pipeline
from official.nlp.bert import model_saving_utils
from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils
from official.utils.misc import tpu_lib
flags.DEFINE_enum(
'mode', 'train_and_eval', ['train_and_eval', 'export_only'],
......@@ -350,16 +350,10 @@ def main(_):
if not FLAGS.model_dir:
FLAGS.model_dir = '/tmp/bert20/'
strategy = None
if FLAGS.strategy_type == 'mirror':
strategy = tf.distribute.MirroredStrategy()
elif FLAGS.strategy_type == 'tpu':
cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu)
strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)
else:
raise ValueError('The distribution strategy type is not supported: %s' %
FLAGS.strategy_type)
strategy = distribution_utils.get_distribution_strategy(
distribution_strategy=FLAGS.distribution_strategy,
num_gpus=FLAGS.num_gpus,
tpu_address=FLAGS.tpu)
max_seq_length = input_meta_data['max_seq_length']
train_input_fn = get_dataset_fn(
FLAGS.train_data_path,
......
......@@ -30,6 +30,7 @@ from official.nlp import optimization
from official.nlp.bert import common_flags
from official.nlp.bert import input_pipeline
from official.nlp.bert import model_saving_utils
from official.utils.misc import distribution_utils
from official.utils.misc import tpu_lib
flags.DEFINE_string('input_files', None,
......@@ -172,15 +173,10 @@ def main(_):
if not FLAGS.model_dir:
FLAGS.model_dir = '/tmp/bert20/'
strategy = None
if FLAGS.strategy_type == 'mirror':
strategy = tf.distribute.MirroredStrategy()
elif FLAGS.strategy_type == 'tpu':
cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu)
strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)
else:
raise ValueError('The distribution strategy type is not supported: %s' %
FLAGS.strategy_type)
strategy = distribution_utils.get_distribution_strategy(
distribution_strategy=FLAGS.distribution_strategy,
num_gpus=FLAGS.num_gpus,
tpu_address=FLAGS.tpu)
if strategy:
print('***** Number of cores used : ', strategy.num_replicas_in_sync)
......
......@@ -36,6 +36,7 @@ from official.nlp.bert import input_pipeline
from official.nlp.bert import model_saving_utils
from official.nlp.bert import squad_lib
from official.nlp.bert import tokenization
from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils
from official.utils.misc import tpu_lib
......@@ -386,17 +387,10 @@ def main(_):
export_squad(FLAGS.model_export_path, input_meta_data)
return
strategy = None
if FLAGS.strategy_type == 'mirror':
strategy = tf.distribute.MirroredStrategy()
elif FLAGS.strategy_type == 'multi_worker_mirror':
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
elif FLAGS.strategy_type == 'tpu':
cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu)
strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)
else:
raise ValueError('The distribution strategy type is not supported: %s' %
FLAGS.strategy_type)
strategy = distribution_utils.get_distribution_strategy(
distribution_strategy=FLAGS.distribution_strategy,
num_gpus=FLAGS.num_gpus,
tpu_address=FLAGS.tpu)
if FLAGS.mode in ('train', 'train_and_predict'):
train_squad(strategy, input_meta_data)
if FLAGS.mode in ('predict', 'train_and_predict'):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment