Unverified Commit e0e6d981 authored by saberkun's avatar saberkun Committed by GitHub
Browse files

Merged commit includes the following changes: (#7060)

254069984  by hongkuny<hongkuny@google.com>:
    Automated rollback of changelist 254060732.

254061429  by hongkuny<hongkuny@google.com>:

    Use host while loop for training steps.

--
254060732  by yifeif<yifeif@google.com>:
    Automated rollback of changelist 254027750.

254027750  by hongkuny<hongkuny@google.com>:

    Internal change

PiperOrigin-RevId: 254069984
parent 695265c8
......@@ -199,8 +199,19 @@ def run_customized_training_loop(
if eval_metric else None)
@tf.function
def train_step(iterator):
"""Performs a distributed training step."""
def train_step(iterator, steps):
"""Performs a distributed training step.
Args:
iterator: the distributed iterator of training datasets.
steps: an tf.int32 integer tensor to specify number of steps to run
inside host training loop.
Raises:
ValueError: Any of the arguments or tensor shapes are invalid.
"""
if not isinstance(steps, tf.Tensor):
raise ValueError('steps should be an Tensor. Python object may cause '
'retracing.')
def _replicated_step(inputs):
"""Replicated training step."""
......@@ -218,6 +229,7 @@ def run_customized_training_loop(
if train_metric:
train_metric.update_state(labels, model_outputs)
for _ in tf.range(steps):
strategy.experimental_run_v2(_replicated_step, args=(next(iterator),))
@tf.function
......@@ -274,13 +286,13 @@ def run_customized_training_loop(
if train_metric:
train_metric.reset_states()
state_step = current_step
_run_callbacks_on_batch_begin(state_step)
for _ in range(
_steps_to_run(state_step, steps_per_epoch, steps_per_loop)):
current_step += 1
train_step(train_iterator)
_run_callbacks_on_batch_end(state_step)
_run_callbacks_on_batch_begin(current_step)
# Runs several steps in the host while loop.
steps = _steps_to_run(current_step, steps_per_epoch, steps_per_loop)
# Converts steps to a Tensor to avoid tf.function retracing.
train_step(train_iterator, tf.convert_to_tensor(steps, dtype=tf.int32))
_run_callbacks_on_batch_end(current_step)
current_step += steps
# Updates training logging.
training_status = 'Train Step: %d/%d / loss = %s' % (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment