Unverified Commit 51289259 authored by saberkun's avatar saberkun Committed by GitHub
Browse files

Merged commit includes the following changes: (#6963)

251681245  by hongkuny<hongkuny@google.com>:

    Update bert to use the new tf.distribute APIs

--
251575972  by A. Unique TensorFlower<gardener@tensorflow.org>:

    Remove `steps_per_run` when instantiating TPUStrategy.

--

PiperOrigin-RevId: 251681245
parent d01ac976
......@@ -125,7 +125,8 @@ def run_customized_training_loop(
# To reduce unnecessary send/receive input pipeline operation, we place input
# pipeline ops in worker task.
with tf.device(get_primary_cpu_task(use_remote_tpu)):
train_iterator = strategy.make_dataset_iterator(train_input_fn())
train_iterator = iter(
strategy.experimental_distribute_dataset(train_input_fn()))
with strategy.scope():
total_training_steps = steps_per_epoch * epochs
......@@ -171,9 +172,8 @@ def run_customized_training_loop(
optimizer.apply_gradients(zip(grads, tvars))
return loss
per_replica_losses = strategy.experimental_run(_replicated_step,
iterator)
per_replica_losses = strategy.experimental_run_v2(
_replicated_step, args=(next(iterator),))
# For reporting, we returns the mean of losses.
loss = strategy.reduce(
tf.distribute.ReduceOp.MEAN, per_replica_losses, axis=None)
......@@ -190,10 +190,11 @@ def run_customized_training_loop(
model_outputs = model(inputs, training=False)
metric.update_state(labels, model_outputs)
strategy.experimental_run(_test_step_fn, iterator)
strategy.experimental_run_v2(_test_step_fn, args=(next(iterator),))
def _run_evaluation(current_training_step, test_iterator):
def _run_evaluation(current_training_step, test_dataset):
"""Runs validation steps and aggregate metrics."""
test_iterator = iter(test_dataset)
for _ in range(eval_steps):
test_step(test_iterator)
......@@ -259,8 +260,9 @@ def run_customized_training_loop(
if eval_input_fn:
logging.info('Running evaluation after step: %s.', current_step)
_run_evaluation(current_step,
strategy.make_dataset_iterator(eval_input_fn()))
_run_evaluation(
current_step,
strategy.experimental_distribute_dataset(eval_input_fn()))
# Re-initialize evaluation metric, except the last step.
if metric and current_step < total_training_steps:
......@@ -273,7 +275,8 @@ def run_customized_training_loop(
if eval_input_fn:
logging.info('Running final evaluation after training is complete.')
eval_metric_result = _run_evaluation(
current_step, strategy.make_dataset_iterator(eval_input_fn()))
current_step,
strategy.experimental_distribute_dataset(eval_input_fn()))
training_summary = {
'total_training_steps': total_training_steps,
......
......@@ -74,8 +74,6 @@ flags.DEFINE_integer('train_batch_size', 32, 'Batch size for training.')
flags.DEFINE_integer('eval_batch_size', 8, 'Batch size for evaluation.')
flags.DEFINE_integer('num_train_epochs', 3,
'Total number of training epochs to perform.')
flags.DEFINE_integer('steps_per_run', 200,
'Number of steps running on TPU devices.')
flags.DEFINE_float('learning_rate', 5e-5, 'The initial learning rate for Adam.')
FLAGS = flags.FLAGS
......@@ -240,8 +238,7 @@ def main(_):
elif FLAGS.strategy_type == 'tpu':
# Initialize TPU System.
cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu)
strategy = tf.distribute.experimental.TPUStrategy(
cluster_resolver, steps_per_run=FLAGS.steps_per_run)
strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)
else:
raise ValueError('The distribution strategy type is not supported: %s' %
FLAGS.strategy_type)
......
......@@ -56,9 +56,6 @@ flags.DEFINE_integer(
flags.DEFINE_integer('max_predictions_per_seq', 20,
'Maximum predictions per sequence_output.')
flags.DEFINE_integer('train_batch_size', 32, 'Total batch size for training.')
flags.DEFINE_integer(
'steps_per_run', 1000,
'Number of steps to run in TPU worker before returning to host.')
flags.DEFINE_integer('num_train_epochs', 3,
'Total number of training epochs to perform.')
flags.DEFINE_integer('num_steps_per_epoch', 1000,
......@@ -167,8 +164,7 @@ def main(_):
elif FLAGS.strategy_type == 'tpu':
# Initialize TPU System.
cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu)
strategy = tf.distribute.experimental.TPUStrategy(
cluster_resolver, steps_per_run=FLAGS.steps_per_run)
strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)
else:
raise ValueError('The distribution strategy type is not supported: %s' %
FLAGS.strategy_type)
......
......@@ -64,8 +64,6 @@ flags.DEFINE_enum(
flags.DEFINE_integer('train_batch_size', 32, 'Total batch size for training.')
flags.DEFINE_integer('num_train_epochs', 3,
'Total number of training epochs to perform.')
flags.DEFINE_integer('steps_per_run', 200,
'Number of steps running on TPU devices.')
flags.DEFINE_float('learning_rate', 5e-5, 'The initial learning rate for Adam.')
# Predict processing related.
......@@ -152,7 +150,8 @@ def predict_squad_customized(strategy, input_meta_data, bert_config,
input_meta_data['max_seq_length'],
FLAGS.predict_batch_size,
is_training=False)
predict_iterator = strategy.make_dataset_iterator(predict_dataset)
predict_iterator = iter(
strategy.experimental_distribute_dataset(predict_dataset))
with strategy.scope():
squad_model, _ = bert_models.squad_model(
......@@ -167,7 +166,7 @@ def predict_squad_customized(strategy, input_meta_data, bert_config,
def predict_step(iterator):
"""Predicts on distributed devices."""
def replicated_step(inputs):
def _replicated_step(inputs):
"""Replicated prediction calculation."""
x, _ = inputs
unique_ids, start_logits, end_logits = squad_model(x, training=False)
......@@ -176,7 +175,8 @@ def predict_squad_customized(strategy, input_meta_data, bert_config,
start_logits=start_logits,
end_logits=end_logits)
outputs = strategy.experimental_run(replicated_step, iterator)
outputs = strategy.experimental_run_v2(
_replicated_step, args=(next(iterator),))
return tf.nest.map_structure(strategy.unwrap, outputs)
all_results = []
......@@ -316,8 +316,7 @@ def main(_):
elif FLAGS.strategy_type == 'tpu':
# Initialize TPU System.
cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu)
strategy = tf.distribute.experimental.TPUStrategy(
cluster_resolver, steps_per_run=FLAGS.steps_per_run)
strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)
else:
raise ValueError('The distribution strategy type is not supported: %s' %
FLAGS.strategy_type)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment