Commit 497989e0 authored by Bruce Fontaine's avatar Bruce Fontaine Committed by A. Unique TensorFlower
Browse files

Use experimental_connect_to_cluster API in TPU lib to support training on a slice of a TPU pod.

PiperOrigin-RevId: 270926016
parent a52564cb
...@@ -130,8 +130,7 @@ def run_customized_training_loop( ...@@ -130,8 +130,7 @@ def run_customized_training_loop(
after every epoch. after every epoch.
init_checkpoint: Optional checkpoint to load to `sub_model` returned by init_checkpoint: Optional checkpoint to load to `sub_model` returned by
`model_fn`. `model_fn`.
use_remote_tpu: If true, input pipeline ops are placed in TPU worker host use_remote_tpu: Ignored, will be removed in the future.
as an optimization.
custom_callbacks: A list of Keras Callbacks objects to run during custom_callbacks: A list of Keras Callbacks objects to run during
training. More specifically, `on_batch_begin()`, `on_batch_end()`, training. More specifically, `on_batch_begin()`, `on_batch_end()`,
methods are invoked during training. methods are invoked during training.
...@@ -146,6 +145,8 @@ def run_customized_training_loop( ...@@ -146,6 +145,8 @@ def run_customized_training_loop(
attribute or when required parameters are set to none. (2) eval args are attribute or when required parameters are set to none. (2) eval args are
not specified correctly. (3) metric_fn must be a callable if specified. not specified correctly. (3) metric_fn must be a callable if specified.
""" """
# TODO(bfontain): Remove use_remote_tpu once there are no models using it.
del use_remote_tpu
if _sentinel is not None: if _sentinel is not None:
raise ValueError('only call `run_customized_training_loop()` ' raise ValueError('only call `run_customized_training_loop()` '
...@@ -188,233 +189,232 @@ def run_customized_training_loop( ...@@ -188,233 +189,232 @@ def run_customized_training_loop(
# To reduce unnecessary send/receive input pipeline operation, we place input # To reduce unnecessary send/receive input pipeline operation, we place input
# pipeline ops in worker task. # pipeline ops in worker task.
with tf.device(tpu_lib.get_primary_cpu_task(use_remote_tpu)): train_iterator = _get_input_iterator(train_input_fn, strategy)
train_iterator = _get_input_iterator(train_input_fn, strategy)
with distribution_utils.get_strategy_scope(strategy):
with distribution_utils.get_strategy_scope(strategy): # To correctly place the model weights on accelerators,
# To correctly place the model weights on accelerators, # model and optimizer should be created in scope.
# model and optimizer should be created in scope. model, sub_model = model_fn()
model, sub_model = model_fn() if not hasattr(model, 'optimizer'):
if not hasattr(model, 'optimizer'): raise ValueError('User should set optimizer attribute to model '
raise ValueError('User should set optimizer attribute to model ' 'inside `model_fn`.')
'inside `model_fn`.') optimizer = model.optimizer
optimizer = model.optimizer use_float16 = isinstance(
use_float16 = isinstance( optimizer, tf.keras.mixed_precision.experimental.LossScaleOptimizer)
optimizer, tf.keras.mixed_precision.experimental.LossScaleOptimizer)
if init_checkpoint:
if init_checkpoint: logging.info(
logging.info( 'Checkpoint file %s found and restoring from '
'Checkpoint file %s found and restoring from ' 'initial checkpoint for core model.', init_checkpoint)
'initial checkpoint for core model.', init_checkpoint) checkpoint = tf.train.Checkpoint(model=sub_model)
checkpoint = tf.train.Checkpoint(model=sub_model) checkpoint.restore(init_checkpoint).assert_consumed()
checkpoint.restore(init_checkpoint).assert_consumed() logging.info('Loading from checkpoint file completed')
logging.info('Loading from checkpoint file completed')
train_loss_metric = tf.keras.metrics.Mean(
train_loss_metric = tf.keras.metrics.Mean( 'training_loss', dtype=tf.float32)
'training_loss', dtype=tf.float32) eval_metrics = [metric_fn()] if metric_fn else []
eval_metrics = [metric_fn()] if metric_fn else [] # If evaluation is required, make a copy of metric as it will be used by
# If evaluation is required, make a copy of metric as it will be used by # both train and evaluation.
# both train and evaluation. train_metrics = [
train_metrics = [ metric.__class__.from_config(metric.get_config())
metric.__class__.from_config(metric.get_config()) for metric in eval_metrics
for metric in eval_metrics ]
]
# Create summary writers
# Create summary writers eval_summary_writer = tf.summary.create_file_writer(
eval_summary_writer = tf.summary.create_file_writer( os.path.join(model_dir, 'summaries/eval'))
os.path.join(model_dir, 'summaries/eval')) if steps_per_loop >= _MIN_SUMMARY_STEPS:
if steps_per_loop >= _MIN_SUMMARY_STEPS: # Only writes summary when the stats are collected sufficiently over
# Only writes summary when the stats are collected sufficiently over # enough steps.
# enough steps. train_summary_writer = tf.summary.create_file_writer(
train_summary_writer = tf.summary.create_file_writer( os.path.join(model_dir, 'summaries/train'))
os.path.join(model_dir, 'summaries/train')) else:
else: train_summary_writer = None
train_summary_writer = None
# Collects training variables.
# Collects training variables. training_vars = model.trainable_variables
training_vars = model.trainable_variables
def _replicated_step(inputs):
def _replicated_step(inputs): """Replicated training step."""
"""Replicated training step."""
inputs, labels = inputs
inputs, labels = inputs with tf.GradientTape() as tape:
with tf.GradientTape() as tape: model_outputs = model(inputs, training=True)
model_outputs = model(inputs, training=True) loss = loss_fn(labels, model_outputs)
loss = loss_fn(labels, model_outputs)
if use_float16:
scaled_loss = optimizer.get_scaled_loss(loss)
if use_float16: if use_float16:
scaled_grads = tape.gradient(scaled_loss, training_vars) scaled_loss = optimizer.get_scaled_loss(loss)
grads = optimizer.get_unscaled_gradients(scaled_grads)
else:
grads = tape.gradient(loss, training_vars)
optimizer.apply_gradients(zip(grads, training_vars))
# For reporting, the metric takes the mean of losses.
train_loss_metric.update_state(loss)
for metric in train_metrics:
metric.update_state(labels, model_outputs)
@tf.function if use_float16:
def train_steps(iterator, steps): scaled_grads = tape.gradient(scaled_loss, training_vars)
"""Performs distributed training steps in a loop. grads = optimizer.get_unscaled_gradients(scaled_grads)
else:
Args: grads = tape.gradient(loss, training_vars)
iterator: the distributed iterator of training datasets. optimizer.apply_gradients(zip(grads, training_vars))
steps: an tf.int32 integer tensor to specify number of steps to run # For reporting, the metric takes the mean of losses.
inside host training loop. train_loss_metric.update_state(loss)
for metric in train_metrics:
metric.update_state(labels, model_outputs)
@tf.function
def train_steps(iterator, steps):
"""Performs distributed training steps in a loop.
Args:
iterator: the distributed iterator of training datasets.
steps: an tf.int32 integer tensor to specify number of steps to run
inside host training loop.
Raises:
ValueError: Any of the arguments or tensor shapes are invalid.
"""
if not isinstance(steps, tf.Tensor):
raise ValueError('steps should be an Tensor. Python object may cause '
'retracing.')
for _ in tf.range(steps):
strategy.experimental_run_v2(_replicated_step, args=(next(iterator),))
Raises: def train_single_step(iterator):
ValueError: Any of the arguments or tensor shapes are invalid. """Performs a distributed training step.
"""
if not isinstance(steps, tf.Tensor):
raise ValueError('steps should be an Tensor. Python object may cause '
'retracing.')
for _ in tf.range(steps): Args:
strategy.experimental_run_v2(_replicated_step, args=(next(iterator),)) iterator: the distributed iterator of training datasets.
def train_single_step(iterator): Raises:
"""Performs a distributed training step. ValueError: Any of the arguments or tensor shapes are invalid.
"""
strategy.experimental_run_v2(_replicated_step, args=(next(iterator),))
Args: def test_step(iterator):
iterator: the distributed iterator of training datasets. """Calculates evaluation metrics on distributed devices."""
Raises: def _test_step_fn(inputs):
ValueError: Any of the arguments or tensor shapes are invalid. """Replicated accuracy calculation."""
"""
strategy.experimental_run_v2(_replicated_step, args=(next(iterator),))
def test_step(iterator): inputs, labels = inputs
"""Calculates evaluation metrics on distributed devices.""" model_outputs = model(inputs, training=False)
for metric in eval_metrics:
metric.update_state(labels, model_outputs)
def _test_step_fn(inputs): strategy.experimental_run_v2(_test_step_fn, args=(next(iterator),))
"""Replicated accuracy calculation."""
if not run_eagerly:
train_single_step = tf.function(train_single_step)
test_step = tf.function(test_step)
def _run_evaluation(current_training_step, test_iterator):
"""Runs validation steps and aggregate metrics."""
for _ in range(eval_steps):
test_step(test_iterator)
with eval_summary_writer.as_default():
for metric in eval_metrics + model.metrics:
metric_value = _float_metric_value(metric)
logging.info('Step: [%d] Validation %s = %f', current_training_step,
metric.name, metric_value)
tf.summary.scalar(
metric.name, metric_value, step=current_training_step)
eval_summary_writer.flush()
def _run_callbacks_on_batch_begin(batch):
"""Runs custom callbacks at the start of every step."""
if not custom_callbacks:
return
for callback in custom_callbacks:
callback.on_batch_begin(batch)
def _run_callbacks_on_batch_end(batch):
"""Runs custom callbacks at the end of every step."""
if not custom_callbacks:
return
for callback in custom_callbacks:
callback.on_batch_end(batch)
# Training loop starts here.
checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
latest_checkpoint_file = tf.train.latest_checkpoint(model_dir)
if latest_checkpoint_file:
logging.info(
'Checkpoint file %s found and restoring from '
'checkpoint', latest_checkpoint_file)
checkpoint.restore(latest_checkpoint_file)
logging.info('Loading from checkpoint file completed')
current_step = optimizer.iterations.numpy()
checkpoint_name = 'ctl_step_{step}.ckpt'
while current_step < total_training_steps:
# Training loss/metric are taking average over steps inside micro
# training loop. We reset the their values before each round.
train_loss_metric.reset_states()
for metric in train_metrics + model.metrics:
metric.reset_states()
_run_callbacks_on_batch_begin(current_step)
# Runs several steps in the host while loop.
steps = _steps_to_run(current_step, steps_per_epoch, steps_per_loop)
if steps == 1:
# TODO(zongweiz): merge with train_steps once tf.while_loop
# GPU performance bugs are fixed.
train_single_step(train_iterator)
else:
# Converts steps to a Tensor to avoid tf.function retracing.
train_steps(train_iterator,
tf.convert_to_tensor(steps, dtype=tf.int32))
_run_callbacks_on_batch_end(current_step)
current_step += steps
train_loss = _float_metric_value(train_loss_metric)
# Updates training logging.
training_status = 'Train Step: %d/%d / loss = %s' % (
current_step, total_training_steps, train_loss)
if train_summary_writer:
with train_summary_writer.as_default():
tf.summary.scalar(
train_loss_metric.name, train_loss, step=current_step)
for metric in train_metrics + model.metrics:
metric_value = _float_metric_value(metric)
training_status += ' %s = %f' % (metric.name, metric_value)
tf.summary.scalar(metric.name, metric_value, step=current_step)
train_summary_writer.flush()
logging.info(training_status)
# Saves model checkpoints and run validation steps at every epoch end.
if current_step % steps_per_epoch == 0:
# To avoid repeated model saving, we do not save after the last
# step of training.
if current_step < total_training_steps:
_save_checkpoint(checkpoint, model_dir,
checkpoint_name.format(step=current_step))
if eval_input_fn:
logging.info('Running evaluation after step: %s.', current_step)
_run_evaluation(current_step,
_get_input_iterator(eval_input_fn, strategy))
# Re-initialize evaluation metric.
for metric in eval_metrics + model.metrics:
metric.reset_states()
inputs, labels = inputs _save_checkpoint(checkpoint, model_dir,
model_outputs = model(inputs, training=False) checkpoint_name.format(step=current_step))
for metric in eval_metrics:
metric.update_state(labels, model_outputs)
strategy.experimental_run_v2(_test_step_fn, args=(next(iterator),)) if eval_input_fn:
logging.info('Running final evaluation after training is complete.')
_run_evaluation(current_step,
_get_input_iterator(eval_input_fn, strategy))
if not run_eagerly: training_summary = {
train_single_step = tf.function(train_single_step) 'total_training_steps': total_training_steps,
test_step = tf.function(test_step) 'train_loss': _float_metric_value(train_loss_metric),
}
if eval_metrics:
# TODO(hongkuny): Cleans up summary reporting in text.
training_summary['last_train_metrics'] = _float_metric_value(
train_metrics[0])
training_summary['eval_metrics'] = _float_metric_value(eval_metrics[0])
def _run_evaluation(current_training_step, test_iterator): write_txt_summary(training_summary, model_dir)
"""Runs validation steps and aggregate metrics."""
for _ in range(eval_steps):
test_step(test_iterator)
with eval_summary_writer.as_default(): return model
for metric in eval_metrics + model.metrics:
metric_value = _float_metric_value(metric)
logging.info('Step: [%d] Validation %s = %f', current_training_step,
metric.name, metric_value)
tf.summary.scalar(
metric.name, metric_value, step=current_training_step)
eval_summary_writer.flush()
def _run_callbacks_on_batch_begin(batch):
"""Runs custom callbacks at the start of every step."""
if not custom_callbacks:
return
for callback in custom_callbacks:
callback.on_batch_begin(batch)
def _run_callbacks_on_batch_end(batch):
"""Runs custom callbacks at the end of every step."""
if not custom_callbacks:
return
for callback in custom_callbacks:
callback.on_batch_end(batch)
# Training loop starts here.
checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
latest_checkpoint_file = tf.train.latest_checkpoint(model_dir)
if latest_checkpoint_file:
logging.info(
'Checkpoint file %s found and restoring from '
'checkpoint', latest_checkpoint_file)
checkpoint.restore(latest_checkpoint_file)
logging.info('Loading from checkpoint file completed')
current_step = optimizer.iterations.numpy()
checkpoint_name = 'ctl_step_{step}.ckpt'
while current_step < total_training_steps:
# Training loss/metric are taking average over steps inside micro
# training loop. We reset the their values before each round.
train_loss_metric.reset_states()
for metric in train_metrics + model.metrics:
metric.reset_states()
_run_callbacks_on_batch_begin(current_step)
# Runs several steps in the host while loop.
steps = _steps_to_run(current_step, steps_per_epoch, steps_per_loop)
if steps == 1:
# TODO(zongweiz): merge with train_steps once tf.while_loop
# GPU performance bugs are fixed.
train_single_step(train_iterator)
else:
# Converts steps to a Tensor to avoid tf.function retracing.
train_steps(train_iterator,
tf.convert_to_tensor(steps, dtype=tf.int32))
_run_callbacks_on_batch_end(current_step)
current_step += steps
train_loss = _float_metric_value(train_loss_metric)
# Updates training logging.
training_status = 'Train Step: %d/%d / loss = %s' % (
current_step, total_training_steps, train_loss)
if train_summary_writer:
with train_summary_writer.as_default():
tf.summary.scalar(
train_loss_metric.name, train_loss, step=current_step)
for metric in train_metrics + model.metrics:
metric_value = _float_metric_value(metric)
training_status += ' %s = %f' % (metric.name, metric_value)
tf.summary.scalar(metric.name, metric_value, step=current_step)
train_summary_writer.flush()
logging.info(training_status)
# Saves model checkpoints and run validation steps at every epoch end.
if current_step % steps_per_epoch == 0:
# To avoid repeated model saving, we do not save after the last
# step of training.
if current_step < total_training_steps:
_save_checkpoint(checkpoint, model_dir,
checkpoint_name.format(step=current_step))
if eval_input_fn:
logging.info('Running evaluation after step: %s.', current_step)
_run_evaluation(current_step,
_get_input_iterator(eval_input_fn, strategy))
# Re-initialize evaluation metric.
for metric in eval_metrics + model.metrics:
metric.reset_states()
_save_checkpoint(checkpoint, model_dir,
checkpoint_name.format(step=current_step))
if eval_input_fn:
logging.info('Running final evaluation after training is complete.')
_run_evaluation(current_step,
_get_input_iterator(eval_input_fn, strategy))
training_summary = {
'total_training_steps': total_training_steps,
'train_loss': _float_metric_value(train_loss_metric),
}
if eval_metrics:
# TODO(hongkuny): Cleans up summary reporting in text.
training_summary['last_train_metrics'] = _float_metric_value(
train_metrics[0])
training_summary['eval_metrics'] = _float_metric_value(eval_metrics[0])
write_txt_summary(training_summary, model_dir)
return model
...@@ -152,7 +152,6 @@ class ModelTrainingUtilsTest(tf.test.TestCase, parameterized.TestCase): ...@@ -152,7 +152,6 @@ class ModelTrainingUtilsTest(tf.test.TestCase, parameterized.TestCase):
eval_steps=10, eval_steps=10,
init_checkpoint=None, init_checkpoint=None,
metric_fn=metric_fn, metric_fn=metric_fn,
use_remote_tpu=False,
custom_callbacks=None, custom_callbacks=None,
run_eagerly=run_eagerly) run_eagerly=run_eagerly)
......
...@@ -90,7 +90,6 @@ def run_customized_training(strategy, ...@@ -90,7 +90,6 @@ def run_customized_training(strategy,
warmup_steps, warmup_steps,
initial_lr, initial_lr,
init_checkpoint, init_checkpoint,
use_remote_tpu=False,
custom_callbacks=None, custom_callbacks=None,
run_eagerly=False): run_eagerly=False):
"""Run BERT classifier training using low-level API.""" """Run BERT classifier training using low-level API."""
...@@ -151,7 +150,6 @@ def run_customized_training(strategy, ...@@ -151,7 +150,6 @@ def run_customized_training(strategy,
eval_steps=eval_steps, eval_steps=eval_steps,
init_checkpoint=init_checkpoint, init_checkpoint=init_checkpoint,
metric_fn=metric_fn, metric_fn=metric_fn,
use_remote_tpu=use_remote_tpu,
custom_callbacks=custom_callbacks, custom_callbacks=custom_callbacks,
run_eagerly=run_eagerly) run_eagerly=run_eagerly)
...@@ -201,7 +199,6 @@ def run_bert(strategy, input_meta_data): ...@@ -201,7 +199,6 @@ def run_bert(strategy, input_meta_data):
# Runs customized training loop. # Runs customized training loop.
logging.info('Training using customized training loop TF 2.0 with distrubuted' logging.info('Training using customized training loop TF 2.0 with distrubuted'
'strategy.') 'strategy.')
use_remote_tpu = (FLAGS.strategy_type == 'tpu' and FLAGS.tpu)
trained_model = run_customized_training( trained_model = run_customized_training(
strategy, strategy,
bert_config, bert_config,
...@@ -214,13 +211,11 @@ def run_bert(strategy, input_meta_data): ...@@ -214,13 +211,11 @@ def run_bert(strategy, input_meta_data):
warmup_steps, warmup_steps,
FLAGS.learning_rate, FLAGS.learning_rate,
FLAGS.init_checkpoint, FLAGS.init_checkpoint,
use_remote_tpu=use_remote_tpu,
run_eagerly=FLAGS.run_eagerly) run_eagerly=FLAGS.run_eagerly)
if FLAGS.model_export_path: if FLAGS.model_export_path:
with tf.device(tpu_lib.get_primary_cpu_task(use_remote_tpu)): model_saving_utils.export_bert_model(
model_saving_utils.export_bert_model( FLAGS.model_export_path, model=trained_model)
FLAGS.model_export_path, model=trained_model)
return trained_model return trained_model
...@@ -238,7 +233,6 @@ def main(_): ...@@ -238,7 +233,6 @@ def main(_):
if FLAGS.strategy_type == 'mirror': if FLAGS.strategy_type == 'mirror':
strategy = tf.distribute.MirroredStrategy() strategy = tf.distribute.MirroredStrategy()
elif FLAGS.strategy_type == 'tpu': elif FLAGS.strategy_type == 'tpu':
# Initialize TPU System.
cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu) cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu)
strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver) strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)
else: else:
......
...@@ -114,8 +114,7 @@ def run_customized_training(strategy, ...@@ -114,8 +114,7 @@ def run_customized_training(strategy,
initial_lr, initial_lr,
warmup_steps, warmup_steps,
input_files, input_files,
train_batch_size, train_batch_size):
use_remote_tpu=False):
"""Run BERT pretrain model training using low-level API.""" """Run BERT pretrain model training using low-level API."""
train_input_fn = functools.partial(get_pretrain_input_data, input_files, train_input_fn = functools.partial(get_pretrain_input_data, input_files,
...@@ -148,8 +147,7 @@ def run_customized_training(strategy, ...@@ -148,8 +147,7 @@ def run_customized_training(strategy,
train_input_fn=train_input_fn, train_input_fn=train_input_fn,
steps_per_epoch=steps_per_epoch, steps_per_epoch=steps_per_epoch,
steps_per_loop=steps_per_loop, steps_per_loop=steps_per_loop,
epochs=epochs, epochs=epochs)
use_remote_tpu=use_remote_tpu)
# Creates the BERT core model outside distribution strategy scope. # Creates the BERT core model outside distribution strategy scope.
_, core_model = bert_models.pretrain_model(bert_config, max_seq_length, _, core_model = bert_models.pretrain_model(bert_config, max_seq_length,
...@@ -173,7 +171,6 @@ def run_bert_pretrain(strategy): ...@@ -173,7 +171,6 @@ def run_bert_pretrain(strategy):
logging.info('Training using customized training loop TF 2.0 with distrubuted' logging.info('Training using customized training loop TF 2.0 with distrubuted'
'strategy.') 'strategy.')
use_remote_tpu = (FLAGS.strategy_type == 'tpu' and FLAGS.tpu)
return run_customized_training( return run_customized_training(
strategy, strategy,
bert_config, bert_config,
...@@ -186,8 +183,7 @@ def run_bert_pretrain(strategy): ...@@ -186,8 +183,7 @@ def run_bert_pretrain(strategy):
FLAGS.learning_rate, FLAGS.learning_rate,
FLAGS.warmup_steps, FLAGS.warmup_steps,
FLAGS.input_files, FLAGS.input_files,
FLAGS.train_batch_size, FLAGS.train_batch_size)
use_remote_tpu=use_remote_tpu)
def main(_): def main(_):
...@@ -200,7 +196,6 @@ def main(_): ...@@ -200,7 +196,6 @@ def main(_):
if FLAGS.strategy_type == 'mirror': if FLAGS.strategy_type == 'mirror':
strategy = tf.distribute.MirroredStrategy() strategy = tf.distribute.MirroredStrategy()
elif FLAGS.strategy_type == 'tpu': elif FLAGS.strategy_type == 'tpu':
# Initialize TPU System.
cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu) cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu)
strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver) strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)
else: else:
......
...@@ -245,7 +245,6 @@ def train_squad(strategy, ...@@ -245,7 +245,6 @@ def train_squad(strategy,
loss_fn = get_loss_fn( loss_fn = get_loss_fn(
loss_factor=1.0 / loss_factor=1.0 /
strategy.num_replicas_in_sync if FLAGS.scale_loss else 1.0) strategy.num_replicas_in_sync if FLAGS.scale_loss else 1.0)
use_remote_tpu = (FLAGS.strategy_type == 'tpu' and FLAGS.tpu)
model_training_utils.run_customized_training_loop( model_training_utils.run_customized_training_loop(
strategy=strategy, strategy=strategy,
...@@ -257,7 +256,6 @@ def train_squad(strategy, ...@@ -257,7 +256,6 @@ def train_squad(strategy,
epochs=epochs, epochs=epochs,
train_input_fn=train_input_fn, train_input_fn=train_input_fn,
init_checkpoint=FLAGS.init_checkpoint, init_checkpoint=FLAGS.init_checkpoint,
use_remote_tpu=use_remote_tpu,
run_eagerly=run_eagerly, run_eagerly=run_eagerly,
custom_callbacks=custom_callbacks) custom_callbacks=custom_callbacks)
...@@ -366,7 +364,6 @@ def main(_): ...@@ -366,7 +364,6 @@ def main(_):
elif FLAGS.strategy_type == 'multi_worker_mirror': elif FLAGS.strategy_type == 'multi_worker_mirror':
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy() strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
elif FLAGS.strategy_type == 'tpu': elif FLAGS.strategy_type == 'tpu':
# Initialize TPU System.
cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu) cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu)
strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver) strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)
else: else:
......
...@@ -126,23 +126,13 @@ def get_metric_fn(): ...@@ -126,23 +126,13 @@ def get_metric_fn():
return train_acc_metric return train_acc_metric
def get_primary_cpu_task(use_remote_tpu=False):
"""Returns primary CPU task to which input pipeline Ops are put."""
# Remote Eager Borg job configures the TPU worker with job name 'worker'.
return "/job:worker" if use_remote_tpu else ""
def main(unused_argv): def main(unused_argv):
del unused_argv del unused_argv
use_remote_tpu = False
if FLAGS.strategy_type == "mirror": if FLAGS.strategy_type == "mirror":
strategy = tf.distribute.MirroredStrategy() strategy = tf.distribute.MirroredStrategy()
elif FLAGS.strategy_type == "tpu": elif FLAGS.strategy_type == "tpu":
# Initialize TPU System.
cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu) cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu)
strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver) strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)
use_remote_tpu = True
else: else:
raise ValueError("The distribution strategy type is not supported: %s" % raise ValueError("The distribution strategy type is not supported: %s" %
FLAGS.strategy_type) FLAGS.strategy_type)
...@@ -180,23 +170,22 @@ def main(unused_argv): ...@@ -180,23 +170,22 @@ def main(unused_argv):
input_meta_data["lr_layer_decay_rate"] = FLAGS.lr_layer_decay_rate input_meta_data["lr_layer_decay_rate"] = FLAGS.lr_layer_decay_rate
input_meta_data["n_class"] = FLAGS.n_class input_meta_data["n_class"] = FLAGS.n_class
with tf.device(get_primary_cpu_task(use_remote_tpu)): training_utils.train(
training_utils.train( strategy=strategy,
strategy=strategy, model_fn=model_fn,
model_fn=model_fn, input_meta_data=input_meta_data,
input_meta_data=input_meta_data, eval_fn=eval_fn,
eval_fn=eval_fn, metric_fn=get_metric_fn,
metric_fn=get_metric_fn, train_input_fn=train_input_fn,
train_input_fn=train_input_fn, test_input_fn=test_input_fn,
test_input_fn=test_input_fn, init_checkpoint=FLAGS.init_checkpoint,
init_checkpoint=FLAGS.init_checkpoint, total_training_steps=total_training_steps,
total_training_steps=total_training_steps, steps_per_epoch=steps_per_epoch,
steps_per_epoch=steps_per_epoch, steps_per_loop=steps_per_loop,
steps_per_loop=steps_per_loop, optimizer=optimizer,
optimizer=optimizer, learning_rate_fn=learning_rate_fn,
learning_rate_fn=learning_rate_fn, model_dir=FLAGS.model_dir,
model_dir=FLAGS.model_dir, save_steps=1000)
save_steps=1000)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -52,24 +52,14 @@ def get_pretrainxlnet_model(model_config, run_config): ...@@ -52,24 +52,14 @@ def get_pretrainxlnet_model(model_config, run_config):
return model return model
def get_primary_cpu_task(use_remote_tpu=False):
"""Returns primary CPU task to which input pipeline Ops are put."""
# Remote Eager Borg job configures the TPU worker with job name 'worker'.
return "/job:worker" if use_remote_tpu else ""
def main(unused_argv): def main(unused_argv):
del unused_argv del unused_argv
use_remote_tpu = False
num_hosts = 1 num_hosts = 1
if FLAGS.strategy_type == "mirror": if FLAGS.strategy_type == "mirror":
strategy = tf.distribute.MirroredStrategy() strategy = tf.distribute.MirroredStrategy()
elif FLAGS.strategy_type == "tpu": elif FLAGS.strategy_type == "tpu":
# Initialize TPU System.
cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu) cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu)
strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver) strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)
use_remote_tpu = True
topology = FLAGS.tpu_topology.split("x") topology = FLAGS.tpu_topology.split("x")
total_num_core = 2 * int(topology[0]) * int(topology[1]) total_num_core = 2 * int(topology[0]) * int(topology[1])
num_hosts = total_num_core // FLAGS.num_core_per_host num_hosts = total_num_core // FLAGS.num_core_per_host
...@@ -111,23 +101,22 @@ def main(unused_argv): ...@@ -111,23 +101,22 @@ def main(unused_argv):
model_fn = functools.partial(get_pretrainxlnet_model, model_config, model_fn = functools.partial(get_pretrainxlnet_model, model_config,
run_config) run_config)
with tf.device(get_primary_cpu_task(use_remote_tpu)): training_utils.train(
training_utils.train( strategy=strategy,
strategy=strategy, model_fn=model_fn,
model_fn=model_fn, input_meta_data=input_meta_data,
input_meta_data=input_meta_data, eval_fn=None,
eval_fn=None, metric_fn=None,
metric_fn=None, train_input_fn=train_input_fn,
train_input_fn=train_input_fn, test_input_fn=None,
test_input_fn=None, init_checkpoint=FLAGS.init_checkpoint,
init_checkpoint=FLAGS.init_checkpoint, total_training_steps=total_training_steps,
total_training_steps=total_training_steps, steps_per_epoch=steps_per_epoch,
steps_per_epoch=steps_per_epoch, steps_per_loop=steps_per_loop,
steps_per_loop=steps_per_loop, optimizer=optimizer,
optimizer=optimizer, learning_rate_fn=learning_rate_fn,
learning_rate_fn=learning_rate_fn, model_dir=FLAGS.model_dir,
model_dir=FLAGS.model_dir, save_steps=FLAGS.save_steps)
save_steps=FLAGS.save_steps)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -91,13 +91,6 @@ class InputFeatures(object): ...@@ -91,13 +91,6 @@ class InputFeatures(object):
self.is_impossible = is_impossible self.is_impossible = is_impossible
def get_primary_cpu_task(use_remote_tpu=False):
"""Returns primary CPU task to which input pipeline Ops are put."""
# Remote Eager Borg job configures the TPU worker with job name 'worker'.
return "/job:worker" if use_remote_tpu else ""
# pylint: disable=unused-argument # pylint: disable=unused-argument
def run_evaluation(strategy, def run_evaluation(strategy,
test_input_fn, test_input_fn,
...@@ -224,14 +217,11 @@ def get_qaxlnet_model(model_config, run_config, start_n_top, end_n_top): ...@@ -224,14 +217,11 @@ def get_qaxlnet_model(model_config, run_config, start_n_top, end_n_top):
def main(unused_argv): def main(unused_argv):
del unused_argv del unused_argv
use_remote_tpu = False
if FLAGS.strategy_type == "mirror": if FLAGS.strategy_type == "mirror":
strategy = tf.distribute.MirroredStrategy() strategy = tf.distribute.MirroredStrategy()
elif FLAGS.strategy_type == "tpu": elif FLAGS.strategy_type == "tpu":
# Initialize TPU System.
cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu) cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu)
strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver) strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)
use_remote_tpu = True
else: else:
raise ValueError("The distribution strategy type is not supported: %s" % raise ValueError("The distribution strategy type is not supported: %s" %
FLAGS.strategy_type) FLAGS.strategy_type)
...@@ -285,22 +275,21 @@ def main(unused_argv): ...@@ -285,22 +275,21 @@ def main(unused_argv):
eval_fn = functools.partial(run_evaluation, strategy, test_input_fn, eval_fn = functools.partial(run_evaluation, strategy, test_input_fn,
eval_steps, input_meta_data) eval_steps, input_meta_data)
with tf.device(get_primary_cpu_task(use_remote_tpu)): training_utils.train(
training_utils.train( strategy=strategy,
strategy=strategy, model_fn=model_fn,
model_fn=model_fn, input_meta_data=input_meta_data,
input_meta_data=input_meta_data, eval_fn=eval_fn,
eval_fn=eval_fn, metric_fn=None,
metric_fn=None, train_input_fn=train_input_fn,
train_input_fn=train_input_fn, test_input_fn=test_input_fn,
test_input_fn=test_input_fn, init_checkpoint=FLAGS.init_checkpoint,
init_checkpoint=FLAGS.init_checkpoint, total_training_steps=total_training_steps,
total_training_steps=total_training_steps, steps_per_epoch=steps_per_epoch,
steps_per_epoch=steps_per_epoch, steps_per_loop=steps_per_loop,
steps_per_loop=steps_per_loop, optimizer=optimizer,
optimizer=optimizer, learning_rate_fn=learning_rate_fn,
learning_rate_fn=learning_rate_fn, model_dir=FLAGS.model_dir)
model_dir=FLAGS.model_dir)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -43,7 +43,6 @@ from official.utils.misc import distribution_utils ...@@ -43,7 +43,6 @@ from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils from official.utils.misc import keras_utils
from official.utils.misc import model_helpers from official.utils.misc import model_helpers
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
from official.utils.misc import tpu_lib
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
...@@ -254,84 +253,80 @@ def run_ncf(_): ...@@ -254,84 +253,80 @@ def run_ncf(_):
"val_HR_METRIC", desired_value=FLAGS.hr_threshold) "val_HR_METRIC", desired_value=FLAGS.hr_threshold)
callbacks.append(early_stopping_callback) callbacks.append(early_stopping_callback)
use_remote_tpu = params["use_tpu"] and FLAGS.tpu (train_input_dataset, eval_input_dataset,
primary_cpu_task = tpu_lib.get_primary_cpu_task(use_remote_tpu) num_train_steps, num_eval_steps) = \
(ncf_input_pipeline.create_ncf_input_data(
with tf.device(primary_cpu_task): params, producer, input_meta_data, strategy))
(train_input_dataset, eval_input_dataset, steps_per_epoch = None if generate_input_online else num_train_steps
num_train_steps, num_eval_steps) = \
(ncf_input_pipeline.create_ncf_input_data( with distribution_utils.get_strategy_scope(strategy):
params, producer, input_meta_data, strategy)) keras_model = _get_keras_model(params)
steps_per_epoch = None if generate_input_online else num_train_steps optimizer = tf.keras.optimizers.Adam(
learning_rate=params["learning_rate"],
with distribution_utils.get_strategy_scope(strategy): beta_1=params["beta1"],
keras_model = _get_keras_model(params) beta_2=params["beta2"],
optimizer = tf.keras.optimizers.Adam( epsilon=params["epsilon"])
learning_rate=params["learning_rate"], if FLAGS.dtype == "fp16":
beta_1=params["beta1"], optimizer = \
beta_2=params["beta2"], tf.compat.v1.train.experimental.enable_mixed_precision_graph_rewrite(
epsilon=params["epsilon"])
if FLAGS.dtype == "fp16":
optimizer = \
tf.compat.v1.train.experimental.enable_mixed_precision_graph_rewrite(
optimizer,
loss_scale=flags_core.get_loss_scale(FLAGS,
default_for_fp16="dynamic"))
if params["keras_use_ctl"]:
train_loss, eval_results = run_ncf_custom_training(
params,
strategy,
keras_model,
optimizer, optimizer,
callbacks, loss_scale=flags_core.get_loss_scale(FLAGS,
train_input_dataset, default_for_fp16="dynamic"))
eval_input_dataset,
num_train_steps, if params["keras_use_ctl"]:
num_eval_steps, train_loss, eval_results = run_ncf_custom_training(
generate_input_online=generate_input_online) params,
strategy,
keras_model,
optimizer,
callbacks,
train_input_dataset,
eval_input_dataset,
num_train_steps,
num_eval_steps,
generate_input_online=generate_input_online)
else:
# TODO(b/138957587): Remove when force_v2_in_keras_compile is on longer
# a valid arg for this model. Also remove as a valid flag.
if FLAGS.force_v2_in_keras_compile is not None:
keras_model.compile(
optimizer=optimizer,
run_eagerly=FLAGS.run_eagerly,
experimental_run_tf_function=FLAGS.force_v2_in_keras_compile)
else: else:
# TODO(b/138957587): Remove when force_v2_in_keras_compile is on longer keras_model.compile(
# a valid arg for this model. Also remove as a valid flag. optimizer=optimizer, run_eagerly=FLAGS.run_eagerly)
if FLAGS.force_v2_in_keras_compile is not None:
keras_model.compile( history = keras_model.fit(
optimizer=optimizer, train_input_dataset,
run_eagerly=FLAGS.run_eagerly, epochs=FLAGS.train_epochs,
experimental_run_tf_function=FLAGS.force_v2_in_keras_compile) steps_per_epoch=steps_per_epoch,
else: callbacks=callbacks,
keras_model.compile( validation_data=eval_input_dataset,
optimizer=optimizer, run_eagerly=FLAGS.run_eagerly) validation_steps=num_eval_steps,
verbose=2)
history = keras_model.fit(
train_input_dataset, logging.info("Training done. Start evaluating")
epochs=FLAGS.train_epochs,
steps_per_epoch=steps_per_epoch, eval_loss_and_metrics = keras_model.evaluate(
callbacks=callbacks, eval_input_dataset, steps=num_eval_steps, verbose=2)
validation_data=eval_input_dataset,
validation_steps=num_eval_steps, logging.info("Keras evaluation is done.")
verbose=2)
# Keras evaluate() API returns scalar loss and metric values from
logging.info("Training done. Start evaluating") # evaluation as a list. Here, the returned list would contain
# [evaluation loss, hr sum, hr count].
eval_loss_and_metrics = keras_model.evaluate( eval_hit_rate = eval_loss_and_metrics[1] / eval_loss_and_metrics[2]
eval_input_dataset, steps=num_eval_steps, verbose=2)
# Format evaluation result into [eval loss, eval hit accuracy].
logging.info("Keras evaluation is done.") eval_results = [eval_loss_and_metrics[0], eval_hit_rate]
# Keras evaluate() API returns scalar loss and metric values from if history and history.history:
# evaluation as a list. Here, the returned list would contain train_history = history.history
# [evaluation loss, hr sum, hr count]. train_loss = train_history["loss"][-1]
eval_hit_rate = eval_loss_and_metrics[1] / eval_loss_and_metrics[2]
stats = build_stats(train_loss, eval_results, time_callback)
# Format evaluation result into [eval loss, eval hit accuracy]. return stats
eval_results = [eval_loss_and_metrics[0], eval_hit_rate]
if history and history.history:
train_history = history.history
train_loss = train_history["loss"][-1]
stats = build_stats(train_loss, eval_results, time_callback)
return stats
def run_ncf_custom_training(params, def run_ncf_custom_training(params,
......
...@@ -128,7 +128,6 @@ def get_distribution_strategy(distribution_strategy="default", ...@@ -128,7 +128,6 @@ def get_distribution_strategy(distribution_strategy="default",
if distribution_strategy == "tpu": if distribution_strategy == "tpu":
# When tpu_address is an empty string, we communicate with local TPUs. # When tpu_address is an empty string, we communicate with local TPUs.
# Initialize TPU System.
cluster_resolver = tpu_lib.tpu_initialize(tpu_address) cluster_resolver = tpu_lib.tpu_initialize(tpu_address)
return tf.distribute.experimental.TPUStrategy(cluster_resolver) return tf.distribute.experimental.TPUStrategy(cluster_resolver)
......
...@@ -21,18 +21,14 @@ def tpu_initialize(tpu_address): ...@@ -21,18 +21,14 @@ def tpu_initialize(tpu_address):
"""Initializes TPU for TF 2.0 training. """Initializes TPU for TF 2.0 training.
Args: Args:
tpu_address: string, bns address of TPU workers. tpu_address: string, bns address of master TPU worker.
Returns: Returns:
A TPUClusterResolver. A TPUClusterResolver.
""" """
cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver( cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
tpu=tpu_address) tpu=tpu_address)
tf.config.experimental_connect_to_host(cluster_resolver.master()) if tpu_address not in ('', 'local'):
tf.config.experimental_connect_to_cluster(cluster_resolver)
tf.tpu.experimental.initialize_tpu_system(cluster_resolver) tf.tpu.experimental.initialize_tpu_system(cluster_resolver)
return cluster_resolver return cluster_resolver
def get_primary_cpu_task(use_remote_tpu=False):
"""Returns remote TPU worker address. No-op for GPU/CPU training."""
return "/job:worker" if use_remote_tpu else ""
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment