Unverified Commit 240623ac authored by saberkun's avatar saberkun Committed by GitHub
Browse files

Merged commit includes the following changes: (#7093)

254785517  by A. Unique TensorFlower<gardener@tensorflow.org>:

    Use train_single_step for BERT GPU models to temporarily work around some performance bugs in GPU runs

--
254497647  by hongkuny<hongkuny@google.com>:

    Fix device placement for TPU export model.

--

PiperOrigin-RevId: 254785517
parent 1157c738
...@@ -65,7 +65,9 @@ def _float_metric_value(metric): ...@@ -65,7 +65,9 @@ def _float_metric_value(metric):
def _steps_to_run(current_step, steps_per_epoch, steps_per_loop): def _steps_to_run(current_step, steps_per_epoch, steps_per_loop):
"""Calculates steps to run on device.""" """Calculates steps to run on device."""
if steps_per_loop <= 1: if steps_per_loop <= 0:
raise ValueError('steps_per_loop should be positive integer.')
if steps_per_loop == 1:
return steps_per_loop return steps_per_loop
remainder_in_epoch = current_step % steps_per_epoch remainder_in_epoch = current_step % steps_per_epoch
if remainder_in_epoch != 0: if remainder_in_epoch != 0:
...@@ -198,9 +200,25 @@ def run_customized_training_loop( ...@@ -198,9 +200,25 @@ def run_customized_training_loop(
eval_metric.__class__.from_config(eval_metric.get_config()) eval_metric.__class__.from_config(eval_metric.get_config())
if eval_metric else None) if eval_metric else None)
def _replicated_step(inputs):
"""Replicated training step."""
inputs, labels = inputs
with tf.GradientTape() as tape:
model_outputs = model(inputs)
loss = loss_fn(labels, model_outputs)
tvars = model.trainable_variables
grads = tape.gradient(loss, tvars)
optimizer.apply_gradients(zip(grads, tvars))
# For reporting, the metric takes the mean of losses.
train_loss_metric.update_state(loss)
if train_metric:
train_metric.update_state(labels, model_outputs)
@tf.function @tf.function
def train_step(iterator, steps): def train_steps(iterator, steps):
"""Performs a distributed training step. """Performs distributed training steps in a loop.
Args: Args:
iterator: the distributed iterator of training datasets. iterator: the distributed iterator of training datasets.
...@@ -213,25 +231,20 @@ def run_customized_training_loop( ...@@ -213,25 +231,20 @@ def run_customized_training_loop(
raise ValueError('steps should be an Tensor. Python object may cause ' raise ValueError('steps should be an Tensor. Python object may cause '
'retracing.') 'retracing.')
def _replicated_step(inputs):
"""Replicated training step."""
inputs, labels = inputs
with tf.GradientTape() as tape:
model_outputs = model(inputs)
loss = loss_fn(labels, model_outputs)
tvars = model.trainable_variables
grads = tape.gradient(loss, tvars)
optimizer.apply_gradients(zip(grads, tvars))
# For reporting, the metric takes the mean of losses.
train_loss_metric.update_state(loss)
if train_metric:
train_metric.update_state(labels, model_outputs)
for _ in tf.range(steps): for _ in tf.range(steps):
strategy.experimental_run_v2(_replicated_step, args=(next(iterator),)) strategy.experimental_run_v2(_replicated_step, args=(next(iterator),))
@tf.function
def train_single_step(iterator):
"""Performs a distributed training step.
Args:
iterator: the distributed iterator of training datasets.
Raises:
ValueError: Any of the arguments or tensor shapes are invalid.
"""
strategy.experimental_run_v2(_replicated_step, args=(next(iterator),))
@tf.function @tf.function
def test_step(iterator): def test_step(iterator):
"""Calculates evaluation metrics on distributed devices.""" """Calculates evaluation metrics on distributed devices."""
...@@ -289,8 +302,15 @@ def run_customized_training_loop( ...@@ -289,8 +302,15 @@ def run_customized_training_loop(
_run_callbacks_on_batch_begin(current_step) _run_callbacks_on_batch_begin(current_step)
# Runs several steps in the host while loop. # Runs several steps in the host while loop.
steps = _steps_to_run(current_step, steps_per_epoch, steps_per_loop) steps = _steps_to_run(current_step, steps_per_epoch, steps_per_loop)
# Converts steps to a Tensor to avoid tf.function retracing.
train_step(train_iterator, tf.convert_to_tensor(steps, dtype=tf.int32)) if steps == 1:
# TODO(zongweiz): merge with train_steps once tf.while_loop
# GPU performance bugs are fixed.
train_single_step(train_iterator)
else:
# Converts steps to a Tensor to avoid tf.function retracing.
train_steps(train_iterator,
tf.convert_to_tensor(steps, dtype=tf.int32))
_run_callbacks_on_batch_end(current_step) _run_callbacks_on_batch_end(current_step)
current_step += steps current_step += steps
......
...@@ -226,8 +226,9 @@ def run_bert(strategy, input_meta_data): ...@@ -226,8 +226,9 @@ def run_bert(strategy, input_meta_data):
use_remote_tpu=use_remote_tpu) use_remote_tpu=use_remote_tpu)
if FLAGS.model_export_path: if FLAGS.model_export_path:
model_saving_utils.export_bert_model( with tf.device(model_training_utils.get_primary_cpu_task(use_remote_tpu)):
FLAGS.model_export_path, model=trained_model) model_saving_utils.export_bert_model(
FLAGS.model_export_path, model=trained_model)
return trained_model return trained_model
...@@ -257,4 +258,5 @@ def main(_): ...@@ -257,4 +258,5 @@ def main(_):
if __name__ == '__main__': if __name__ == '__main__':
flags.mark_flag_as_required('bert_config_file') flags.mark_flag_as_required('bert_config_file')
flags.mark_flag_as_required('input_meta_data_path') flags.mark_flag_as_required('input_meta_data_path')
flags.mark_flag_as_required('model_dir')
app.run(main) app.run(main)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment