Unverified Commit 240623ac authored by saberkun's avatar saberkun Committed by GitHub
Browse files

Merged commit includes the following changes: (#7093)

254785517  by A. Unique TensorFlower<gardener@tensorflow.org>:

    Use train_single_step for BERT GPU models to temporarily work around some performance bugs in GPU runs

--
254497647  by hongkuny<hongkuny@google.com>:

    Fix device placement for TPU export model.

--

PiperOrigin-RevId: 254785517
parent 1157c738
......@@ -65,7 +65,9 @@ def _float_metric_value(metric):
def _steps_to_run(current_step, steps_per_epoch, steps_per_loop):
"""Calculates steps to run on device."""
if steps_per_loop <= 1:
if steps_per_loop <= 0:
raise ValueError('steps_per_loop should be positive integer.')
if steps_per_loop == 1:
return steps_per_loop
remainder_in_epoch = current_step % steps_per_epoch
if remainder_in_epoch != 0:
......@@ -198,21 +200,6 @@ def run_customized_training_loop(
eval_metric.__class__.from_config(eval_metric.get_config())
if eval_metric else None)
@tf.function
def train_step(iterator, steps):
"""Performs a distributed training step.
Args:
iterator: the distributed iterator of training datasets.
steps: an tf.int32 integer tensor to specify number of steps to run
inside host training loop.
Raises:
ValueError: Any of the arguments or tensor shapes are invalid.
"""
if not isinstance(steps, tf.Tensor):
raise ValueError('steps should be an Tensor. Python object may cause '
'retracing.')
def _replicated_step(inputs):
"""Replicated training step."""
......@@ -229,9 +216,35 @@ def run_customized_training_loop(
if train_metric:
train_metric.update_state(labels, model_outputs)
@tf.function
def train_steps(iterator, steps):
"""Performs distributed training steps in a loop.
Args:
iterator: the distributed iterator of training datasets.
steps: an tf.int32 integer tensor to specify number of steps to run
inside host training loop.
Raises:
ValueError: Any of the arguments or tensor shapes are invalid.
"""
if not isinstance(steps, tf.Tensor):
raise ValueError('steps should be an Tensor. Python object may cause '
'retracing.')
for _ in tf.range(steps):
strategy.experimental_run_v2(_replicated_step, args=(next(iterator),))
@tf.function
def train_single_step(iterator):
"""Performs a distributed training step.
Args:
iterator: the distributed iterator of training datasets.
Raises:
ValueError: Any of the arguments or tensor shapes are invalid.
"""
strategy.experimental_run_v2(_replicated_step, args=(next(iterator),))
@tf.function
def test_step(iterator):
"""Calculates evaluation metrics on distributed devices."""
......@@ -289,8 +302,15 @@ def run_customized_training_loop(
_run_callbacks_on_batch_begin(current_step)
# Runs several steps in the host while loop.
steps = _steps_to_run(current_step, steps_per_epoch, steps_per_loop)
if steps == 1:
# TODO(zongweiz): merge with train_steps once tf.while_loop
# GPU performance bugs are fixed.
train_single_step(train_iterator)
else:
# Converts steps to a Tensor to avoid tf.function retracing.
train_step(train_iterator, tf.convert_to_tensor(steps, dtype=tf.int32))
train_steps(train_iterator,
tf.convert_to_tensor(steps, dtype=tf.int32))
_run_callbacks_on_batch_end(current_step)
current_step += steps
......
......@@ -226,6 +226,7 @@ def run_bert(strategy, input_meta_data):
use_remote_tpu=use_remote_tpu)
if FLAGS.model_export_path:
with tf.device(model_training_utils.get_primary_cpu_task(use_remote_tpu)):
model_saving_utils.export_bert_model(
FLAGS.model_export_path, model=trained_model)
return trained_model
......@@ -257,4 +258,5 @@ def main(_):
if __name__ == '__main__':
flags.mark_flag_as_required('bert_config_file')
flags.mark_flag_as_required('input_meta_data_path')
flags.mark_flag_as_required('model_dir')
app.run(main)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment