"vscode:/vscode.git/clone" did not exist on "f2c618817be8b4e265e368eb39ecbba8714807d3"
Unverified Commit e21dcdd0 authored by Hongkun Yu's avatar Hongkun Yu Committed by GitHub
Browse files

Merged commit includes the following changes: (#7221)

258208153  by hongkuny<hongkuny@google.com>:

    Adds run_eagerly option for bert.

--

PiperOrigin-RevId: 258208153
parent 492f8c92
......@@ -102,7 +102,8 @@ def run_customized_training_loop(
metric_fn=None,
init_checkpoint=None,
use_remote_tpu=False,
custom_callbacks=None):
custom_callbacks=None,
run_eagerly=False):
"""Run BERT pretrain model training using low-level API.
Arguments:
......@@ -139,6 +140,8 @@ def run_customized_training_loop(
custom_callbacks: A list of Keras Callbacks objects to run during
training. More specifically, `on_batch_begin()`, `on_batch_end()`,
methods are invoked during training.
run_eagerly: Whether to run model training in pure eager execution. This
should be disable for TPUStrategy.
Returns:
Trained model.
......@@ -168,6 +171,16 @@ def run_customized_training_loop(
steps_per_loop = steps_per_epoch
assert tf.executing_eagerly()
if run_eagerly:
if steps_per_loop > 1:
raise ValueError(
'steps_per_loop is used for performance optimization. When you want '
'to run eagerly, you cannot leverage graph mode loop.')
if isinstance(strategy, tf.distribute.experimental.TPUStrategy):
raise ValueError(
'TPUStrategy should not run eagerly as it heavily replies on graph'
' optimization for the distributed system.')
if eval_input_fn and (eval_steps is None or metric_fn is None):
raise ValueError(
'`eval_step` and `metric_fn` are required when `eval_input_fn ` '
......@@ -254,7 +267,6 @@ def run_customized_training_loop(
for _ in tf.range(steps):
strategy.experimental_run_v2(_replicated_step, args=(next(iterator),))
@tf.function
def train_single_step(iterator):
"""Performs a distributed training step.
......@@ -265,7 +277,6 @@ def run_customized_training_loop(
"""
strategy.experimental_run_v2(_replicated_step, args=(next(iterator),))
@tf.function
def test_step(iterator):
"""Calculates evaluation metrics on distributed devices."""
......@@ -278,6 +289,10 @@ def run_customized_training_loop(
strategy.experimental_run_v2(_test_step_fn, args=(next(iterator),))
if not run_eagerly:
train_single_step = tf.function(train_single_step)
test_step = tf.function(test_step)
def _run_evaluation(current_training_step, test_iterator):
"""Runs validation steps and aggregate metrics."""
for _ in range(eval_steps):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment