Commit 0edeca54 authored by Dan Holtmann-Rice's avatar Dan Holtmann-Rice Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 325088513
parent 29d45e88
...@@ -107,9 +107,12 @@ class ResnetRunnable(orbit.StandardTrainer, orbit.StandardEvaluator): ...@@ -107,9 +107,12 @@ class ResnetRunnable(orbit.StandardTrainer, orbit.StandardEvaluator):
.datasets_num_private_threads, .datasets_num_private_threads,
dtype=self.dtype, dtype=self.dtype,
drop_remainder=True) drop_remainder=True)
orbit.StandardTrainer.__init__(self, train_dataset, orbit.StandardTrainer.__init__(
flags_obj.use_tf_while_loop, self,
flags_obj.use_tf_function) train_dataset,
options=orbit.StandardTrainerOptions(
use_tf_while_loop=flags_obj.use_tf_while_loop,
use_tf_function=flags_obj.use_tf_function))
if not flags_obj.skip_eval: if not flags_obj.skip_eval:
eval_dataset = orbit.utils.make_distributed_dataset( eval_dataset = orbit.utils.make_distributed_dataset(
self.strategy, self.strategy,
...@@ -119,8 +122,11 @@ class ResnetRunnable(orbit.StandardTrainer, orbit.StandardEvaluator): ...@@ -119,8 +122,11 @@ class ResnetRunnable(orbit.StandardTrainer, orbit.StandardEvaluator):
batch_size=self.batch_size, batch_size=self.batch_size,
parse_record_fn=imagenet_preprocessing.parse_record, parse_record_fn=imagenet_preprocessing.parse_record,
dtype=self.dtype) dtype=self.dtype)
orbit.StandardEvaluator.__init__(self, eval_dataset, orbit.StandardEvaluator.__init__(
flags_obj.use_tf_function) self,
eval_dataset,
options=orbit.StandardEvaluatorOptions(
use_tf_function=flags_obj.use_tf_function))
def train_loop_begin(self): def train_loop_begin(self):
"""See base class.""" """See base class."""
......
...@@ -221,7 +221,10 @@ class TestTrainerWithSummaries(standard_runner.StandardTrainer): ...@@ -221,7 +221,10 @@ class TestTrainerWithSummaries(standard_runner.StandardTrainer):
self.strategy.experimental_distribute_datasets_from_function(dataset_fn) self.strategy.experimental_distribute_datasets_from_function(dataset_fn)
) )
standard_runner.StandardTrainer.__init__( standard_runner.StandardTrainer.__init__(
self, train_dataset, use_tpu_summary_optimization=True) self,
train_dataset,
options=standard_runner.StandardTrainerOptions(
use_tpu_summary_optimization=True))
def build_train_dataset(self): def build_train_dataset(self):
return self.strategy.experimental_distribute_datasets_from_function( return self.strategy.experimental_distribute_datasets_from_function(
......
...@@ -23,20 +23,22 @@ import tensorflow as tf ...@@ -23,20 +23,22 @@ import tensorflow as tf
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class TrainerOverrides: class StandardTrainerOptions:
"""Advanced overrides for Orbit trainers. """Advanced options for `orbit.StandardTrainer`.
Attributes: Attributes:
use_tf_while_loop: A boolean indicates whether to wrap the train step with use_tf_while_loop: A boolean indicating whether to run the training loop
a `tf.while_loop`. using a `tf.while_loop`. If `True`, `use_tf_function` must also be `True`.
use_tf_function: A boolean indicates whether a `tf.function` will be used. use_tf_function: A boolean indicating whether to apply `tf.function` to the
If False, training will run on pure eager mode. training loop. This will only affect the body of the loop (involving
use_tpu_summary_optimization: A boolean indicates whether to enable the `train_step`); `train_loop_begin` and `train_loop_end` will always be run
performance optimization for summaries in TPUs. In TPUs, writing in eager mode.
summaries with outside compilation inside train step is slow. If True, use_tpu_summary_optimization: A boolean indicating whether to enable a
it creates two `tf.function` with two XLA programs: one with summaries performance optimization for summaries in TPUs. Writing summaries
and one without, and run the program with summaries (slow one) only if conditionally with outside compilation on TPUs can be extremely slow. If
necessary. `True`, this optimization creates two `tf.function`s with two XLA programs
(one with summary calls, and one without). The program with summaries runs
only for one step when summaries should be recorded.
""" """
use_tf_while_loop: bool = True use_tf_while_loop: bool = True
use_tf_function: bool = True use_tf_function: bool = True
...@@ -46,39 +48,29 @@ class TrainerOverrides: ...@@ -46,39 +48,29 @@ class TrainerOverrides:
class StandardTrainer(runner.AbstractTrainer, metaclass=abc.ABCMeta): class StandardTrainer(runner.AbstractTrainer, metaclass=abc.ABCMeta):
"""Implements the standard functionality of AbstractTrainer APIs.""" """Implements the standard functionality of AbstractTrainer APIs."""
def __init__(self, def __init__(self, train_dataset, options: StandardTrainerOptions = None):
train_dataset,
use_tf_while_loop=True,
use_tf_function=True,
use_tpu_summary_optimization=False):
"""Construct a `StandardTrainer` object. """Construct a `StandardTrainer` object.
Args: Args:
train_dataset: A tf.nest-compatible structure of tf.data.Dataset or train_dataset: A tf.nest-compatible structure of tf.data.Dataset or
DistributedDataset. DistributedDataset.
use_tf_while_loop: A boolean indicates whether to wrap the train step with options: An `orbit.StandardTrainerOptions` instance.
a `tf.while_loop`.
use_tf_function: A boolean indicates whether a `tf.function` will be used.
If False, training will run on pure eager mode.
use_tpu_summary_optimization: A boolean indicates whether to enable the
performance optimization for summaries in TPUs. In TPUs, writing
summaries with outside compilation inside train step is slow. If True,
it creates two `tf.function` with two XLA programs: one with summaries
and one without, and run the program with summaries (slow one) only if
necessary.
""" """
if use_tf_while_loop and not use_tf_function: options = options or StandardTrainerOptions()
if options.use_tf_while_loop and not options.use_tf_function:
raise ValueError("`use_tf_while_loop=True` and `use_tf_function=False` " raise ValueError("`use_tf_while_loop=True` and `use_tf_function=False` "
"is not supported") "is not supported")
if use_tpu_summary_optimization and not use_tf_while_loop: if options.use_tpu_summary_optimization and not options.use_tf_while_loop:
raise ValueError("`use_tpu_summary_optimization=True` and " raise ValueError("`use_tpu_summary_optimization=True` and "
"`use_tf_while_loop=False` is not supported") "`use_tf_while_loop=False` is not supported")
self._use_tf_while_loop = use_tf_while_loop
self._use_tf_function = use_tf_function self._use_tf_while_loop = options.use_tf_while_loop
self._use_tf_function = options.use_tf_function
self._use_tpu_summary_optimization = options.use_tpu_summary_optimization
self._train_dataset = train_dataset self._train_dataset = train_dataset
self._train_iter = None self._train_iter = None
self._train_loop_fn = None self._train_loop_fn = None
self._use_tpu_summary_optimization = use_tpu_summary_optimization
def train(self, def train(self,
num_steps: Optional[tf.Tensor]) -> Optional[Dict[Text, tf.Tensor]]: num_steps: Optional[tf.Tensor]) -> Optional[Dict[Text, tf.Tensor]]:
...@@ -168,12 +160,14 @@ class StandardTrainer(runner.AbstractTrainer, metaclass=abc.ABCMeta): ...@@ -168,12 +160,14 @@ class StandardTrainer(runner.AbstractTrainer, metaclass=abc.ABCMeta):
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class EvaluatorOverrides: class StandardEvaluatorOptions:
"""Advanced overrides for Orbit evaluators. """Advanced options for the `orbit.StandardEvaluator`.
Attributes: Attributes:
use_tf_function: A boolean indicates whether a `tf.function` will be used. use_tf_function: A boolean indicating whether to apply `tf.function` to the
If False, training will run on pure eager mode. training loop. This will only affect the body of the loop (involving
`train_step`); `train_loop_begin` and `train_loop_end` will always be run
in eager mode.
""" """
use_tf_function: bool = True use_tf_function: bool = True
...@@ -181,16 +175,16 @@ class EvaluatorOverrides: ...@@ -181,16 +175,16 @@ class EvaluatorOverrides:
class StandardEvaluator(runner.AbstractEvaluator, metaclass=abc.ABCMeta): class StandardEvaluator(runner.AbstractEvaluator, metaclass=abc.ABCMeta):
"""Implements the standard functionality of AbstractEvaluator APIs.""" """Implements the standard functionality of AbstractEvaluator APIs."""
def __init__(self, eval_dataset, use_tf_function=True): def __init__(self, eval_dataset, options: StandardEvaluatorOptions = None):
"""Construct a `StandardEvaluator` object. """Construct a `StandardEvaluator` object.
Args: Args:
eval_dataset: A tf.nest-compatible structure of tf.data.Dataset or eval_dataset: A tf.nest-compatible structure of tf.data.Dataset or
DistributedDataset. DistributedDataset.
use_tf_function: A boolean indicates whether a `tf.function` will be used. options: An `orbit.StandardEvaluatorOptions` instance.
If False, evaluation will run on pure eager mode.
""" """
self._eval_use_tf_function = use_tf_function options = options or StandardEvaluatorOptions()
self._eval_use_tf_function = options.use_tf_function
self._eval_dataset = eval_dataset self._eval_dataset = eval_dataset
self._eval_loop_fn = None self._eval_loop_fn = None
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
"""Tests for orbit.standard_runner.""" """Tests for orbit.standard_runner."""
from orbit import standard_runner from orbit import standard_runner
from orbit import utils
import tensorflow as tf import tensorflow as tf
...@@ -32,46 +33,49 @@ def dataset_fn(input_context=None): ...@@ -32,46 +33,49 @@ def dataset_fn(input_context=None):
return dataset return dataset
class TestRunner(standard_runner.StandardTrainer, class TestTrainer(standard_runner.StandardTrainer):
standard_runner.StandardEvaluator): """A StandardTrainer subclass for tests."""
"""Implements the training and evaluation APIs for tests."""
def __init__(self): def __init__(self, options=None):
self.strategy = tf.distribute.get_strategy() self.strategy = tf.distribute.get_strategy()
self.global_step = tf.Variable( self.global_step = utils.create_global_step()
0, distribute = self.strategy.experimental_distribute_datasets_from_function
trainable=False, dataset = distribute(dataset_fn)
dtype=tf.int64, super().__init__(train_dataset=dataset, options=options)
name='global_step',
aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA)
standard_runner.StandardTrainer.__init__(self, train_dataset=None)
standard_runner.StandardEvaluator.__init__(self, eval_dataset=None)
def train_loop_begin(self): def train_loop_begin(self):
self.train_dataset = ( self.global_step.assign(0)
self.strategy.experimental_distribute_datasets_from_function(dataset_fn)
)
def train_step(self, iterator): def train_step(self, iterator):
def _replicated_step(_): def replica_step(_):
self.global_step.assign_add(1) self.global_step.assign_add(1)
self.strategy.run(_replicated_step, args=(next(iterator),)) self.strategy.run(replica_step, args=(next(iterator),))
def train_loop_end(self): def train_loop_end(self):
return self.global_step.numpy() return self.global_step.numpy()
class TestEvaluator(standard_runner.StandardEvaluator):
"""A StandardEvaluator subclass for tests."""
def __init__(self, options=None):
self.strategy = tf.distribute.get_strategy()
self.global_step = utils.create_global_step()
distribute = self.strategy.experimental_distribute_datasets_from_function
dataset = distribute(dataset_fn)
super().__init__(eval_dataset=dataset, options=options)
def eval_begin(self): def eval_begin(self):
self.eval_dataset = self.strategy.experimental_distribute_datasets_from_function( self.global_step.assign(0)
dataset_fn)
def eval_step(self, iterator): def eval_step(self, iterator):
def _replicated_step(_): def replica_step(_):
self.global_step.assign_add(1) self.global_step.assign_add(1)
self.strategy.run(_replicated_step, args=(next(iterator),)) self.strategy.run(replica_step, args=(next(iterator),))
def eval_end(self): def eval_end(self):
return self.global_step.numpy() return self.global_step.numpy()
...@@ -79,15 +83,19 @@ class TestRunner(standard_runner.StandardTrainer, ...@@ -79,15 +83,19 @@ class TestRunner(standard_runner.StandardTrainer,
class StandardRunnerTest(tf.test.TestCase): class StandardRunnerTest(tf.test.TestCase):
def test_train(self): def test_default_trainer(self):
test_runner = TestRunner() trainer = TestTrainer()
self.assertEqual( self.assertEqual(trainer.train(tf.constant(10)), 10)
test_runner.train(tf.convert_to_tensor(10, dtype=tf.int32)), 10)
def test_trainer_with_tpu_summary_optimization(self):
options = standard_runner.StandardTrainerOptions(
use_tpu_summary_optimization=True)
trainer = TestTrainer(options)
self.assertEqual(trainer.train(tf.constant(10)), 10)
def test_eval(self): def test_default_evaluator(self):
test_runner = TestRunner() evaluator = TestEvaluator()
self.assertEqual( self.assertEqual(evaluator.evaluate(tf.constant(10)), 10)
test_runner.evaluate(tf.convert_to_tensor(10, dtype=tf.int32)), 10)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment