"git@developer.sourcefind.cn:OpenDAS/mmdetection3d.git" did not exist on "d94042db9c37d23329e5f6ba3c2e2c096d862be5"
Unverified Commit e0e6d981 authored by saberkun's avatar saberkun Committed by GitHub
Browse files

Merged commit includes the following changes: (#7060)

254069984  by hongkuny<hongkuny@google.com>:
    Automated rollback of changelist 254060732.

254061429  by hongkuny<hongkuny@google.com>:

    Use host while loop for training steps.

--
254060732  by yifeif<yifeif@google.com>:
    Automated rollback of changelist 254027750.

254027750  by hongkuny<hongkuny@google.com>:

    Internal change

PiperOrigin-RevId: 254069984
parent 695265c8
...@@ -199,8 +199,19 @@ def run_customized_training_loop( ...@@ -199,8 +199,19 @@ def run_customized_training_loop(
if eval_metric else None) if eval_metric else None)
@tf.function @tf.function
def train_step(iterator): def train_step(iterator, steps):
"""Performs a distributed training step.""" """Performs a distributed training step.
Args:
iterator: the distributed iterator of training datasets.
steps: an tf.int32 integer tensor to specify number of steps to run
inside host training loop.
Raises:
ValueError: Any of the arguments or tensor shapes are invalid.
"""
if not isinstance(steps, tf.Tensor):
raise ValueError('steps should be an Tensor. Python object may cause '
'retracing.')
def _replicated_step(inputs): def _replicated_step(inputs):
"""Replicated training step.""" """Replicated training step."""
...@@ -218,7 +229,8 @@ def run_customized_training_loop( ...@@ -218,7 +229,8 @@ def run_customized_training_loop(
if train_metric: if train_metric:
train_metric.update_state(labels, model_outputs) train_metric.update_state(labels, model_outputs)
strategy.experimental_run_v2(_replicated_step, args=(next(iterator),)) for _ in tf.range(steps):
strategy.experimental_run_v2(_replicated_step, args=(next(iterator),))
@tf.function @tf.function
def test_step(iterator): def test_step(iterator):
...@@ -274,13 +286,13 @@ def run_customized_training_loop( ...@@ -274,13 +286,13 @@ def run_customized_training_loop(
if train_metric: if train_metric:
train_metric.reset_states() train_metric.reset_states()
state_step = current_step _run_callbacks_on_batch_begin(current_step)
_run_callbacks_on_batch_begin(state_step) # Runs several steps in the host while loop.
for _ in range( steps = _steps_to_run(current_step, steps_per_epoch, steps_per_loop)
_steps_to_run(state_step, steps_per_epoch, steps_per_loop)): # Converts steps to a Tensor to avoid tf.function retracing.
current_step += 1 train_step(train_iterator, tf.convert_to_tensor(steps, dtype=tf.int32))
train_step(train_iterator) _run_callbacks_on_batch_end(current_step)
_run_callbacks_on_batch_end(state_step) current_step += steps
# Updates training logging. # Updates training logging.
training_status = 'Train Step: %d/%d / loss = %s' % ( training_status = 'Train Step: %d/%d / loss = %s' % (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment