Commit 0e6f8848 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Add a new option "recreate_iterator_for_each_eval" into StandardEvaluator.

In current implementation of StandardEvaluator, the iterator of evaluate dataset will go back (actually create a new iterator) to the beginning when evaluate() is called every time.

In some case, the iterator need to go ahead forever. So add the new option "recreate_iterator_for_each_eval" to control this behavior.

In some case, the iterator creation is time consuming. For example, careting an iterator of distributed dataset for multi worker in eager mode. So it's better to make a repeat dataset instead of recreating an iterator.

If "recreate_iterator_for_each_eval" is True, then the iterator will go back; Otherwise, keep going ahead.

PiperOrigin-RevId: 364356169
parent e27323d0
......@@ -220,9 +220,17 @@ class StandardEvaluatorOptions:
in eager mode.
use_tf_while_loop: A boolean indicating whether to run the training loop
using a `tf.while_loop`. If `True`, `use_tf_function` must also be `True`.
recreate_iterator_for_each_eval: A boolean indicating whether to recreate a
new iterator for the evaluation dataset before each round of evaluation,
which implies each round of evaluation starts from the beginning of
the evaluation dataset. For example, the evaluation dataset is
`[1, 2, 3, 4]`, batch size is 1 and evaluation steps is 2. If `True`, the
data to be evaluated is [1, 2] every time. If `False`, the iterator
state is maintained between calls to `StandardEvaluator.evaluate()`.
"""
use_tf_function: bool = True
use_tf_while_loop: bool = False
recreate_iterator_for_each_eval: bool = True
class StandardEvaluator(runner.AbstractEvaluator, metaclass=abc.ABCMeta):
......@@ -261,6 +269,7 @@ class StandardEvaluator(runner.AbstractEvaluator, metaclass=abc.ABCMeta):
self._eval_options = options
self._eval_dataset = eval_dataset
self._eval_iter = None
self._eval_loop_fn = None
def create_eval_loop_fn(self, has_state: bool):
......@@ -313,7 +322,15 @@ class StandardEvaluator(runner.AbstractEvaluator, metaclass=abc.ABCMeta):
if self._eval_loop_fn is None:
self._eval_loop_fn = self.create_eval_loop_fn(has_state)
# If `recreate_iterator_for_each_eval` is `True`, `self._eval_iter` is
# always None.
if self._eval_iter is None:
eval_iter = tf.nest.map_structure(iter, self.eval_dataset)
if not self._eval_options.recreate_iterator_for_each_eval:
self._eval_iter = eval_iter
else:
eval_iter = self._eval_iter
if self._eval_options.use_tf_while_loop and not has_state:
self._eval_loop_fn(eval_iter, num_steps)
else:
......@@ -421,3 +438,4 @@ class StandardEvaluator(runner.AbstractEvaluator, metaclass=abc.ABCMeta):
`DistributedDataset`.
"""
self._eval_dataset = eval_dataset
self._eval_iter = None
......@@ -136,6 +136,17 @@ class StandardRunnerTest(parameterized.TestCase):
evaluator = TestEvaluatorWithOutputsAggregation(options)
self.assertEqual(evaluator.evaluate(tf.constant(10)), 45)
@parameterized.named_parameters(
("recreate_iterator_for_each_eval", True, 10, 10),
("not_recreate_iterator_for_each_eval", False, 10, 35))
def test_evaluator_with_repeat_dataset(self, recreate_iterator_for_each_eval,
sum_for_1st_time, sum_for_2nd_time):
options = standard_runner.StandardEvaluatorOptions(
recreate_iterator_for_each_eval=recreate_iterator_for_each_eval)
evaluator = TestEvaluatorWithOutputsAggregation(options)
self.assertEqual(evaluator.evaluate(tf.constant(5)), sum_for_1st_time)
self.assertEqual(evaluator.evaluate(tf.constant(5)), sum_for_2nd_time)
if __name__ == "__main__":
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment