Unverified Commit 51289259 authored by saberkun's avatar saberkun Committed by GitHub
Browse files

Merged commit includes the following changes: (#6963)

251681245  by hongkuny<hongkuny@google.com>:

    Update bert to use the new tf.distribute APIs

--
251575972  by A. Unique TensorFlower<gardener@tensorflow.org>:

    Remove `steps_per_run` when instantiating TPUStrategy.

--

PiperOrigin-RevId: 251681245
parent d01ac976
...@@ -125,7 +125,8 @@ def run_customized_training_loop( ...@@ -125,7 +125,8 @@ def run_customized_training_loop(
# To reduce unnecessary send/receive input pipeline operation, we place input # To reduce unnecessary send/receive input pipeline operation, we place input
# pipeline ops in worker task. # pipeline ops in worker task.
with tf.device(get_primary_cpu_task(use_remote_tpu)): with tf.device(get_primary_cpu_task(use_remote_tpu)):
train_iterator = strategy.make_dataset_iterator(train_input_fn()) train_iterator = iter(
strategy.experimental_distribute_dataset(train_input_fn()))
with strategy.scope(): with strategy.scope():
total_training_steps = steps_per_epoch * epochs total_training_steps = steps_per_epoch * epochs
...@@ -171,9 +172,8 @@ def run_customized_training_loop( ...@@ -171,9 +172,8 @@ def run_customized_training_loop(
optimizer.apply_gradients(zip(grads, tvars)) optimizer.apply_gradients(zip(grads, tvars))
return loss return loss
per_replica_losses = strategy.experimental_run(_replicated_step, per_replica_losses = strategy.experimental_run_v2(
iterator) _replicated_step, args=(next(iterator),))
# For reporting, we returns the mean of losses. # For reporting, we returns the mean of losses.
loss = strategy.reduce( loss = strategy.reduce(
tf.distribute.ReduceOp.MEAN, per_replica_losses, axis=None) tf.distribute.ReduceOp.MEAN, per_replica_losses, axis=None)
...@@ -190,10 +190,11 @@ def run_customized_training_loop( ...@@ -190,10 +190,11 @@ def run_customized_training_loop(
model_outputs = model(inputs, training=False) model_outputs = model(inputs, training=False)
metric.update_state(labels, model_outputs) metric.update_state(labels, model_outputs)
strategy.experimental_run(_test_step_fn, iterator) strategy.experimental_run_v2(_test_step_fn, args=(next(iterator),))
def _run_evaluation(current_training_step, test_iterator): def _run_evaluation(current_training_step, test_dataset):
"""Runs validation steps and aggregate metrics.""" """Runs validation steps and aggregate metrics."""
test_iterator = iter(test_dataset)
for _ in range(eval_steps): for _ in range(eval_steps):
test_step(test_iterator) test_step(test_iterator)
...@@ -259,8 +260,9 @@ def run_customized_training_loop( ...@@ -259,8 +260,9 @@ def run_customized_training_loop(
if eval_input_fn: if eval_input_fn:
logging.info('Running evaluation after step: %s.', current_step) logging.info('Running evaluation after step: %s.', current_step)
_run_evaluation(current_step, _run_evaluation(
strategy.make_dataset_iterator(eval_input_fn())) current_step,
strategy.experimental_distribute_dataset(eval_input_fn()))
# Re-initialize evaluation metric, except the last step. # Re-initialize evaluation metric, except the last step.
if metric and current_step < total_training_steps: if metric and current_step < total_training_steps:
...@@ -273,7 +275,8 @@ def run_customized_training_loop( ...@@ -273,7 +275,8 @@ def run_customized_training_loop(
if eval_input_fn: if eval_input_fn:
logging.info('Running final evaluation after training is complete.') logging.info('Running final evaluation after training is complete.')
eval_metric_result = _run_evaluation( eval_metric_result = _run_evaluation(
current_step, strategy.make_dataset_iterator(eval_input_fn())) current_step,
strategy.experimental_distribute_dataset(eval_input_fn()))
training_summary = { training_summary = {
'total_training_steps': total_training_steps, 'total_training_steps': total_training_steps,
......
...@@ -74,8 +74,6 @@ flags.DEFINE_integer('train_batch_size', 32, 'Batch size for training.') ...@@ -74,8 +74,6 @@ flags.DEFINE_integer('train_batch_size', 32, 'Batch size for training.')
flags.DEFINE_integer('eval_batch_size', 8, 'Batch size for evaluation.') flags.DEFINE_integer('eval_batch_size', 8, 'Batch size for evaluation.')
flags.DEFINE_integer('num_train_epochs', 3, flags.DEFINE_integer('num_train_epochs', 3,
'Total number of training epochs to perform.') 'Total number of training epochs to perform.')
flags.DEFINE_integer('steps_per_run', 200,
'Number of steps running on TPU devices.')
flags.DEFINE_float('learning_rate', 5e-5, 'The initial learning rate for Adam.') flags.DEFINE_float('learning_rate', 5e-5, 'The initial learning rate for Adam.')
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
...@@ -240,8 +238,7 @@ def main(_): ...@@ -240,8 +238,7 @@ def main(_):
elif FLAGS.strategy_type == 'tpu': elif FLAGS.strategy_type == 'tpu':
# Initialize TPU System. # Initialize TPU System.
cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu) cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu)
strategy = tf.distribute.experimental.TPUStrategy( strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)
cluster_resolver, steps_per_run=FLAGS.steps_per_run)
else: else:
raise ValueError('The distribution strategy type is not supported: %s' % raise ValueError('The distribution strategy type is not supported: %s' %
FLAGS.strategy_type) FLAGS.strategy_type)
......
...@@ -56,9 +56,6 @@ flags.DEFINE_integer( ...@@ -56,9 +56,6 @@ flags.DEFINE_integer(
flags.DEFINE_integer('max_predictions_per_seq', 20, flags.DEFINE_integer('max_predictions_per_seq', 20,
'Maximum predictions per sequence_output.') 'Maximum predictions per sequence_output.')
flags.DEFINE_integer('train_batch_size', 32, 'Total batch size for training.') flags.DEFINE_integer('train_batch_size', 32, 'Total batch size for training.')
flags.DEFINE_integer(
'steps_per_run', 1000,
'Number of steps to run in TPU worker before returning to host.')
flags.DEFINE_integer('num_train_epochs', 3, flags.DEFINE_integer('num_train_epochs', 3,
'Total number of training epochs to perform.') 'Total number of training epochs to perform.')
flags.DEFINE_integer('num_steps_per_epoch', 1000, flags.DEFINE_integer('num_steps_per_epoch', 1000,
...@@ -167,8 +164,7 @@ def main(_): ...@@ -167,8 +164,7 @@ def main(_):
elif FLAGS.strategy_type == 'tpu': elif FLAGS.strategy_type == 'tpu':
# Initialize TPU System. # Initialize TPU System.
cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu) cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu)
strategy = tf.distribute.experimental.TPUStrategy( strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)
cluster_resolver, steps_per_run=FLAGS.steps_per_run)
else: else:
raise ValueError('The distribution strategy type is not supported: %s' % raise ValueError('The distribution strategy type is not supported: %s' %
FLAGS.strategy_type) FLAGS.strategy_type)
......
...@@ -64,8 +64,6 @@ flags.DEFINE_enum( ...@@ -64,8 +64,6 @@ flags.DEFINE_enum(
flags.DEFINE_integer('train_batch_size', 32, 'Total batch size for training.') flags.DEFINE_integer('train_batch_size', 32, 'Total batch size for training.')
flags.DEFINE_integer('num_train_epochs', 3, flags.DEFINE_integer('num_train_epochs', 3,
'Total number of training epochs to perform.') 'Total number of training epochs to perform.')
flags.DEFINE_integer('steps_per_run', 200,
'Number of steps running on TPU devices.')
flags.DEFINE_float('learning_rate', 5e-5, 'The initial learning rate for Adam.') flags.DEFINE_float('learning_rate', 5e-5, 'The initial learning rate for Adam.')
# Predict processing related. # Predict processing related.
...@@ -152,7 +150,8 @@ def predict_squad_customized(strategy, input_meta_data, bert_config, ...@@ -152,7 +150,8 @@ def predict_squad_customized(strategy, input_meta_data, bert_config,
input_meta_data['max_seq_length'], input_meta_data['max_seq_length'],
FLAGS.predict_batch_size, FLAGS.predict_batch_size,
is_training=False) is_training=False)
predict_iterator = strategy.make_dataset_iterator(predict_dataset) predict_iterator = iter(
strategy.experimental_distribute_dataset(predict_dataset))
with strategy.scope(): with strategy.scope():
squad_model, _ = bert_models.squad_model( squad_model, _ = bert_models.squad_model(
...@@ -167,7 +166,7 @@ def predict_squad_customized(strategy, input_meta_data, bert_config, ...@@ -167,7 +166,7 @@ def predict_squad_customized(strategy, input_meta_data, bert_config,
def predict_step(iterator): def predict_step(iterator):
"""Predicts on distributed devices.""" """Predicts on distributed devices."""
def replicated_step(inputs): def _replicated_step(inputs):
"""Replicated prediction calculation.""" """Replicated prediction calculation."""
x, _ = inputs x, _ = inputs
unique_ids, start_logits, end_logits = squad_model(x, training=False) unique_ids, start_logits, end_logits = squad_model(x, training=False)
...@@ -176,7 +175,8 @@ def predict_squad_customized(strategy, input_meta_data, bert_config, ...@@ -176,7 +175,8 @@ def predict_squad_customized(strategy, input_meta_data, bert_config,
start_logits=start_logits, start_logits=start_logits,
end_logits=end_logits) end_logits=end_logits)
outputs = strategy.experimental_run(replicated_step, iterator) outputs = strategy.experimental_run_v2(
_replicated_step, args=(next(iterator),))
return tf.nest.map_structure(strategy.unwrap, outputs) return tf.nest.map_structure(strategy.unwrap, outputs)
all_results = [] all_results = []
...@@ -316,8 +316,7 @@ def main(_): ...@@ -316,8 +316,7 @@ def main(_):
elif FLAGS.strategy_type == 'tpu': elif FLAGS.strategy_type == 'tpu':
# Initialize TPU System. # Initialize TPU System.
cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu) cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu)
strategy = tf.distribute.experimental.TPUStrategy( strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)
cluster_resolver, steps_per_run=FLAGS.steps_per_run)
else: else:
raise ValueError('The distribution strategy type is not supported: %s' % raise ValueError('The distribution strategy type is not supported: %s' %
FLAGS.strategy_type) FLAGS.strategy_type)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment