Commit 497989e0 authored by Bruce Fontaine's avatar Bruce Fontaine Committed by A. Unique TensorFlower
Browse files

Use experimental_connect_to_cluster API in TPU lib to support training on a slice of a TPU pod.

PiperOrigin-RevId: 270926016
parent a52564cb
...@@ -130,8 +130,7 @@ def run_customized_training_loop( ...@@ -130,8 +130,7 @@ def run_customized_training_loop(
after every epoch. after every epoch.
init_checkpoint: Optional checkpoint to load to `sub_model` returned by init_checkpoint: Optional checkpoint to load to `sub_model` returned by
`model_fn`. `model_fn`.
use_remote_tpu: If true, input pipeline ops are placed in TPU worker host use_remote_tpu: Ignored, will be removed in the future.
as an optimization.
custom_callbacks: A list of Keras Callbacks objects to run during custom_callbacks: A list of Keras Callbacks objects to run during
training. More specifically, `on_batch_begin()`, `on_batch_end()`, training. More specifically, `on_batch_begin()`, `on_batch_end()`,
methods are invoked during training. methods are invoked during training.
...@@ -146,6 +145,8 @@ def run_customized_training_loop( ...@@ -146,6 +145,8 @@ def run_customized_training_loop(
attribute or when required parameters are set to none. (2) eval args are attribute or when required parameters are set to none. (2) eval args are
not specified correctly. (3) metric_fn must be a callable if specified. not specified correctly. (3) metric_fn must be a callable if specified.
""" """
# TODO(bfontain): Remove use_remote_tpu once there are no models using it.
del use_remote_tpu
if _sentinel is not None: if _sentinel is not None:
raise ValueError('only call `run_customized_training_loop()` ' raise ValueError('only call `run_customized_training_loop()` '
...@@ -188,7 +189,6 @@ def run_customized_training_loop( ...@@ -188,7 +189,6 @@ def run_customized_training_loop(
# To reduce unnecessary send/receive input pipeline operation, we place input # To reduce unnecessary send/receive input pipeline operation, we place input
# pipeline ops in worker task. # pipeline ops in worker task.
with tf.device(tpu_lib.get_primary_cpu_task(use_remote_tpu)):
train_iterator = _get_input_iterator(train_input_fn, strategy) train_iterator = _get_input_iterator(train_input_fn, strategy)
with distribution_utils.get_strategy_scope(strategy): with distribution_utils.get_strategy_scope(strategy):
......
...@@ -152,7 +152,6 @@ class ModelTrainingUtilsTest(tf.test.TestCase, parameterized.TestCase): ...@@ -152,7 +152,6 @@ class ModelTrainingUtilsTest(tf.test.TestCase, parameterized.TestCase):
eval_steps=10, eval_steps=10,
init_checkpoint=None, init_checkpoint=None,
metric_fn=metric_fn, metric_fn=metric_fn,
use_remote_tpu=False,
custom_callbacks=None, custom_callbacks=None,
run_eagerly=run_eagerly) run_eagerly=run_eagerly)
......
...@@ -90,7 +90,6 @@ def run_customized_training(strategy, ...@@ -90,7 +90,6 @@ def run_customized_training(strategy,
warmup_steps, warmup_steps,
initial_lr, initial_lr,
init_checkpoint, init_checkpoint,
use_remote_tpu=False,
custom_callbacks=None, custom_callbacks=None,
run_eagerly=False): run_eagerly=False):
"""Run BERT classifier training using low-level API.""" """Run BERT classifier training using low-level API."""
...@@ -151,7 +150,6 @@ def run_customized_training(strategy, ...@@ -151,7 +150,6 @@ def run_customized_training(strategy,
eval_steps=eval_steps, eval_steps=eval_steps,
init_checkpoint=init_checkpoint, init_checkpoint=init_checkpoint,
metric_fn=metric_fn, metric_fn=metric_fn,
use_remote_tpu=use_remote_tpu,
custom_callbacks=custom_callbacks, custom_callbacks=custom_callbacks,
run_eagerly=run_eagerly) run_eagerly=run_eagerly)
...@@ -201,7 +199,6 @@ def run_bert(strategy, input_meta_data): ...@@ -201,7 +199,6 @@ def run_bert(strategy, input_meta_data):
# Runs customized training loop. # Runs customized training loop.
logging.info('Training using customized training loop TF 2.0 with distrubuted' logging.info('Training using customized training loop TF 2.0 with distrubuted'
'strategy.') 'strategy.')
use_remote_tpu = (FLAGS.strategy_type == 'tpu' and FLAGS.tpu)
trained_model = run_customized_training( trained_model = run_customized_training(
strategy, strategy,
bert_config, bert_config,
...@@ -214,11 +211,9 @@ def run_bert(strategy, input_meta_data): ...@@ -214,11 +211,9 @@ def run_bert(strategy, input_meta_data):
warmup_steps, warmup_steps,
FLAGS.learning_rate, FLAGS.learning_rate,
FLAGS.init_checkpoint, FLAGS.init_checkpoint,
use_remote_tpu=use_remote_tpu,
run_eagerly=FLAGS.run_eagerly) run_eagerly=FLAGS.run_eagerly)
if FLAGS.model_export_path: if FLAGS.model_export_path:
with tf.device(tpu_lib.get_primary_cpu_task(use_remote_tpu)):
model_saving_utils.export_bert_model( model_saving_utils.export_bert_model(
FLAGS.model_export_path, model=trained_model) FLAGS.model_export_path, model=trained_model)
return trained_model return trained_model
...@@ -238,7 +233,6 @@ def main(_): ...@@ -238,7 +233,6 @@ def main(_):
if FLAGS.strategy_type == 'mirror': if FLAGS.strategy_type == 'mirror':
strategy = tf.distribute.MirroredStrategy() strategy = tf.distribute.MirroredStrategy()
elif FLAGS.strategy_type == 'tpu': elif FLAGS.strategy_type == 'tpu':
# Initialize TPU System.
cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu) cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu)
strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver) strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)
else: else:
......
...@@ -114,8 +114,7 @@ def run_customized_training(strategy, ...@@ -114,8 +114,7 @@ def run_customized_training(strategy,
initial_lr, initial_lr,
warmup_steps, warmup_steps,
input_files, input_files,
train_batch_size, train_batch_size):
use_remote_tpu=False):
"""Run BERT pretrain model training using low-level API.""" """Run BERT pretrain model training using low-level API."""
train_input_fn = functools.partial(get_pretrain_input_data, input_files, train_input_fn = functools.partial(get_pretrain_input_data, input_files,
...@@ -148,8 +147,7 @@ def run_customized_training(strategy, ...@@ -148,8 +147,7 @@ def run_customized_training(strategy,
train_input_fn=train_input_fn, train_input_fn=train_input_fn,
steps_per_epoch=steps_per_epoch, steps_per_epoch=steps_per_epoch,
steps_per_loop=steps_per_loop, steps_per_loop=steps_per_loop,
epochs=epochs, epochs=epochs)
use_remote_tpu=use_remote_tpu)
# Creates the BERT core model outside distribution strategy scope. # Creates the BERT core model outside distribution strategy scope.
_, core_model = bert_models.pretrain_model(bert_config, max_seq_length, _, core_model = bert_models.pretrain_model(bert_config, max_seq_length,
...@@ -173,7 +171,6 @@ def run_bert_pretrain(strategy): ...@@ -173,7 +171,6 @@ def run_bert_pretrain(strategy):
logging.info('Training using customized training loop TF 2.0 with distrubuted' logging.info('Training using customized training loop TF 2.0 with distrubuted'
'strategy.') 'strategy.')
use_remote_tpu = (FLAGS.strategy_type == 'tpu' and FLAGS.tpu)
return run_customized_training( return run_customized_training(
strategy, strategy,
bert_config, bert_config,
...@@ -186,8 +183,7 @@ def run_bert_pretrain(strategy): ...@@ -186,8 +183,7 @@ def run_bert_pretrain(strategy):
FLAGS.learning_rate, FLAGS.learning_rate,
FLAGS.warmup_steps, FLAGS.warmup_steps,
FLAGS.input_files, FLAGS.input_files,
FLAGS.train_batch_size, FLAGS.train_batch_size)
use_remote_tpu=use_remote_tpu)
def main(_): def main(_):
...@@ -200,7 +196,6 @@ def main(_): ...@@ -200,7 +196,6 @@ def main(_):
if FLAGS.strategy_type == 'mirror': if FLAGS.strategy_type == 'mirror':
strategy = tf.distribute.MirroredStrategy() strategy = tf.distribute.MirroredStrategy()
elif FLAGS.strategy_type == 'tpu': elif FLAGS.strategy_type == 'tpu':
# Initialize TPU System.
cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu) cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu)
strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver) strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)
else: else:
......
...@@ -245,7 +245,6 @@ def train_squad(strategy, ...@@ -245,7 +245,6 @@ def train_squad(strategy,
loss_fn = get_loss_fn( loss_fn = get_loss_fn(
loss_factor=1.0 / loss_factor=1.0 /
strategy.num_replicas_in_sync if FLAGS.scale_loss else 1.0) strategy.num_replicas_in_sync if FLAGS.scale_loss else 1.0)
use_remote_tpu = (FLAGS.strategy_type == 'tpu' and FLAGS.tpu)
model_training_utils.run_customized_training_loop( model_training_utils.run_customized_training_loop(
strategy=strategy, strategy=strategy,
...@@ -257,7 +256,6 @@ def train_squad(strategy, ...@@ -257,7 +256,6 @@ def train_squad(strategy,
epochs=epochs, epochs=epochs,
train_input_fn=train_input_fn, train_input_fn=train_input_fn,
init_checkpoint=FLAGS.init_checkpoint, init_checkpoint=FLAGS.init_checkpoint,
use_remote_tpu=use_remote_tpu,
run_eagerly=run_eagerly, run_eagerly=run_eagerly,
custom_callbacks=custom_callbacks) custom_callbacks=custom_callbacks)
...@@ -366,7 +364,6 @@ def main(_): ...@@ -366,7 +364,6 @@ def main(_):
elif FLAGS.strategy_type == 'multi_worker_mirror': elif FLAGS.strategy_type == 'multi_worker_mirror':
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy() strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
elif FLAGS.strategy_type == 'tpu': elif FLAGS.strategy_type == 'tpu':
# Initialize TPU System.
cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu) cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu)
strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver) strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)
else: else:
......
...@@ -126,23 +126,13 @@ def get_metric_fn(): ...@@ -126,23 +126,13 @@ def get_metric_fn():
return train_acc_metric return train_acc_metric
def get_primary_cpu_task(use_remote_tpu=False):
"""Returns primary CPU task to which input pipeline Ops are put."""
# Remote Eager Borg job configures the TPU worker with job name 'worker'.
return "/job:worker" if use_remote_tpu else ""
def main(unused_argv): def main(unused_argv):
del unused_argv del unused_argv
use_remote_tpu = False
if FLAGS.strategy_type == "mirror": if FLAGS.strategy_type == "mirror":
strategy = tf.distribute.MirroredStrategy() strategy = tf.distribute.MirroredStrategy()
elif FLAGS.strategy_type == "tpu": elif FLAGS.strategy_type == "tpu":
# Initialize TPU System.
cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu) cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu)
strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver) strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)
use_remote_tpu = True
else: else:
raise ValueError("The distribution strategy type is not supported: %s" % raise ValueError("The distribution strategy type is not supported: %s" %
FLAGS.strategy_type) FLAGS.strategy_type)
...@@ -180,7 +170,6 @@ def main(unused_argv): ...@@ -180,7 +170,6 @@ def main(unused_argv):
input_meta_data["lr_layer_decay_rate"] = FLAGS.lr_layer_decay_rate input_meta_data["lr_layer_decay_rate"] = FLAGS.lr_layer_decay_rate
input_meta_data["n_class"] = FLAGS.n_class input_meta_data["n_class"] = FLAGS.n_class
with tf.device(get_primary_cpu_task(use_remote_tpu)):
training_utils.train( training_utils.train(
strategy=strategy, strategy=strategy,
model_fn=model_fn, model_fn=model_fn,
......
...@@ -52,24 +52,14 @@ def get_pretrainxlnet_model(model_config, run_config): ...@@ -52,24 +52,14 @@ def get_pretrainxlnet_model(model_config, run_config):
return model return model
def get_primary_cpu_task(use_remote_tpu=False):
"""Returns primary CPU task to which input pipeline Ops are put."""
# Remote Eager Borg job configures the TPU worker with job name 'worker'.
return "/job:worker" if use_remote_tpu else ""
def main(unused_argv): def main(unused_argv):
del unused_argv del unused_argv
use_remote_tpu = False
num_hosts = 1 num_hosts = 1
if FLAGS.strategy_type == "mirror": if FLAGS.strategy_type == "mirror":
strategy = tf.distribute.MirroredStrategy() strategy = tf.distribute.MirroredStrategy()
elif FLAGS.strategy_type == "tpu": elif FLAGS.strategy_type == "tpu":
# Initialize TPU System.
cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu) cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu)
strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver) strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)
use_remote_tpu = True
topology = FLAGS.tpu_topology.split("x") topology = FLAGS.tpu_topology.split("x")
total_num_core = 2 * int(topology[0]) * int(topology[1]) total_num_core = 2 * int(topology[0]) * int(topology[1])
num_hosts = total_num_core // FLAGS.num_core_per_host num_hosts = total_num_core // FLAGS.num_core_per_host
...@@ -111,7 +101,6 @@ def main(unused_argv): ...@@ -111,7 +101,6 @@ def main(unused_argv):
model_fn = functools.partial(get_pretrainxlnet_model, model_config, model_fn = functools.partial(get_pretrainxlnet_model, model_config,
run_config) run_config)
with tf.device(get_primary_cpu_task(use_remote_tpu)):
training_utils.train( training_utils.train(
strategy=strategy, strategy=strategy,
model_fn=model_fn, model_fn=model_fn,
......
...@@ -91,13 +91,6 @@ class InputFeatures(object): ...@@ -91,13 +91,6 @@ class InputFeatures(object):
self.is_impossible = is_impossible self.is_impossible = is_impossible
def get_primary_cpu_task(use_remote_tpu=False):
"""Returns primary CPU task to which input pipeline Ops are put."""
# Remote Eager Borg job configures the TPU worker with job name 'worker'.
return "/job:worker" if use_remote_tpu else ""
# pylint: disable=unused-argument # pylint: disable=unused-argument
def run_evaluation(strategy, def run_evaluation(strategy,
test_input_fn, test_input_fn,
...@@ -224,14 +217,11 @@ def get_qaxlnet_model(model_config, run_config, start_n_top, end_n_top): ...@@ -224,14 +217,11 @@ def get_qaxlnet_model(model_config, run_config, start_n_top, end_n_top):
def main(unused_argv): def main(unused_argv):
del unused_argv del unused_argv
use_remote_tpu = False
if FLAGS.strategy_type == "mirror": if FLAGS.strategy_type == "mirror":
strategy = tf.distribute.MirroredStrategy() strategy = tf.distribute.MirroredStrategy()
elif FLAGS.strategy_type == "tpu": elif FLAGS.strategy_type == "tpu":
# Initialize TPU System.
cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu) cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu)
strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver) strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)
use_remote_tpu = True
else: else:
raise ValueError("The distribution strategy type is not supported: %s" % raise ValueError("The distribution strategy type is not supported: %s" %
FLAGS.strategy_type) FLAGS.strategy_type)
...@@ -285,7 +275,6 @@ def main(unused_argv): ...@@ -285,7 +275,6 @@ def main(unused_argv):
eval_fn = functools.partial(run_evaluation, strategy, test_input_fn, eval_fn = functools.partial(run_evaluation, strategy, test_input_fn,
eval_steps, input_meta_data) eval_steps, input_meta_data)
with tf.device(get_primary_cpu_task(use_remote_tpu)):
training_utils.train( training_utils.train(
strategy=strategy, strategy=strategy,
model_fn=model_fn, model_fn=model_fn,
......
...@@ -43,7 +43,6 @@ from official.utils.misc import distribution_utils ...@@ -43,7 +43,6 @@ from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils from official.utils.misc import keras_utils
from official.utils.misc import model_helpers from official.utils.misc import model_helpers
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
from official.utils.misc import tpu_lib
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
...@@ -254,10 +253,6 @@ def run_ncf(_): ...@@ -254,10 +253,6 @@ def run_ncf(_):
"val_HR_METRIC", desired_value=FLAGS.hr_threshold) "val_HR_METRIC", desired_value=FLAGS.hr_threshold)
callbacks.append(early_stopping_callback) callbacks.append(early_stopping_callback)
use_remote_tpu = params["use_tpu"] and FLAGS.tpu
primary_cpu_task = tpu_lib.get_primary_cpu_task(use_remote_tpu)
with tf.device(primary_cpu_task):
(train_input_dataset, eval_input_dataset, (train_input_dataset, eval_input_dataset,
num_train_steps, num_eval_steps) = \ num_train_steps, num_eval_steps) = \
(ncf_input_pipeline.create_ncf_input_data( (ncf_input_pipeline.create_ncf_input_data(
......
...@@ -128,7 +128,6 @@ def get_distribution_strategy(distribution_strategy="default", ...@@ -128,7 +128,6 @@ def get_distribution_strategy(distribution_strategy="default",
if distribution_strategy == "tpu": if distribution_strategy == "tpu":
# When tpu_address is an empty string, we communicate with local TPUs. # When tpu_address is an empty string, we communicate with local TPUs.
# Initialize TPU System.
cluster_resolver = tpu_lib.tpu_initialize(tpu_address) cluster_resolver = tpu_lib.tpu_initialize(tpu_address)
return tf.distribute.experimental.TPUStrategy(cluster_resolver) return tf.distribute.experimental.TPUStrategy(cluster_resolver)
......
...@@ -21,18 +21,14 @@ def tpu_initialize(tpu_address): ...@@ -21,18 +21,14 @@ def tpu_initialize(tpu_address):
"""Initializes TPU for TF 2.0 training. """Initializes TPU for TF 2.0 training.
Args: Args:
tpu_address: string, bns address of TPU workers. tpu_address: string, bns address of master TPU worker.
Returns: Returns:
A TPUClusterResolver. A TPUClusterResolver.
""" """
cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver( cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
tpu=tpu_address) tpu=tpu_address)
tf.config.experimental_connect_to_host(cluster_resolver.master()) if tpu_address not in ('', 'local'):
tf.config.experimental_connect_to_cluster(cluster_resolver)
tf.tpu.experimental.initialize_tpu_system(cluster_resolver) tf.tpu.experimental.initialize_tpu_system(cluster_resolver)
return cluster_resolver return cluster_resolver
def get_primary_cpu_task(use_remote_tpu=False):
"""Returns remote TPU worker address. No-op for GPU/CPU training."""
return "/job:worker" if use_remote_tpu else ""
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment