Commit 39aea35f authored by Ayush Dubey's avatar Ayush Dubey Committed by A. Unique TensorFlower
Browse files

Add flags to run BERT with MultiWorkerMirroredStrategy.

PiperOrigin-RevId: 264935345
parent b3ee015d
...@@ -32,10 +32,10 @@ def define_common_bert_flags(): ...@@ -32,10 +32,10 @@ def define_common_bert_flags():
'init_checkpoint', None, 'init_checkpoint', None,
'Initial checkpoint (usually from a pre-trained BERT model).') 'Initial checkpoint (usually from a pre-trained BERT model).')
flags.DEFINE_enum( flags.DEFINE_enum(
'strategy_type', 'mirror', ['tpu', 'mirror'], 'strategy_type', 'mirror', ['tpu', 'mirror', 'multi_worker_mirror'],
'Distribution Strategy type to use for training. `tpu` uses ' 'Distribution Strategy type to use for training. `tpu` uses '
'TPUStrategy for running on TPUs, `mirror` uses GPUs with ' 'TPUStrategy for running on TPUs, `mirror` uses GPUs with single host, '
'single host.') '`multi_worker_mirror` uses CPUs or GPUs with multiple hosts.')
flags.DEFINE_integer('num_train_epochs', 3, flags.DEFINE_integer('num_train_epochs', 3,
'Total number of training epochs to perform.') 'Total number of training epochs to perform.')
flags.DEFINE_integer( flags.DEFINE_integer(
......
...@@ -321,6 +321,8 @@ def main(_): ...@@ -321,6 +321,8 @@ def main(_):
strategy = None strategy = None
if FLAGS.strategy_type == 'mirror': if FLAGS.strategy_type == 'mirror':
strategy = tf.distribute.MirroredStrategy() strategy = tf.distribute.MirroredStrategy()
elif FLAGS.strategy_type == 'multi_worker_mirror':
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
elif FLAGS.strategy_type == 'tpu': elif FLAGS.strategy_type == 'tpu':
# Initialize TPU System. # Initialize TPU System.
cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu) cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment