Commit 412f4d2e authored by Ruoxin Sang's avatar Ruoxin Sang Committed by A. Unique TensorFlower
Browse files

Add host training support for StandardEvaluator.

PiperOrigin-RevId: 348868176
parent b7930ff9
...@@ -50,12 +50,12 @@ class StandardTrainerOptions: ...@@ -50,12 +50,12 @@ class StandardTrainerOptions:
"""Advanced options for `orbit.StandardTrainer`. """Advanced options for `orbit.StandardTrainer`.
Attributes: Attributes:
use_tf_while_loop: A boolean indicating whether to run the training loop
using a `tf.while_loop`. If `True`, `use_tf_function` must also be `True`.
use_tf_function: A boolean indicating whether to apply `tf.function` to the use_tf_function: A boolean indicating whether to apply `tf.function` to the
training loop. This will only affect the body of the loop (involving training loop. This will only affect the body of the loop (involving
`train_step`); `train_loop_begin` and `train_loop_end` will always be run `train_step`); `train_loop_begin` and `train_loop_end` will always be run
in eager mode. in eager mode.
use_tf_while_loop: A boolean indicating whether to run the training loop
using a `tf.while_loop`. If `True`, `use_tf_function` must also be `True`.
use_tpu_summary_optimization: A boolean indicating whether to enable a use_tpu_summary_optimization: A boolean indicating whether to enable a
performance optimization for summaries in TPUs. Writing summaries performance optimization for summaries in TPUs. Writing summaries
conditionally with outside compilation on TPUs can be extremely slow. If conditionally with outside compilation on TPUs can be extremely slow. If
...@@ -63,8 +63,8 @@ class StandardTrainerOptions: ...@@ -63,8 +63,8 @@ class StandardTrainerOptions:
(one with summary calls, and one without). The program with summaries runs (one with summary calls, and one without). The program with summaries runs
only for one step when summaries should be recorded. only for one step when summaries should be recorded.
""" """
use_tf_while_loop: bool = True
use_tf_function: bool = True use_tf_function: bool = True
use_tf_while_loop: bool = True
use_tpu_summary_optimization: bool = False use_tpu_summary_optimization: bool = False
...@@ -215,14 +215,30 @@ class StandardEvaluatorOptions: ...@@ -215,14 +215,30 @@ class StandardEvaluatorOptions:
training loop. This will only affect the body of the loop (involving training loop. This will only affect the body of the loop (involving
`train_step`); `train_loop_begin` and `train_loop_end` will always be run `train_step`); `train_loop_begin` and `train_loop_end` will always be run
in eager mode. in eager mode.
use_tf_while_loop: A boolean indicating whether to run the training loop
using a `tf.while_loop`. If `True`, `use_tf_function` must also be `True`.
""" """
use_tf_function: bool = True use_tf_function: bool = True
use_tf_while_loop: bool = False
def _create_eval_loop_fn(eval_step_fn, options: StandardEvaluatorOptions): def _create_eval_loop_fn(eval_step_fn, has_state: bool,
if options.use_tf_function: options: StandardEvaluatorOptions):
eval_step_fn = tf.function(eval_step_fn) """Create evaluation loop function."""
return loop_fns.create_loop_fn(eval_step_fn) if options.use_tf_while_loop:
# TODO(b/176126742): tf.while_loop doesn't support `None` as a loop input
# even when it is not used inside the loop. To workaround this limitation,
# we have to build two tf.functions for it.
if has_state:
loop_fn = loop_fns.create_tf_while_loop_fn_with_state(eval_step_fn)
else:
loop_fn = loop_fns.create_tf_while_loop_fn(eval_step_fn)
loop_fn = tf.function(loop_fn)
else:
if options.use_tf_function:
eval_step_fn = tf.function(eval_step_fn)
loop_fn = loop_fns.create_loop_fn(eval_step_fn)
return loop_fn
class StandardEvaluator(runner.AbstractEvaluator, metaclass=abc.ABCMeta): class StandardEvaluator(runner.AbstractEvaluator, metaclass=abc.ABCMeta):
...@@ -254,7 +270,12 @@ class StandardEvaluator(runner.AbstractEvaluator, metaclass=abc.ABCMeta): ...@@ -254,7 +270,12 @@ class StandardEvaluator(runner.AbstractEvaluator, metaclass=abc.ABCMeta):
`DistributedDataset`. `DistributedDataset`.
options: An `orbit.StandardEvaluatorOptions` instance. options: An `orbit.StandardEvaluatorOptions` instance.
""" """
self._eval_options = options or StandardEvaluatorOptions() options = options or StandardEvaluatorOptions()
if options.use_tf_while_loop and not options.use_tf_function:
raise ValueError("`use_tf_while_loop=True` and `use_tf_function=False` "
"is not supported")
self._eval_options = options
self._eval_dataset = eval_dataset self._eval_dataset = eval_dataset
self._eval_loop_fn = None self._eval_loop_fn = None
...@@ -268,16 +289,28 @@ class StandardEvaluator(runner.AbstractEvaluator, metaclass=abc.ABCMeta): ...@@ -268,16 +289,28 @@ class StandardEvaluator(runner.AbstractEvaluator, metaclass=abc.ABCMeta):
Returns: Returns:
The output of `self.eval_end()`. The output of `self.eval_end()`.
Raises:
ValueError: If `options.use_tf_while_loop` is `True` and `num_steps` is
unspecified.
""" """
if self._eval_options.use_tf_while_loop and num_steps == -1:
raise ValueError("Looping until exhausted is not supported if "
"`options.use_tf_while_loop` is `True`")
outputs = self.eval_begin() # pylint: disable=assignment-from-no-return outputs = self.eval_begin() # pylint: disable=assignment-from-no-return
has_state = outputs is not None
if self._eval_loop_fn is None: if self._eval_loop_fn is None:
self._eval_loop_fn = _create_eval_loop_fn( self._eval_loop_fn = _create_eval_loop_fn(
self.eval_step, options=self._eval_options) self.eval_step, has_state=has_state, options=self._eval_options)
eval_iter = tf.nest.map_structure(iter, self.eval_dataset) eval_iter = tf.nest.map_structure(iter, self.eval_dataset)
outputs = self._eval_loop_fn( if self._eval_options.use_tf_while_loop and not has_state:
eval_iter, num_steps, state=outputs, reduce_fn=self.eval_reduce) self._eval_loop_fn(eval_iter, num_steps)
else:
outputs = self._eval_loop_fn(
eval_iter, num_steps, state=outputs, reduce_fn=self.eval_reduce)
if outputs is None: if outputs is None:
return self.eval_end() return self.eval_end()
......
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
# ============================================================================== # ==============================================================================
"""Tests for orbit.standard_runner.""" """Tests for orbit.standard_runner."""
from absl.testing import parameterized
from orbit import standard_runner from orbit import standard_runner
from orbit import utils from orbit import utils
...@@ -79,7 +81,36 @@ class TestEvaluator(standard_runner.StandardEvaluator): ...@@ -79,7 +81,36 @@ class TestEvaluator(standard_runner.StandardEvaluator):
return self.global_step.numpy() return self.global_step.numpy()
class StandardRunnerTest(tf.test.TestCase): class TestEvaluatorWithOutputsAggregation(standard_runner.StandardEvaluator):
"""A StandardEvaluator subclass for tests."""
def __init__(self, options=None):
self.strategy = tf.distribute.get_strategy()
dataset = self.strategy.distribute_datasets_from_function(
lambda _: tf.data.Dataset.range(10))
super().__init__(eval_dataset=dataset, options=options)
def eval_begin(self):
return tf.constant((0.0,))
def eval_reduce(self, state, step_outputs):
state = tf.concat([state, step_outputs], 0)
return state
def eval_step(self, iterator):
def replica_step(x):
x = tf.cast(x, tf.float32)
return tf.reduce_sum(x)
return self.strategy.experimental_local_results(
self.strategy.run(replica_step, args=(next(iterator),)))
def eval_end(self, outputs):
return tf.reduce_sum(outputs)
class StandardRunnerTest(parameterized.TestCase):
def test_default_trainer(self): def test_default_trainer(self):
trainer = TestTrainer() trainer = TestTrainer()
...@@ -91,10 +122,20 @@ class StandardRunnerTest(tf.test.TestCase): ...@@ -91,10 +122,20 @@ class StandardRunnerTest(tf.test.TestCase):
trainer = TestTrainer(options) trainer = TestTrainer(options)
self.assertEqual(trainer.train(tf.constant(10)), 10) self.assertEqual(trainer.train(tf.constant(10)), 10)
def test_default_evaluator(self): @parameterized.named_parameters(("use_tf_while_loop", True), ("", False))
evaluator = TestEvaluator() def test_default_evaluator(self, use_tf_while_loop):
options = standard_runner.StandardEvaluatorOptions(
use_tf_while_loop=use_tf_while_loop)
evaluator = TestEvaluator(options)
self.assertEqual(evaluator.evaluate(tf.constant(10)), 10) self.assertEqual(evaluator.evaluate(tf.constant(10)), 10)
@parameterized.named_parameters(("use_tf_while_loop", True), ("", False))
def test_evaluator_with_outputs_aggregation(self, use_tf_while_loop):
options = standard_runner.StandardEvaluatorOptions(
use_tf_while_loop=use_tf_while_loop)
evaluator = TestEvaluatorWithOutputsAggregation(options)
self.assertEqual(evaluator.evaluate(tf.constant(10)), 45)
if __name__ == '__main__': if __name__ == "__main__":
tf.test.main() tf.test.main()
...@@ -117,6 +117,60 @@ def create_tf_while_loop_fn(step_fn): ...@@ -117,6 +117,60 @@ def create_tf_while_loop_fn(step_fn):
return loop_fn return loop_fn
def create_tf_while_loop_fn_with_state(step_fn):
"""Creates a TF while loop function with state.
This function is similar to `create_tf_while_loop_fn`, but allowing a `state`
to be accumulated over multiple iterations of the loop. Note that the
structure of the `state` cannot be changed across iterations.
Args:
step_fn: A function taking a nested structure of `tf.data.Iterator` or
`DistributedIterator`. Currently, any return values are ignored.
Returns:
A loop function taking required `iterator`, `num_steps`, `state` and
`reduce_fn` parameters. If called inside a `tf.function`, the loop will be
converted by AutoGraph into a `tf.while_loop` construct. See the `loop_fn`
definition below for additional details.
"""
def loop_fn_with_state(iterator, num_steps, state, reduce_fn):
"""Makes `num_steps` calls to `step_fn(iterator)`.
Args:
iterator: A nested structure of `tf.data.Iterator` or
`DistributedIterator`.
num_steps: The number of steps in the loop. Should be passed as a
`tf.Tensor`. Iterating until iterator exhaustion is not supported.
state: An initial state before running the loop.
reduce_fn: A callable taking two inputs, `state` and `value`, where
`state` is the previous output from `reduce_fn`, and `value` is the
output from `step_fn`.
Returns:
The final state returned by `reduce_fn`.
"""
if not isinstance(num_steps, tf.Tensor):
raise ValueError(
"`num_steps` should be a `tf.Tensor`. Passing a Python value can "
"cause unnecessary retracing when wrapped by `tf.function`.")
for _ in tf.range(num_steps):
# Relax the shapes within the loop, so the shape of `state` can change
# across iterations. This is useful to aggregate outputs from each step
# and concat to `state`.
tf.autograph.experimental.set_loop_options(
shape_invariants=[(t, tf.TensorShape([None] * t.shape.rank))
for t in tf.nest.flatten(state)
if tf.is_tensor(t)])
outputs = step_fn(iterator)
state = reduce_fn(state, outputs)
return state
return loop_fn_with_state
class LoopFnWithSummaries(tpu_summaries.OptionalSummariesFunction): class LoopFnWithSummaries(tpu_summaries.OptionalSummariesFunction):
"""Implements a two-program approach for optimizing summaries on TPU. """Implements a two-program approach for optimizing summaries on TPU.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment