"megatron/vscode:/vscode.git/clone" did not exist on "c2a32e12b87e737709f332e1ea8dfbde487ffefd"
Commit 497989e0 authored by Bruce Fontaine's avatar Bruce Fontaine Committed by A. Unique TensorFlower
Browse files

Use experimental_connect_to_cluster API in TPU lib to support training on a slice of a TPU pod.

PiperOrigin-RevId: 270926016
parent a52564cb
......@@ -130,8 +130,7 @@ def run_customized_training_loop(
after every epoch.
init_checkpoint: Optional checkpoint to load to `sub_model` returned by
`model_fn`.
use_remote_tpu: If true, input pipeline ops are placed in TPU worker host
as an optimization.
use_remote_tpu: Ignored, will be removed in the future.
custom_callbacks: A list of Keras Callbacks objects to run during
training. More specifically, `on_batch_begin()`, `on_batch_end()`,
methods are invoked during training.
......@@ -146,6 +145,8 @@ def run_customized_training_loop(
attribute or when required parameters are set to none. (2) eval args are
not specified correctly. (3) metric_fn must be a callable if specified.
"""
# TODO(bfontain): Remove use_remote_tpu once there are no models using it.
del use_remote_tpu
if _sentinel is not None:
raise ValueError('only call `run_customized_training_loop()` '
......@@ -188,7 +189,6 @@ def run_customized_training_loop(
# To reduce unnecessary send/receive input pipeline operation, we place input
# pipeline ops in worker task.
with tf.device(tpu_lib.get_primary_cpu_task(use_remote_tpu)):
train_iterator = _get_input_iterator(train_input_fn, strategy)
with distribution_utils.get_strategy_scope(strategy):
......
......@@ -152,7 +152,6 @@ class ModelTrainingUtilsTest(tf.test.TestCase, parameterized.TestCase):
eval_steps=10,
init_checkpoint=None,
metric_fn=metric_fn,
use_remote_tpu=False,
custom_callbacks=None,
run_eagerly=run_eagerly)
......
......@@ -90,7 +90,6 @@ def run_customized_training(strategy,
warmup_steps,
initial_lr,
init_checkpoint,
use_remote_tpu=False,
custom_callbacks=None,
run_eagerly=False):
"""Run BERT classifier training using low-level API."""
......@@ -151,7 +150,6 @@ def run_customized_training(strategy,
eval_steps=eval_steps,
init_checkpoint=init_checkpoint,
metric_fn=metric_fn,
use_remote_tpu=use_remote_tpu,
custom_callbacks=custom_callbacks,
run_eagerly=run_eagerly)
......@@ -201,7 +199,6 @@ def run_bert(strategy, input_meta_data):
# Runs customized training loop.
logging.info('Training using customized training loop TF 2.0 with distrubuted'
'strategy.')
use_remote_tpu = (FLAGS.strategy_type == 'tpu' and FLAGS.tpu)
trained_model = run_customized_training(
strategy,
bert_config,
......@@ -214,11 +211,9 @@ def run_bert(strategy, input_meta_data):
warmup_steps,
FLAGS.learning_rate,
FLAGS.init_checkpoint,
use_remote_tpu=use_remote_tpu,
run_eagerly=FLAGS.run_eagerly)
if FLAGS.model_export_path:
with tf.device(tpu_lib.get_primary_cpu_task(use_remote_tpu)):
model_saving_utils.export_bert_model(
FLAGS.model_export_path, model=trained_model)
return trained_model
......@@ -238,7 +233,6 @@ def main(_):
if FLAGS.strategy_type == 'mirror':
strategy = tf.distribute.MirroredStrategy()
elif FLAGS.strategy_type == 'tpu':
# Initialize TPU System.
cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu)
strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)
else:
......
......@@ -114,8 +114,7 @@ def run_customized_training(strategy,
initial_lr,
warmup_steps,
input_files,
train_batch_size,
use_remote_tpu=False):
train_batch_size):
"""Run BERT pretrain model training using low-level API."""
train_input_fn = functools.partial(get_pretrain_input_data, input_files,
......@@ -148,8 +147,7 @@ def run_customized_training(strategy,
train_input_fn=train_input_fn,
steps_per_epoch=steps_per_epoch,
steps_per_loop=steps_per_loop,
epochs=epochs,
use_remote_tpu=use_remote_tpu)
epochs=epochs)
# Creates the BERT core model outside distribution strategy scope.
_, core_model = bert_models.pretrain_model(bert_config, max_seq_length,
......@@ -173,7 +171,6 @@ def run_bert_pretrain(strategy):
logging.info('Training using customized training loop TF 2.0 with distrubuted'
'strategy.')
use_remote_tpu = (FLAGS.strategy_type == 'tpu' and FLAGS.tpu)
return run_customized_training(
strategy,
bert_config,
......@@ -186,8 +183,7 @@ def run_bert_pretrain(strategy):
FLAGS.learning_rate,
FLAGS.warmup_steps,
FLAGS.input_files,
FLAGS.train_batch_size,
use_remote_tpu=use_remote_tpu)
FLAGS.train_batch_size)
def main(_):
......@@ -200,7 +196,6 @@ def main(_):
if FLAGS.strategy_type == 'mirror':
strategy = tf.distribute.MirroredStrategy()
elif FLAGS.strategy_type == 'tpu':
# Initialize TPU System.
cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu)
strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)
else:
......
......@@ -245,7 +245,6 @@ def train_squad(strategy,
loss_fn = get_loss_fn(
loss_factor=1.0 /
strategy.num_replicas_in_sync if FLAGS.scale_loss else 1.0)
use_remote_tpu = (FLAGS.strategy_type == 'tpu' and FLAGS.tpu)
model_training_utils.run_customized_training_loop(
strategy=strategy,
......@@ -257,7 +256,6 @@ def train_squad(strategy,
epochs=epochs,
train_input_fn=train_input_fn,
init_checkpoint=FLAGS.init_checkpoint,
use_remote_tpu=use_remote_tpu,
run_eagerly=run_eagerly,
custom_callbacks=custom_callbacks)
......@@ -366,7 +364,6 @@ def main(_):
elif FLAGS.strategy_type == 'multi_worker_mirror':
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
elif FLAGS.strategy_type == 'tpu':
# Initialize TPU System.
cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu)
strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)
else:
......
......@@ -126,23 +126,13 @@ def get_metric_fn():
return train_acc_metric
def get_primary_cpu_task(use_remote_tpu=False):
"""Returns primary CPU task to which input pipeline Ops are put."""
# Remote Eager Borg job configures the TPU worker with job name 'worker'.
return "/job:worker" if use_remote_tpu else ""
def main(unused_argv):
del unused_argv
use_remote_tpu = False
if FLAGS.strategy_type == "mirror":
strategy = tf.distribute.MirroredStrategy()
elif FLAGS.strategy_type == "tpu":
# Initialize TPU System.
cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu)
strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)
use_remote_tpu = True
else:
raise ValueError("The distribution strategy type is not supported: %s" %
FLAGS.strategy_type)
......@@ -180,7 +170,6 @@ def main(unused_argv):
input_meta_data["lr_layer_decay_rate"] = FLAGS.lr_layer_decay_rate
input_meta_data["n_class"] = FLAGS.n_class
with tf.device(get_primary_cpu_task(use_remote_tpu)):
training_utils.train(
strategy=strategy,
model_fn=model_fn,
......
......@@ -52,24 +52,14 @@ def get_pretrainxlnet_model(model_config, run_config):
return model
def get_primary_cpu_task(use_remote_tpu=False):
"""Returns primary CPU task to which input pipeline Ops are put."""
# Remote Eager Borg job configures the TPU worker with job name 'worker'.
return "/job:worker" if use_remote_tpu else ""
def main(unused_argv):
del unused_argv
use_remote_tpu = False
num_hosts = 1
if FLAGS.strategy_type == "mirror":
strategy = tf.distribute.MirroredStrategy()
elif FLAGS.strategy_type == "tpu":
# Initialize TPU System.
cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu)
strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)
use_remote_tpu = True
topology = FLAGS.tpu_topology.split("x")
total_num_core = 2 * int(topology[0]) * int(topology[1])
num_hosts = total_num_core // FLAGS.num_core_per_host
......@@ -111,7 +101,6 @@ def main(unused_argv):
model_fn = functools.partial(get_pretrainxlnet_model, model_config,
run_config)
with tf.device(get_primary_cpu_task(use_remote_tpu)):
training_utils.train(
strategy=strategy,
model_fn=model_fn,
......
......@@ -91,13 +91,6 @@ class InputFeatures(object):
self.is_impossible = is_impossible
def get_primary_cpu_task(use_remote_tpu=False):
"""Returns primary CPU task to which input pipeline Ops are put."""
# Remote Eager Borg job configures the TPU worker with job name 'worker'.
return "/job:worker" if use_remote_tpu else ""
# pylint: disable=unused-argument
def run_evaluation(strategy,
test_input_fn,
......@@ -224,14 +217,11 @@ def get_qaxlnet_model(model_config, run_config, start_n_top, end_n_top):
def main(unused_argv):
del unused_argv
use_remote_tpu = False
if FLAGS.strategy_type == "mirror":
strategy = tf.distribute.MirroredStrategy()
elif FLAGS.strategy_type == "tpu":
# Initialize TPU System.
cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu)
strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)
use_remote_tpu = True
else:
raise ValueError("The distribution strategy type is not supported: %s" %
FLAGS.strategy_type)
......@@ -285,7 +275,6 @@ def main(unused_argv):
eval_fn = functools.partial(run_evaluation, strategy, test_input_fn,
eval_steps, input_meta_data)
with tf.device(get_primary_cpu_task(use_remote_tpu)):
training_utils.train(
strategy=strategy,
model_fn=model_fn,
......
......@@ -43,7 +43,6 @@ from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils
from official.utils.misc import model_helpers
from official.utils.flags import core as flags_core
from official.utils.misc import tpu_lib
FLAGS = flags.FLAGS
......@@ -254,10 +253,6 @@ def run_ncf(_):
"val_HR_METRIC", desired_value=FLAGS.hr_threshold)
callbacks.append(early_stopping_callback)
use_remote_tpu = params["use_tpu"] and FLAGS.tpu
primary_cpu_task = tpu_lib.get_primary_cpu_task(use_remote_tpu)
with tf.device(primary_cpu_task):
(train_input_dataset, eval_input_dataset,
num_train_steps, num_eval_steps) = \
(ncf_input_pipeline.create_ncf_input_data(
......
......@@ -128,7 +128,6 @@ def get_distribution_strategy(distribution_strategy="default",
if distribution_strategy == "tpu":
# When tpu_address is an empty string, we communicate with local TPUs.
# Initialize TPU System.
cluster_resolver = tpu_lib.tpu_initialize(tpu_address)
return tf.distribute.experimental.TPUStrategy(cluster_resolver)
......
......@@ -21,18 +21,14 @@ def tpu_initialize(tpu_address):
"""Initializes TPU for TF 2.0 training.
Args:
tpu_address: string, bns address of TPU workers.
tpu_address: string, bns address of master TPU worker.
Returns:
A TPUClusterResolver.
"""
cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
tpu=tpu_address)
tf.config.experimental_connect_to_host(cluster_resolver.master())
if tpu_address not in ('', 'local'):
tf.config.experimental_connect_to_cluster(cluster_resolver)
tf.tpu.experimental.initialize_tpu_system(cluster_resolver)
return cluster_resolver
def get_primary_cpu_task(use_remote_tpu=False):
"""Returns remote TPU worker address. No-op for GPU/CPU training."""
return "/job:worker" if use_remote_tpu else ""
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment