Commit 412f4d2e authored by Ruoxin Sang's avatar Ruoxin Sang Committed by A. Unique TensorFlower
Browse files

Add host training support for StandardEvaluator.

PiperOrigin-RevId: 348868176
parent b7930ff9
......@@ -50,12 +50,12 @@ class StandardTrainerOptions:
"""Advanced options for `orbit.StandardTrainer`.
Attributes:
use_tf_while_loop: A boolean indicating whether to run the training loop
using a `tf.while_loop`. If `True`, `use_tf_function` must also be `True`.
use_tf_function: A boolean indicating whether to apply `tf.function` to the
training loop. This will only affect the body of the loop (involving
`train_step`); `train_loop_begin` and `train_loop_end` will always be run
in eager mode.
use_tf_while_loop: A boolean indicating whether to run the training loop
using a `tf.while_loop`. If `True`, `use_tf_function` must also be `True`.
use_tpu_summary_optimization: A boolean indicating whether to enable a
performance optimization for summaries in TPUs. Writing summaries
conditionally with outside compilation on TPUs can be extremely slow. If
......@@ -63,8 +63,8 @@ class StandardTrainerOptions:
(one with summary calls, and one without). The program with summaries runs
only for one step when summaries should be recorded.
"""
use_tf_while_loop: bool = True
use_tf_function: bool = True
use_tf_while_loop: bool = True
use_tpu_summary_optimization: bool = False
......@@ -215,14 +215,30 @@ class StandardEvaluatorOptions:
training loop. This will only affect the body of the loop (involving
`train_step`); `train_loop_begin` and `train_loop_end` will always be run
in eager mode.
use_tf_while_loop: A boolean indicating whether to run the training loop
using a `tf.while_loop`. If `True`, `use_tf_function` must also be `True`.
"""
use_tf_function: bool = True
use_tf_while_loop: bool = False
def _create_eval_loop_fn(eval_step_fn, options: StandardEvaluatorOptions):
def _create_eval_loop_fn(eval_step_fn, has_state: bool,
options: StandardEvaluatorOptions):
"""Create evaluation loop function."""
if options.use_tf_while_loop:
# TODO(b/176126742): tf.while_loop doesn't support `None` as a loop input
# even when it is not used inside the loop. To workaround this limitation,
# we have to build two tf.functions for it.
if has_state:
loop_fn = loop_fns.create_tf_while_loop_fn_with_state(eval_step_fn)
else:
loop_fn = loop_fns.create_tf_while_loop_fn(eval_step_fn)
loop_fn = tf.function(loop_fn)
else:
if options.use_tf_function:
eval_step_fn = tf.function(eval_step_fn)
return loop_fns.create_loop_fn(eval_step_fn)
loop_fn = loop_fns.create_loop_fn(eval_step_fn)
return loop_fn
class StandardEvaluator(runner.AbstractEvaluator, metaclass=abc.ABCMeta):
......@@ -254,7 +270,12 @@ class StandardEvaluator(runner.AbstractEvaluator, metaclass=abc.ABCMeta):
`DistributedDataset`.
options: An `orbit.StandardEvaluatorOptions` instance.
"""
self._eval_options = options or StandardEvaluatorOptions()
options = options or StandardEvaluatorOptions()
if options.use_tf_while_loop and not options.use_tf_function:
raise ValueError("`use_tf_while_loop=True` and `use_tf_function=False` "
"is not supported")
self._eval_options = options
self._eval_dataset = eval_dataset
self._eval_loop_fn = None
......@@ -268,14 +289,26 @@ class StandardEvaluator(runner.AbstractEvaluator, metaclass=abc.ABCMeta):
Returns:
The output of `self.eval_end()`.
Raises:
ValueError: If `options.use_tf_while_loop` is `True` and `num_steps` is
unspecified.
"""
if self._eval_options.use_tf_while_loop and num_steps == -1:
raise ValueError("Looping until exhausted is not supported if "
"`options.use_tf_while_loop` is `True`")
outputs = self.eval_begin() # pylint: disable=assignment-from-no-return
has_state = outputs is not None
if self._eval_loop_fn is None:
self._eval_loop_fn = _create_eval_loop_fn(
self.eval_step, options=self._eval_options)
self.eval_step, has_state=has_state, options=self._eval_options)
eval_iter = tf.nest.map_structure(iter, self.eval_dataset)
if self._eval_options.use_tf_while_loop and not has_state:
self._eval_loop_fn(eval_iter, num_steps)
else:
outputs = self._eval_loop_fn(
eval_iter, num_steps, state=outputs, reduce_fn=self.eval_reduce)
......
......@@ -14,6 +14,8 @@
# ==============================================================================
"""Tests for orbit.standard_runner."""
from absl.testing import parameterized
from orbit import standard_runner
from orbit import utils
......@@ -79,7 +81,36 @@ class TestEvaluator(standard_runner.StandardEvaluator):
return self.global_step.numpy()
class StandardRunnerTest(tf.test.TestCase):
class TestEvaluatorWithOutputsAggregation(standard_runner.StandardEvaluator):
"""A StandardEvaluator subclass for tests."""
def __init__(self, options=None):
self.strategy = tf.distribute.get_strategy()
dataset = self.strategy.distribute_datasets_from_function(
lambda _: tf.data.Dataset.range(10))
super().__init__(eval_dataset=dataset, options=options)
def eval_begin(self):
return tf.constant((0.0,))
def eval_reduce(self, state, step_outputs):
state = tf.concat([state, step_outputs], 0)
return state
def eval_step(self, iterator):
def replica_step(x):
x = tf.cast(x, tf.float32)
return tf.reduce_sum(x)
return self.strategy.experimental_local_results(
self.strategy.run(replica_step, args=(next(iterator),)))
def eval_end(self, outputs):
return tf.reduce_sum(outputs)
class StandardRunnerTest(parameterized.TestCase):
def test_default_trainer(self):
trainer = TestTrainer()
......@@ -91,10 +122,20 @@ class StandardRunnerTest(tf.test.TestCase):
trainer = TestTrainer(options)
self.assertEqual(trainer.train(tf.constant(10)), 10)
def test_default_evaluator(self):
evaluator = TestEvaluator()
@parameterized.named_parameters(("use_tf_while_loop", True), ("", False))
def test_default_evaluator(self, use_tf_while_loop):
options = standard_runner.StandardEvaluatorOptions(
use_tf_while_loop=use_tf_while_loop)
evaluator = TestEvaluator(options)
self.assertEqual(evaluator.evaluate(tf.constant(10)), 10)
@parameterized.named_parameters(("use_tf_while_loop", True), ("", False))
def test_evaluator_with_outputs_aggregation(self, use_tf_while_loop):
options = standard_runner.StandardEvaluatorOptions(
use_tf_while_loop=use_tf_while_loop)
evaluator = TestEvaluatorWithOutputsAggregation(options)
self.assertEqual(evaluator.evaluate(tf.constant(10)), 45)
if __name__ == '__main__':
if __name__ == "__main__":
tf.test.main()
......@@ -117,6 +117,60 @@ def create_tf_while_loop_fn(step_fn):
return loop_fn
def create_tf_while_loop_fn_with_state(step_fn):
"""Creates a TF while loop function with state.
This function is similar to `create_tf_while_loop_fn`, but allowing a `state`
to be accumulated over multiple iterations of the loop. Note that the
structure of the `state` cannot be changed across iterations.
Args:
step_fn: A function taking a nested structure of `tf.data.Iterator` or
`DistributedIterator`. Currently, any return values are ignored.
Returns:
A loop function taking required `iterator`, `num_steps`, `state` and
`reduce_fn` parameters. If called inside a `tf.function`, the loop will be
converted by AutoGraph into a `tf.while_loop` construct. See the `loop_fn`
definition below for additional details.
"""
def loop_fn_with_state(iterator, num_steps, state, reduce_fn):
"""Makes `num_steps` calls to `step_fn(iterator)`.
Args:
iterator: A nested structure of `tf.data.Iterator` or
`DistributedIterator`.
num_steps: The number of steps in the loop. Should be passed as a
`tf.Tensor`. Iterating until iterator exhaustion is not supported.
state: An initial state before running the loop.
reduce_fn: A callable taking two inputs, `state` and `value`, where
`state` is the previous output from `reduce_fn`, and `value` is the
output from `step_fn`.
Returns:
The final state returned by `reduce_fn`.
"""
if not isinstance(num_steps, tf.Tensor):
raise ValueError(
"`num_steps` should be a `tf.Tensor`. Passing a Python value can "
"cause unnecessary retracing when wrapped by `tf.function`.")
for _ in tf.range(num_steps):
# Relax the shapes within the loop, so the shape of `state` can change
# across iterations. This is useful to aggregate outputs from each step
# and concat to `state`.
tf.autograph.experimental.set_loop_options(
shape_invariants=[(t, tf.TensorShape([None] * t.shape.rank))
for t in tf.nest.flatten(state)
if tf.is_tensor(t)])
outputs = step_fn(iterator)
state = reduce_fn(state, outputs)
return state
return loop_fn_with_state
class LoopFnWithSummaries(tpu_summaries.OptionalSummariesFunction):
"""Implements a two-program approach for optimizing summaries on TPU.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment