Unverified Commit e21dcdd0 authored by Hongkun Yu's avatar Hongkun Yu Committed by GitHub
Browse files

Merged commit includes the following changes: (#7221)

258208153  by hongkuny<hongkuny@google.com>:

    Adds run_eagerly option for bert.

--

PiperOrigin-RevId: 258208153
parent 492f8c92
...@@ -102,7 +102,8 @@ def run_customized_training_loop( ...@@ -102,7 +102,8 @@ def run_customized_training_loop(
metric_fn=None, metric_fn=None,
init_checkpoint=None, init_checkpoint=None,
use_remote_tpu=False, use_remote_tpu=False,
custom_callbacks=None): custom_callbacks=None,
run_eagerly=False):
"""Run BERT pretrain model training using low-level API. """Run BERT pretrain model training using low-level API.
Arguments: Arguments:
...@@ -139,6 +140,8 @@ def run_customized_training_loop( ...@@ -139,6 +140,8 @@ def run_customized_training_loop(
custom_callbacks: A list of Keras Callbacks objects to run during custom_callbacks: A list of Keras Callbacks objects to run during
training. More specifically, `on_batch_begin()`, `on_batch_end()`, training. More specifically, `on_batch_begin()`, `on_batch_end()`,
methods are invoked during training. methods are invoked during training.
run_eagerly: Whether to run model training in pure eager execution. This
should be disable for TPUStrategy.
Returns: Returns:
Trained model. Trained model.
...@@ -168,6 +171,16 @@ def run_customized_training_loop( ...@@ -168,6 +171,16 @@ def run_customized_training_loop(
steps_per_loop = steps_per_epoch steps_per_loop = steps_per_epoch
assert tf.executing_eagerly() assert tf.executing_eagerly()
if run_eagerly:
if steps_per_loop > 1:
raise ValueError(
'steps_per_loop is used for performance optimization. When you want '
'to run eagerly, you cannot leverage graph mode loop.')
if isinstance(strategy, tf.distribute.experimental.TPUStrategy):
raise ValueError(
'TPUStrategy should not run eagerly as it heavily replies on graph'
' optimization for the distributed system.')
if eval_input_fn and (eval_steps is None or metric_fn is None): if eval_input_fn and (eval_steps is None or metric_fn is None):
raise ValueError( raise ValueError(
'`eval_step` and `metric_fn` are required when `eval_input_fn ` ' '`eval_step` and `metric_fn` are required when `eval_input_fn ` '
...@@ -254,7 +267,6 @@ def run_customized_training_loop( ...@@ -254,7 +267,6 @@ def run_customized_training_loop(
for _ in tf.range(steps): for _ in tf.range(steps):
strategy.experimental_run_v2(_replicated_step, args=(next(iterator),)) strategy.experimental_run_v2(_replicated_step, args=(next(iterator),))
@tf.function
def train_single_step(iterator): def train_single_step(iterator):
"""Performs a distributed training step. """Performs a distributed training step.
...@@ -265,7 +277,6 @@ def run_customized_training_loop( ...@@ -265,7 +277,6 @@ def run_customized_training_loop(
""" """
strategy.experimental_run_v2(_replicated_step, args=(next(iterator),)) strategy.experimental_run_v2(_replicated_step, args=(next(iterator),))
@tf.function
def test_step(iterator): def test_step(iterator):
"""Calculates evaluation metrics on distributed devices.""" """Calculates evaluation metrics on distributed devices."""
...@@ -278,6 +289,10 @@ def run_customized_training_loop( ...@@ -278,6 +289,10 @@ def run_customized_training_loop(
strategy.experimental_run_v2(_test_step_fn, args=(next(iterator),)) strategy.experimental_run_v2(_test_step_fn, args=(next(iterator),))
if not run_eagerly:
train_single_step = tf.function(train_single_step)
test_step = tf.function(test_step)
def _run_evaluation(current_training_step, test_iterator): def _run_evaluation(current_training_step, test_iterator):
"""Runs validation steps and aggregate metrics.""" """Runs validation steps and aggregate metrics."""
for _ in range(eval_steps): for _ in range(eval_steps):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment