Commit 497989e0 authored by Bruce Fontaine's avatar Bruce Fontaine Committed by A. Unique TensorFlower
Browse files

Use experimental_connect_to_cluster API in TPU lib to support training on a slice of a TPU pod.

PiperOrigin-RevId: 270926016
parent a52564cb
......@@ -130,8 +130,7 @@ def run_customized_training_loop(
after every epoch.
init_checkpoint: Optional checkpoint to load to `sub_model` returned by
`model_fn`.
use_remote_tpu: If true, input pipeline ops are placed in TPU worker host
as an optimization.
use_remote_tpu: Ignored, will be removed in the future.
custom_callbacks: A list of Keras Callbacks objects to run during
training. More specifically, `on_batch_begin()`, `on_batch_end()`,
methods are invoked during training.
......@@ -146,6 +145,8 @@ def run_customized_training_loop(
attribute or when required parameters are set to none. (2) eval args are
not specified correctly. (3) metric_fn must be a callable if specified.
"""
# TODO(bfontain): Remove use_remote_tpu once there are no models using it.
del use_remote_tpu
if _sentinel is not None:
raise ValueError('only call `run_customized_training_loop()` '
......@@ -188,233 +189,232 @@ def run_customized_training_loop(
# To reduce unnecessary send/receive input pipeline operation, we place input
# pipeline ops in worker task.
with tf.device(tpu_lib.get_primary_cpu_task(use_remote_tpu)):
train_iterator = _get_input_iterator(train_input_fn, strategy)
with distribution_utils.get_strategy_scope(strategy):
# To correctly place the model weights on accelerators,
# model and optimizer should be created in scope.
model, sub_model = model_fn()
if not hasattr(model, 'optimizer'):
raise ValueError('User should set optimizer attribute to model '
'inside `model_fn`.')
optimizer = model.optimizer
use_float16 = isinstance(
optimizer, tf.keras.mixed_precision.experimental.LossScaleOptimizer)
if init_checkpoint:
logging.info(
'Checkpoint file %s found and restoring from '
'initial checkpoint for core model.', init_checkpoint)
checkpoint = tf.train.Checkpoint(model=sub_model)
checkpoint.restore(init_checkpoint).assert_consumed()
logging.info('Loading from checkpoint file completed')
train_loss_metric = tf.keras.metrics.Mean(
'training_loss', dtype=tf.float32)
eval_metrics = [metric_fn()] if metric_fn else []
# If evaluation is required, make a copy of metric as it will be used by
# both train and evaluation.
train_metrics = [
metric.__class__.from_config(metric.get_config())
for metric in eval_metrics
]
# Create summary writers
eval_summary_writer = tf.summary.create_file_writer(
os.path.join(model_dir, 'summaries/eval'))
if steps_per_loop >= _MIN_SUMMARY_STEPS:
# Only writes summary when the stats are collected sufficiently over
# enough steps.
train_summary_writer = tf.summary.create_file_writer(
os.path.join(model_dir, 'summaries/train'))
else:
train_summary_writer = None
# Collects training variables.
training_vars = model.trainable_variables
def _replicated_step(inputs):
"""Replicated training step."""
inputs, labels = inputs
with tf.GradientTape() as tape:
model_outputs = model(inputs, training=True)
loss = loss_fn(labels, model_outputs)
if use_float16:
scaled_loss = optimizer.get_scaled_loss(loss)
train_iterator = _get_input_iterator(train_input_fn, strategy)
with distribution_utils.get_strategy_scope(strategy):
# To correctly place the model weights on accelerators,
# model and optimizer should be created in scope.
model, sub_model = model_fn()
if not hasattr(model, 'optimizer'):
raise ValueError('User should set optimizer attribute to model '
'inside `model_fn`.')
optimizer = model.optimizer
use_float16 = isinstance(
optimizer, tf.keras.mixed_precision.experimental.LossScaleOptimizer)
if init_checkpoint:
logging.info(
'Checkpoint file %s found and restoring from '
'initial checkpoint for core model.', init_checkpoint)
checkpoint = tf.train.Checkpoint(model=sub_model)
checkpoint.restore(init_checkpoint).assert_consumed()
logging.info('Loading from checkpoint file completed')
train_loss_metric = tf.keras.metrics.Mean(
'training_loss', dtype=tf.float32)
eval_metrics = [metric_fn()] if metric_fn else []
# If evaluation is required, make a copy of metric as it will be used by
# both train and evaluation.
train_metrics = [
metric.__class__.from_config(metric.get_config())
for metric in eval_metrics
]
# Create summary writers
eval_summary_writer = tf.summary.create_file_writer(
os.path.join(model_dir, 'summaries/eval'))
if steps_per_loop >= _MIN_SUMMARY_STEPS:
# Only writes summary when the stats are collected sufficiently over
# enough steps.
train_summary_writer = tf.summary.create_file_writer(
os.path.join(model_dir, 'summaries/train'))
else:
train_summary_writer = None
# Collects training variables.
training_vars = model.trainable_variables
def _replicated_step(inputs):
"""Replicated training step."""
inputs, labels = inputs
with tf.GradientTape() as tape:
model_outputs = model(inputs, training=True)
loss = loss_fn(labels, model_outputs)
if use_float16:
scaled_grads = tape.gradient(scaled_loss, training_vars)
grads = optimizer.get_unscaled_gradients(scaled_grads)
else:
grads = tape.gradient(loss, training_vars)
optimizer.apply_gradients(zip(grads, training_vars))
# For reporting, the metric takes the mean of losses.
train_loss_metric.update_state(loss)
for metric in train_metrics:
metric.update_state(labels, model_outputs)
scaled_loss = optimizer.get_scaled_loss(loss)
@tf.function
def train_steps(iterator, steps):
"""Performs distributed training steps in a loop.
Args:
iterator: the distributed iterator of training datasets.
steps: an tf.int32 integer tensor to specify number of steps to run
inside host training loop.
if use_float16:
scaled_grads = tape.gradient(scaled_loss, training_vars)
grads = optimizer.get_unscaled_gradients(scaled_grads)
else:
grads = tape.gradient(loss, training_vars)
optimizer.apply_gradients(zip(grads, training_vars))
# For reporting, the metric takes the mean of losses.
train_loss_metric.update_state(loss)
for metric in train_metrics:
metric.update_state(labels, model_outputs)
@tf.function
def train_steps(iterator, steps):
"""Performs distributed training steps in a loop.
Args:
iterator: the distributed iterator of training datasets.
steps: an tf.int32 integer tensor to specify number of steps to run
inside host training loop.
Raises:
ValueError: Any of the arguments or tensor shapes are invalid.
"""
if not isinstance(steps, tf.Tensor):
raise ValueError('steps should be an Tensor. Python object may cause '
'retracing.')
for _ in tf.range(steps):
strategy.experimental_run_v2(_replicated_step, args=(next(iterator),))
Raises:
ValueError: Any of the arguments or tensor shapes are invalid.
"""
if not isinstance(steps, tf.Tensor):
raise ValueError('steps should be an Tensor. Python object may cause '
'retracing.')
def train_single_step(iterator):
"""Performs a distributed training step.
for _ in tf.range(steps):
strategy.experimental_run_v2(_replicated_step, args=(next(iterator),))
Args:
iterator: the distributed iterator of training datasets.
def train_single_step(iterator):
"""Performs a distributed training step.
Raises:
ValueError: Any of the arguments or tensor shapes are invalid.
"""
strategy.experimental_run_v2(_replicated_step, args=(next(iterator),))
Args:
iterator: the distributed iterator of training datasets.
def test_step(iterator):
"""Calculates evaluation metrics on distributed devices."""
Raises:
ValueError: Any of the arguments or tensor shapes are invalid.
"""
strategy.experimental_run_v2(_replicated_step, args=(next(iterator),))
def _test_step_fn(inputs):
"""Replicated accuracy calculation."""
def test_step(iterator):
"""Calculates evaluation metrics on distributed devices."""
inputs, labels = inputs
model_outputs = model(inputs, training=False)
for metric in eval_metrics:
metric.update_state(labels, model_outputs)
def _test_step_fn(inputs):
"""Replicated accuracy calculation."""
strategy.experimental_run_v2(_test_step_fn, args=(next(iterator),))
if not run_eagerly:
train_single_step = tf.function(train_single_step)
test_step = tf.function(test_step)
def _run_evaluation(current_training_step, test_iterator):
"""Runs validation steps and aggregate metrics."""
for _ in range(eval_steps):
test_step(test_iterator)
with eval_summary_writer.as_default():
for metric in eval_metrics + model.metrics:
metric_value = _float_metric_value(metric)
logging.info('Step: [%d] Validation %s = %f', current_training_step,
metric.name, metric_value)
tf.summary.scalar(
metric.name, metric_value, step=current_training_step)
eval_summary_writer.flush()
def _run_callbacks_on_batch_begin(batch):
"""Runs custom callbacks at the start of every step."""
if not custom_callbacks:
return
for callback in custom_callbacks:
callback.on_batch_begin(batch)
def _run_callbacks_on_batch_end(batch):
"""Runs custom callbacks at the end of every step."""
if not custom_callbacks:
return
for callback in custom_callbacks:
callback.on_batch_end(batch)
# Training loop starts here.
checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
latest_checkpoint_file = tf.train.latest_checkpoint(model_dir)
if latest_checkpoint_file:
logging.info(
'Checkpoint file %s found and restoring from '
'checkpoint', latest_checkpoint_file)
checkpoint.restore(latest_checkpoint_file)
logging.info('Loading from checkpoint file completed')
current_step = optimizer.iterations.numpy()
checkpoint_name = 'ctl_step_{step}.ckpt'
while current_step < total_training_steps:
# Training loss/metric are taking average over steps inside micro
# training loop. We reset the their values before each round.
train_loss_metric.reset_states()
for metric in train_metrics + model.metrics:
metric.reset_states()
_run_callbacks_on_batch_begin(current_step)
# Runs several steps in the host while loop.
steps = _steps_to_run(current_step, steps_per_epoch, steps_per_loop)
if steps == 1:
# TODO(zongweiz): merge with train_steps once tf.while_loop
# GPU performance bugs are fixed.
train_single_step(train_iterator)
else:
# Converts steps to a Tensor to avoid tf.function retracing.
train_steps(train_iterator,
tf.convert_to_tensor(steps, dtype=tf.int32))
_run_callbacks_on_batch_end(current_step)
current_step += steps
train_loss = _float_metric_value(train_loss_metric)
# Updates training logging.
training_status = 'Train Step: %d/%d / loss = %s' % (
current_step, total_training_steps, train_loss)
if train_summary_writer:
with train_summary_writer.as_default():
tf.summary.scalar(
train_loss_metric.name, train_loss, step=current_step)
for metric in train_metrics + model.metrics:
metric_value = _float_metric_value(metric)
training_status += ' %s = %f' % (metric.name, metric_value)
tf.summary.scalar(metric.name, metric_value, step=current_step)
train_summary_writer.flush()
logging.info(training_status)
# Saves model checkpoints and run validation steps at every epoch end.
if current_step % steps_per_epoch == 0:
# To avoid repeated model saving, we do not save after the last
# step of training.
if current_step < total_training_steps:
_save_checkpoint(checkpoint, model_dir,
checkpoint_name.format(step=current_step))
if eval_input_fn:
logging.info('Running evaluation after step: %s.', current_step)
_run_evaluation(current_step,
_get_input_iterator(eval_input_fn, strategy))
# Re-initialize evaluation metric.
for metric in eval_metrics + model.metrics:
metric.reset_states()
inputs, labels = inputs
model_outputs = model(inputs, training=False)
for metric in eval_metrics:
metric.update_state(labels, model_outputs)
_save_checkpoint(checkpoint, model_dir,
checkpoint_name.format(step=current_step))
strategy.experimental_run_v2(_test_step_fn, args=(next(iterator),))
if eval_input_fn:
logging.info('Running final evaluation after training is complete.')
_run_evaluation(current_step,
_get_input_iterator(eval_input_fn, strategy))
if not run_eagerly:
train_single_step = tf.function(train_single_step)
test_step = tf.function(test_step)
training_summary = {
'total_training_steps': total_training_steps,
'train_loss': _float_metric_value(train_loss_metric),
}
if eval_metrics:
# TODO(hongkuny): Cleans up summary reporting in text.
training_summary['last_train_metrics'] = _float_metric_value(
train_metrics[0])
training_summary['eval_metrics'] = _float_metric_value(eval_metrics[0])
def _run_evaluation(current_training_step, test_iterator):
"""Runs validation steps and aggregate metrics."""
for _ in range(eval_steps):
test_step(test_iterator)
write_txt_summary(training_summary, model_dir)
with eval_summary_writer.as_default():
for metric in eval_metrics + model.metrics:
metric_value = _float_metric_value(metric)
logging.info('Step: [%d] Validation %s = %f', current_training_step,
metric.name, metric_value)
tf.summary.scalar(
metric.name, metric_value, step=current_training_step)
eval_summary_writer.flush()
def _run_callbacks_on_batch_begin(batch):
"""Runs custom callbacks at the start of every step."""
if not custom_callbacks:
return
for callback in custom_callbacks:
callback.on_batch_begin(batch)
def _run_callbacks_on_batch_end(batch):
"""Runs custom callbacks at the end of every step."""
if not custom_callbacks:
return
for callback in custom_callbacks:
callback.on_batch_end(batch)
# Training loop starts here.
checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
latest_checkpoint_file = tf.train.latest_checkpoint(model_dir)
if latest_checkpoint_file:
logging.info(
'Checkpoint file %s found and restoring from '
'checkpoint', latest_checkpoint_file)
checkpoint.restore(latest_checkpoint_file)
logging.info('Loading from checkpoint file completed')
current_step = optimizer.iterations.numpy()
checkpoint_name = 'ctl_step_{step}.ckpt'
while current_step < total_training_steps:
# Training loss/metric are taking average over steps inside micro
# training loop. We reset the their values before each round.
train_loss_metric.reset_states()
for metric in train_metrics + model.metrics:
metric.reset_states()
_run_callbacks_on_batch_begin(current_step)
# Runs several steps in the host while loop.
steps = _steps_to_run(current_step, steps_per_epoch, steps_per_loop)
if steps == 1:
# TODO(zongweiz): merge with train_steps once tf.while_loop
# GPU performance bugs are fixed.
train_single_step(train_iterator)
else:
# Converts steps to a Tensor to avoid tf.function retracing.
train_steps(train_iterator,
tf.convert_to_tensor(steps, dtype=tf.int32))
_run_callbacks_on_batch_end(current_step)
current_step += steps
train_loss = _float_metric_value(train_loss_metric)
# Updates training logging.
training_status = 'Train Step: %d/%d / loss = %s' % (
current_step, total_training_steps, train_loss)
if train_summary_writer:
with train_summary_writer.as_default():
tf.summary.scalar(
train_loss_metric.name, train_loss, step=current_step)
for metric in train_metrics + model.metrics:
metric_value = _float_metric_value(metric)
training_status += ' %s = %f' % (metric.name, metric_value)
tf.summary.scalar(metric.name, metric_value, step=current_step)
train_summary_writer.flush()
logging.info(training_status)
# Saves model checkpoints and run validation steps at every epoch end.
if current_step % steps_per_epoch == 0:
# To avoid repeated model saving, we do not save after the last
# step of training.
if current_step < total_training_steps:
_save_checkpoint(checkpoint, model_dir,
checkpoint_name.format(step=current_step))
if eval_input_fn:
logging.info('Running evaluation after step: %s.', current_step)
_run_evaluation(current_step,
_get_input_iterator(eval_input_fn, strategy))
# Re-initialize evaluation metric.
for metric in eval_metrics + model.metrics:
metric.reset_states()
_save_checkpoint(checkpoint, model_dir,
checkpoint_name.format(step=current_step))
if eval_input_fn:
logging.info('Running final evaluation after training is complete.')
_run_evaluation(current_step,
_get_input_iterator(eval_input_fn, strategy))
training_summary = {
'total_training_steps': total_training_steps,
'train_loss': _float_metric_value(train_loss_metric),
}
if eval_metrics:
# TODO(hongkuny): Cleans up summary reporting in text.
training_summary['last_train_metrics'] = _float_metric_value(
train_metrics[0])
training_summary['eval_metrics'] = _float_metric_value(eval_metrics[0])
write_txt_summary(training_summary, model_dir)
return model
return model
......@@ -152,7 +152,6 @@ class ModelTrainingUtilsTest(tf.test.TestCase, parameterized.TestCase):
eval_steps=10,
init_checkpoint=None,
metric_fn=metric_fn,
use_remote_tpu=False,
custom_callbacks=None,
run_eagerly=run_eagerly)
......
......@@ -90,7 +90,6 @@ def run_customized_training(strategy,
warmup_steps,
initial_lr,
init_checkpoint,
use_remote_tpu=False,
custom_callbacks=None,
run_eagerly=False):
"""Run BERT classifier training using low-level API."""
......@@ -151,7 +150,6 @@ def run_customized_training(strategy,
eval_steps=eval_steps,
init_checkpoint=init_checkpoint,
metric_fn=metric_fn,
use_remote_tpu=use_remote_tpu,
custom_callbacks=custom_callbacks,
run_eagerly=run_eagerly)
......@@ -201,7 +199,6 @@ def run_bert(strategy, input_meta_data):
# Runs customized training loop.
logging.info('Training using customized training loop TF 2.0 with distrubuted'
'strategy.')
use_remote_tpu = (FLAGS.strategy_type == 'tpu' and FLAGS.tpu)
trained_model = run_customized_training(
strategy,
bert_config,
......@@ -214,13 +211,11 @@ def run_bert(strategy, input_meta_data):
warmup_steps,
FLAGS.learning_rate,
FLAGS.init_checkpoint,
use_remote_tpu=use_remote_tpu,
run_eagerly=FLAGS.run_eagerly)
if FLAGS.model_export_path:
with tf.device(tpu_lib.get_primary_cpu_task(use_remote_tpu)):
model_saving_utils.export_bert_model(
FLAGS.model_export_path, model=trained_model)
model_saving_utils.export_bert_model(
FLAGS.model_export_path, model=trained_model)
return trained_model
......@@ -238,7 +233,6 @@ def main(_):
if FLAGS.strategy_type == 'mirror':
strategy = tf.distribute.MirroredStrategy()
elif FLAGS.strategy_type == 'tpu':
# Initialize TPU System.
cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu)
strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)
else:
......
......@@ -114,8 +114,7 @@ def run_customized_training(strategy,
initial_lr,
warmup_steps,
input_files,
train_batch_size,
use_remote_tpu=False):
train_batch_size):
"""Run BERT pretrain model training using low-level API."""
train_input_fn = functools.partial(get_pretrain_input_data, input_files,
......@@ -148,8 +147,7 @@ def run_customized_training(strategy,
train_input_fn=train_input_fn,
steps_per_epoch=steps_per_epoch,
steps_per_loop=steps_per_loop,
epochs=epochs,
use_remote_tpu=use_remote_tpu)
epochs=epochs)
# Creates the BERT core model outside distribution strategy scope.
_, core_model = bert_models.pretrain_model(bert_config, max_seq_length,
......@@ -173,7 +171,6 @@ def run_bert_pretrain(strategy):
logging.info('Training using customized training loop TF 2.0 with distrubuted'
'strategy.')
use_remote_tpu = (FLAGS.strategy_type == 'tpu' and FLAGS.tpu)
return run_customized_training(
strategy,
bert_config,
......@@ -186,8 +183,7 @@ def run_bert_pretrain(strategy):
FLAGS.learning_rate,
FLAGS.warmup_steps,
FLAGS.input_files,
FLAGS.train_batch_size,
use_remote_tpu=use_remote_tpu)
FLAGS.train_batch_size)
def main(_):
......@@ -200,7 +196,6 @@ def main(_):
if FLAGS.strategy_type == 'mirror':
strategy = tf.distribute.MirroredStrategy()
elif FLAGS.strategy_type == 'tpu':
# Initialize TPU System.
cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu)
strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)
else:
......
......@@ -245,7 +245,6 @@ def train_squad(strategy,
loss_fn = get_loss_fn(
loss_factor=1.0 /
strategy.num_replicas_in_sync if FLAGS.scale_loss else 1.0)
use_remote_tpu = (FLAGS.strategy_type == 'tpu' and FLAGS.tpu)
model_training_utils.run_customized_training_loop(
strategy=strategy,
......@@ -257,7 +256,6 @@ def train_squad(strategy,
epochs=epochs,
train_input_fn=train_input_fn,
init_checkpoint=FLAGS.init_checkpoint,
use_remote_tpu=use_remote_tpu,
run_eagerly=run_eagerly,
custom_callbacks=custom_callbacks)
......@@ -366,7 +364,6 @@ def main(_):
elif FLAGS.strategy_type == 'multi_worker_mirror':
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
elif FLAGS.strategy_type == 'tpu':
# Initialize TPU System.
cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu)
strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)
else:
......
......@@ -126,23 +126,13 @@ def get_metric_fn():
return train_acc_metric
def get_primary_cpu_task(use_remote_tpu=False):
"""Returns primary CPU task to which input pipeline Ops are put."""
# Remote Eager Borg job configures the TPU worker with job name 'worker'.
return "/job:worker" if use_remote_tpu else ""
def main(unused_argv):
del unused_argv
use_remote_tpu = False
if FLAGS.strategy_type == "mirror":
strategy = tf.distribute.MirroredStrategy()
elif FLAGS.strategy_type == "tpu":
# Initialize TPU System.
cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu)
strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)
use_remote_tpu = True
else:
raise ValueError("The distribution strategy type is not supported: %s" %
FLAGS.strategy_type)
......@@ -180,23 +170,22 @@ def main(unused_argv):
input_meta_data["lr_layer_decay_rate"] = FLAGS.lr_layer_decay_rate
input_meta_data["n_class"] = FLAGS.n_class
with tf.device(get_primary_cpu_task(use_remote_tpu)):
training_utils.train(
strategy=strategy,
model_fn=model_fn,
input_meta_data=input_meta_data,
eval_fn=eval_fn,
metric_fn=get_metric_fn,
train_input_fn=train_input_fn,
test_input_fn=test_input_fn,
init_checkpoint=FLAGS.init_checkpoint,
total_training_steps=total_training_steps,
steps_per_epoch=steps_per_epoch,
steps_per_loop=steps_per_loop,
optimizer=optimizer,
learning_rate_fn=learning_rate_fn,
model_dir=FLAGS.model_dir,
save_steps=1000)
training_utils.train(
strategy=strategy,
model_fn=model_fn,
input_meta_data=input_meta_data,
eval_fn=eval_fn,
metric_fn=get_metric_fn,
train_input_fn=train_input_fn,
test_input_fn=test_input_fn,
init_checkpoint=FLAGS.init_checkpoint,
total_training_steps=total_training_steps,
steps_per_epoch=steps_per_epoch,
steps_per_loop=steps_per_loop,
optimizer=optimizer,
learning_rate_fn=learning_rate_fn,
model_dir=FLAGS.model_dir,
save_steps=1000)
if __name__ == "__main__":
......
......@@ -52,24 +52,14 @@ def get_pretrainxlnet_model(model_config, run_config):
return model
def get_primary_cpu_task(use_remote_tpu=False):
"""Returns primary CPU task to which input pipeline Ops are put."""
# Remote Eager Borg job configures the TPU worker with job name 'worker'.
return "/job:worker" if use_remote_tpu else ""
def main(unused_argv):
del unused_argv
use_remote_tpu = False
num_hosts = 1
if FLAGS.strategy_type == "mirror":
strategy = tf.distribute.MirroredStrategy()
elif FLAGS.strategy_type == "tpu":
# Initialize TPU System.
cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu)
strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)
use_remote_tpu = True
topology = FLAGS.tpu_topology.split("x")
total_num_core = 2 * int(topology[0]) * int(topology[1])
num_hosts = total_num_core // FLAGS.num_core_per_host
......@@ -111,23 +101,22 @@ def main(unused_argv):
model_fn = functools.partial(get_pretrainxlnet_model, model_config,
run_config)
with tf.device(get_primary_cpu_task(use_remote_tpu)):
training_utils.train(
strategy=strategy,
model_fn=model_fn,
input_meta_data=input_meta_data,
eval_fn=None,
metric_fn=None,
train_input_fn=train_input_fn,
test_input_fn=None,
init_checkpoint=FLAGS.init_checkpoint,
total_training_steps=total_training_steps,
steps_per_epoch=steps_per_epoch,
steps_per_loop=steps_per_loop,
optimizer=optimizer,
learning_rate_fn=learning_rate_fn,
model_dir=FLAGS.model_dir,
save_steps=FLAGS.save_steps)
training_utils.train(
strategy=strategy,
model_fn=model_fn,
input_meta_data=input_meta_data,
eval_fn=None,
metric_fn=None,
train_input_fn=train_input_fn,
test_input_fn=None,
init_checkpoint=FLAGS.init_checkpoint,
total_training_steps=total_training_steps,
steps_per_epoch=steps_per_epoch,
steps_per_loop=steps_per_loop,
optimizer=optimizer,
learning_rate_fn=learning_rate_fn,
model_dir=FLAGS.model_dir,
save_steps=FLAGS.save_steps)
if __name__ == "__main__":
......
......@@ -91,13 +91,6 @@ class InputFeatures(object):
self.is_impossible = is_impossible
def get_primary_cpu_task(use_remote_tpu=False):
"""Returns primary CPU task to which input pipeline Ops are put."""
# Remote Eager Borg job configures the TPU worker with job name 'worker'.
return "/job:worker" if use_remote_tpu else ""
# pylint: disable=unused-argument
def run_evaluation(strategy,
test_input_fn,
......@@ -224,14 +217,11 @@ def get_qaxlnet_model(model_config, run_config, start_n_top, end_n_top):
def main(unused_argv):
del unused_argv
use_remote_tpu = False
if FLAGS.strategy_type == "mirror":
strategy = tf.distribute.MirroredStrategy()
elif FLAGS.strategy_type == "tpu":
# Initialize TPU System.
cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu)
strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)
use_remote_tpu = True
else:
raise ValueError("The distribution strategy type is not supported: %s" %
FLAGS.strategy_type)
......@@ -285,22 +275,21 @@ def main(unused_argv):
eval_fn = functools.partial(run_evaluation, strategy, test_input_fn,
eval_steps, input_meta_data)
with tf.device(get_primary_cpu_task(use_remote_tpu)):
training_utils.train(
strategy=strategy,
model_fn=model_fn,
input_meta_data=input_meta_data,
eval_fn=eval_fn,
metric_fn=None,
train_input_fn=train_input_fn,
test_input_fn=test_input_fn,
init_checkpoint=FLAGS.init_checkpoint,
total_training_steps=total_training_steps,
steps_per_epoch=steps_per_epoch,
steps_per_loop=steps_per_loop,
optimizer=optimizer,
learning_rate_fn=learning_rate_fn,
model_dir=FLAGS.model_dir)
training_utils.train(
strategy=strategy,
model_fn=model_fn,
input_meta_data=input_meta_data,
eval_fn=eval_fn,
metric_fn=None,
train_input_fn=train_input_fn,
test_input_fn=test_input_fn,
init_checkpoint=FLAGS.init_checkpoint,
total_training_steps=total_training_steps,
steps_per_epoch=steps_per_epoch,
steps_per_loop=steps_per_loop,
optimizer=optimizer,
learning_rate_fn=learning_rate_fn,
model_dir=FLAGS.model_dir)
if __name__ == "__main__":
......
......@@ -43,7 +43,6 @@ from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils
from official.utils.misc import model_helpers
from official.utils.flags import core as flags_core
from official.utils.misc import tpu_lib
FLAGS = flags.FLAGS
......@@ -254,84 +253,80 @@ def run_ncf(_):
"val_HR_METRIC", desired_value=FLAGS.hr_threshold)
callbacks.append(early_stopping_callback)
use_remote_tpu = params["use_tpu"] and FLAGS.tpu
primary_cpu_task = tpu_lib.get_primary_cpu_task(use_remote_tpu)
with tf.device(primary_cpu_task):
(train_input_dataset, eval_input_dataset,
num_train_steps, num_eval_steps) = \
(ncf_input_pipeline.create_ncf_input_data(
params, producer, input_meta_data, strategy))
steps_per_epoch = None if generate_input_online else num_train_steps
with distribution_utils.get_strategy_scope(strategy):
keras_model = _get_keras_model(params)
optimizer = tf.keras.optimizers.Adam(
learning_rate=params["learning_rate"],
beta_1=params["beta1"],
beta_2=params["beta2"],
epsilon=params["epsilon"])
if FLAGS.dtype == "fp16":
optimizer = \
tf.compat.v1.train.experimental.enable_mixed_precision_graph_rewrite(
optimizer,
loss_scale=flags_core.get_loss_scale(FLAGS,
default_for_fp16="dynamic"))
if params["keras_use_ctl"]:
train_loss, eval_results = run_ncf_custom_training(
params,
strategy,
keras_model,
(train_input_dataset, eval_input_dataset,
num_train_steps, num_eval_steps) = \
(ncf_input_pipeline.create_ncf_input_data(
params, producer, input_meta_data, strategy))
steps_per_epoch = None if generate_input_online else num_train_steps
with distribution_utils.get_strategy_scope(strategy):
keras_model = _get_keras_model(params)
optimizer = tf.keras.optimizers.Adam(
learning_rate=params["learning_rate"],
beta_1=params["beta1"],
beta_2=params["beta2"],
epsilon=params["epsilon"])
if FLAGS.dtype == "fp16":
optimizer = \
tf.compat.v1.train.experimental.enable_mixed_precision_graph_rewrite(
optimizer,
callbacks,
train_input_dataset,
eval_input_dataset,
num_train_steps,
num_eval_steps,
generate_input_online=generate_input_online)
loss_scale=flags_core.get_loss_scale(FLAGS,
default_for_fp16="dynamic"))
if params["keras_use_ctl"]:
train_loss, eval_results = run_ncf_custom_training(
params,
strategy,
keras_model,
optimizer,
callbacks,
train_input_dataset,
eval_input_dataset,
num_train_steps,
num_eval_steps,
generate_input_online=generate_input_online)
else:
# TODO(b/138957587): Remove when force_v2_in_keras_compile is on longer
# a valid arg for this model. Also remove as a valid flag.
if FLAGS.force_v2_in_keras_compile is not None:
keras_model.compile(
optimizer=optimizer,
run_eagerly=FLAGS.run_eagerly,
experimental_run_tf_function=FLAGS.force_v2_in_keras_compile)
else:
# TODO(b/138957587): Remove when force_v2_in_keras_compile is on longer
# a valid arg for this model. Also remove as a valid flag.
if FLAGS.force_v2_in_keras_compile is not None:
keras_model.compile(
optimizer=optimizer,
run_eagerly=FLAGS.run_eagerly,
experimental_run_tf_function=FLAGS.force_v2_in_keras_compile)
else:
keras_model.compile(
optimizer=optimizer, run_eagerly=FLAGS.run_eagerly)
history = keras_model.fit(
train_input_dataset,
epochs=FLAGS.train_epochs,
steps_per_epoch=steps_per_epoch,
callbacks=callbacks,
validation_data=eval_input_dataset,
validation_steps=num_eval_steps,
verbose=2)
logging.info("Training done. Start evaluating")
eval_loss_and_metrics = keras_model.evaluate(
eval_input_dataset, steps=num_eval_steps, verbose=2)
logging.info("Keras evaluation is done.")
# Keras evaluate() API returns scalar loss and metric values from
# evaluation as a list. Here, the returned list would contain
# [evaluation loss, hr sum, hr count].
eval_hit_rate = eval_loss_and_metrics[1] / eval_loss_and_metrics[2]
# Format evaluation result into [eval loss, eval hit accuracy].
eval_results = [eval_loss_and_metrics[0], eval_hit_rate]
if history and history.history:
train_history = history.history
train_loss = train_history["loss"][-1]
stats = build_stats(train_loss, eval_results, time_callback)
return stats
keras_model.compile(
optimizer=optimizer, run_eagerly=FLAGS.run_eagerly)
history = keras_model.fit(
train_input_dataset,
epochs=FLAGS.train_epochs,
steps_per_epoch=steps_per_epoch,
callbacks=callbacks,
validation_data=eval_input_dataset,
validation_steps=num_eval_steps,
verbose=2)
logging.info("Training done. Start evaluating")
eval_loss_and_metrics = keras_model.evaluate(
eval_input_dataset, steps=num_eval_steps, verbose=2)
logging.info("Keras evaluation is done.")
# Keras evaluate() API returns scalar loss and metric values from
# evaluation as a list. Here, the returned list would contain
# [evaluation loss, hr sum, hr count].
eval_hit_rate = eval_loss_and_metrics[1] / eval_loss_and_metrics[2]
# Format evaluation result into [eval loss, eval hit accuracy].
eval_results = [eval_loss_and_metrics[0], eval_hit_rate]
if history and history.history:
train_history = history.history
train_loss = train_history["loss"][-1]
stats = build_stats(train_loss, eval_results, time_callback)
return stats
def run_ncf_custom_training(params,
......
......@@ -128,7 +128,6 @@ def get_distribution_strategy(distribution_strategy="default",
if distribution_strategy == "tpu":
# When tpu_address is an empty string, we communicate with local TPUs.
# Initialize TPU System.
cluster_resolver = tpu_lib.tpu_initialize(tpu_address)
return tf.distribute.experimental.TPUStrategy(cluster_resolver)
......
......@@ -21,18 +21,14 @@ def tpu_initialize(tpu_address):
"""Initializes TPU for TF 2.0 training.
Args:
tpu_address: string, bns address of TPU workers.
tpu_address: string, bns address of master TPU worker.
Returns:
A TPUClusterResolver.
"""
cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
tpu=tpu_address)
tf.config.experimental_connect_to_host(cluster_resolver.master())
if tpu_address not in ('', 'local'):
tf.config.experimental_connect_to_cluster(cluster_resolver)
tf.tpu.experimental.initialize_tpu_system(cluster_resolver)
return cluster_resolver
def get_primary_cpu_task(use_remote_tpu=False):
"""Returns remote TPU worker address. No-op for GPU/CPU training."""
return "/job:worker" if use_remote_tpu else ""
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment