Commit 856622d3 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 374244811
parent fb9f35c8
...@@ -19,6 +19,7 @@ The base trainer implements the Orbit `StandardTrainable` and ...@@ -19,6 +19,7 @@ The base trainer implements the Orbit `StandardTrainable` and
interchangable and independent on model architectures and tasks. interchangable and independent on model architectures and tasks.
""" """
import functools import functools
from typing import Union, Optional
from absl import logging from absl import logging
import gin import gin
import orbit import orbit
...@@ -28,7 +29,6 @@ from official.core import base_task ...@@ -28,7 +29,6 @@ from official.core import base_task
from official.core import config_definitions from official.core import config_definitions
from official.modeling import optimization from official.modeling import optimization
ExperimentConfig = config_definitions.ExperimentConfig ExperimentConfig = config_definitions.ExperimentConfig
TrainerConfig = config_definitions.TrainerConfig TrainerConfig = config_definitions.TrainerConfig
...@@ -143,6 +143,7 @@ class _AsyncTrainer(orbit.StandardTrainer, orbit.StandardEvaluator): ...@@ -143,6 +143,7 @@ class _AsyncTrainer(orbit.StandardTrainer, orbit.StandardEvaluator):
`tf.distribute.InputContext` instance. `tf.distribute.InputContext` instance.
*args: Any positional arguments to pass through to `dataset_or_fn`. *args: Any positional arguments to pass through to `dataset_or_fn`.
**kwargs: Any keyword arguments to pass through to `dataset_or_fn`. **kwargs: Any keyword arguments to pass through to `dataset_or_fn`.
Returns: Returns:
A distributed Dataset. A distributed Dataset.
""" """
...@@ -173,14 +174,19 @@ class Trainer(_AsyncTrainer): ...@@ -173,14 +174,19 @@ class Trainer(_AsyncTrainer):
"""Implements the common trainer shared for TensorFlow models.""" """Implements the common trainer shared for TensorFlow models."""
# pylint: disable=super-init-not-called # pylint: disable=super-init-not-called
def __init__(self, def __init__(
config: ExperimentConfig, self,
task: base_task.Task, config: ExperimentConfig,
model: tf.keras.Model, task: base_task.Task,
optimizer: tf.optimizers.Optimizer, model: tf.keras.Model,
train: bool = True, optimizer: tf.optimizers.Optimizer,
evaluate: bool = True, train: bool = True,
checkpoint_exporter=None): evaluate: bool = True,
train_dataset: Optional[Union[tf.data.Dataset,
tf.distribute.DistributedDataset]] = None,
validation_dataset: Optional[Union[
tf.data.Dataset, tf.distribute.DistributedDataset]] = None,
checkpoint_exporter=None):
"""Initialize common trainer for TensorFlow models. """Initialize common trainer for TensorFlow models.
Args: Args:
...@@ -192,13 +198,22 @@ class Trainer(_AsyncTrainer): ...@@ -192,13 +198,22 @@ class Trainer(_AsyncTrainer):
default to True. default to True.
evaluate: bool, whether or not this trainer will be used for evaluation. evaluate: bool, whether or not this trainer will be used for evaluation.
default to True. default to True.
train_dataset: a dataset object created for training. With tf.distribute,
it needs to be a `DistributedDataset`.
validation_dataset: a dataset object created for evaluation. With
tf.distribute, it needs to be a `DistributedDataset`. The evaluator will
create a dataset iterator for each eval round, so the dataset does not
need to repeat.
checkpoint_exporter: an object that has the `maybe_export_checkpoint` checkpoint_exporter: an object that has the `maybe_export_checkpoint`
interface. interface.
""" """
# Gets the current distribution strategy. If not inside any strategy scope, # Gets the current distribution strategy. If not inside any strategy scope,
# it gets a single-replica no-op strategy. # it gets a single-replica no-op strategy.
self._strategy = tf.distribute.get_strategy() self._strategy = tf.distribute.get_strategy()
self._validate_params(config) self._validate_params(
config,
check_train_data=train_dataset is None,
check_validation_data=validation_dataset is None)
self._config = config self._config = config
self._task = task self._task = task
self._model = model self._model = model
...@@ -239,7 +254,7 @@ class Trainer(_AsyncTrainer): ...@@ -239,7 +254,7 @@ class Trainer(_AsyncTrainer):
self.init_async() self.init_async()
if train: if train:
train_dataset = self.distribute_dataset( train_dataset = train_dataset or self.distribute_dataset(
self.task.build_inputs, self.config.task.train_data) self.task.build_inputs, self.config.task.train_data)
orbit.StandardTrainer.__init__( orbit.StandardTrainer.__init__(
self, self,
...@@ -250,16 +265,19 @@ class Trainer(_AsyncTrainer): ...@@ -250,16 +265,19 @@ class Trainer(_AsyncTrainer):
use_tpu_summary_optimization=config.trainer.allow_tpu_summary)) use_tpu_summary_optimization=config.trainer.allow_tpu_summary))
if evaluate: if evaluate:
eval_dataset = self.distribute_dataset( validation_dataset = validation_dataset or self.distribute_dataset(
self.task.build_inputs, self.config.task.validation_data) self.task.build_inputs, self.config.task.validation_data)
orbit.StandardEvaluator.__init__( orbit.StandardEvaluator.__init__(
self, self,
eval_dataset, validation_dataset,
options=orbit.StandardEvaluatorOptions( options=orbit.StandardEvaluatorOptions(
use_tf_function=config.trainer.eval_tf_function, use_tf_function=config.trainer.eval_tf_function,
use_tf_while_loop=config.trainer.eval_tf_while_loop)) use_tf_while_loop=config.trainer.eval_tf_while_loop))
def _validate_params(self, config): def _validate_params(self,
config,
check_train_data=True,
check_validation_data=True):
r"""Validates if the configuration object passed to the Trainer. r"""Validates if the configuration object passed to the Trainer.
The experiment configuration should be structured as: The experiment configuration should be structured as:
...@@ -270,6 +288,8 @@ class Trainer(_AsyncTrainer): ...@@ -270,6 +288,8 @@ class Trainer(_AsyncTrainer):
Args: Args:
config: a namedtuple, dataclass, ConfigDict, etc. config: a namedtuple, dataclass, ConfigDict, etc.
check_train_data: whether to check task.train_data field.
check_validation_data: whether to check task.validation_data field.
""" """
if not hasattr(config, "trainer"): if not hasattr(config, "trainer"):
raise AttributeError("The trainer requires the configuration contains an" raise AttributeError("The trainer requires the configuration contains an"
...@@ -279,11 +299,11 @@ class Trainer(_AsyncTrainer): ...@@ -279,11 +299,11 @@ class Trainer(_AsyncTrainer):
raise AttributeError("The trainer requires the configuration contains an" raise AttributeError("The trainer requires the configuration contains an"
" attribute `task`.") " attribute `task`.")
if not hasattr(config.task, "train_data"): if check_train_data and not hasattr(config.task, "train_data"):
raise AttributeError("The trainer requires the configuration contains an" raise AttributeError("The trainer requires the configuration contains an"
" attribute `task.train_data`.") " attribute `task.train_data`.")
if not hasattr(config.task, "validation_data"): if check_validation_data and not hasattr(config.task, "validation_data"):
raise AttributeError("The trainer requires the configuration contains an" raise AttributeError("The trainer requires the configuration contains an"
" attribute `task.validation_data`.") " attribute `task.validation_data`.")
...@@ -406,8 +426,8 @@ class Trainer(_AsyncTrainer): ...@@ -406,8 +426,8 @@ class Trainer(_AsyncTrainer):
for metric in self.validation_metrics + [self.validation_loss]: for metric in self.validation_metrics + [self.validation_loss]:
metric.reset_states() metric.reset_states()
# Swaps weights to test on weights moving average. # Swaps weights to test on weights moving average.
if self.optimizer and isinstance( if self.optimizer and isinstance(self.optimizer,
self.optimizer, optimization.ExponentialMovingAverage): optimization.ExponentialMovingAverage):
self.optimizer.swap_weights() self.optimizer.swap_weights()
def eval_step(self, iterator): def eval_step(self, iterator):
...@@ -451,8 +471,8 @@ class Trainer(_AsyncTrainer): ...@@ -451,8 +471,8 @@ class Trainer(_AsyncTrainer):
# Swaps back weights after testing when EMA is used. # Swaps back weights after testing when EMA is used.
# This happens after best checkpoint export so that average weights used for # This happens after best checkpoint export so that average weights used for
# eval are exported instead of regular weights. # eval are exported instead of regular weights.
if self.optimizer and isinstance( if self.optimizer and isinstance(self.optimizer,
self.optimizer, optimization.ExponentialMovingAverage): optimization.ExponentialMovingAverage):
self.optimizer.swap_weights() self.optimizer.swap_weights()
return logs return logs
......
...@@ -20,6 +20,7 @@ import sys ...@@ -20,6 +20,7 @@ import sys
from absl.testing import parameterized from absl.testing import parameterized
import numpy as np import numpy as np
import orbit
import portpicker import portpicker
import tensorflow as tf import tensorflow as tf
...@@ -111,15 +112,14 @@ class MockAsyncTrainer(trainer_lib._AsyncTrainer): ...@@ -111,15 +112,14 @@ class MockAsyncTrainer(trainer_lib._AsyncTrainer):
aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA) aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA)
train_dataset = self.distribute_dataset(dataset_fn) train_dataset = self.distribute_dataset(dataset_fn)
trainer_lib.orbit.StandardTrainer.__init__( orbit.StandardTrainer.__init__(
self, train_dataset, options=trainer_lib.orbit.StandardTrainerOptions()) self, train_dataset, options=orbit.StandardTrainerOptions())
eval_dataset = self.distribute_dataset(dataset_fn) validation_dataset = self.distribute_dataset(dataset_fn)
trainer_lib.orbit.StandardEvaluator.__init__( orbit.StandardEvaluator.__init__(
self, self,
eval_dataset, validation_dataset,
options=trainer_lib.orbit.StandardEvaluatorOptions( options=orbit.StandardEvaluatorOptions(use_tf_while_loop=True))
use_tf_while_loop=True))
def train_loop_begin(self): def train_loop_begin(self):
self.global_step.assign(0) self.global_step.assign(0)
...@@ -185,6 +185,30 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -185,6 +185,30 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase):
self.assertIn('training_loss', logs) self.assertIn('training_loss', logs)
self.assertIn('learning_rate', logs) self.assertIn('learning_rate', logs)
@combinations.generate(all_strategy_combinations())
def test_trainer_passing_datasets(self, distribution):
with distribution.scope():
task = mock_task.MockTask(self._config)
train_dataset = orbit.utils.make_distributed_dataset(
distribution, task.build_inputs, self._config.task.train_data)
validation_dataset = orbit.utils.make_distributed_dataset(
distribution, task.build_inputs, self._config.task.validation_data)
self._config.task.train_data = None
self._config.task.validation_data = None
trainer = trainer_lib.Trainer(
self._config,
task,
model=task.build_model(),
optimizer=task.create_optimizer(self._config.trainer.optimizer_config,
self._config.runtime),
train_dataset=train_dataset,
validation_dataset=validation_dataset)
logs = trainer.train(tf.convert_to_tensor(5, dtype=tf.int32))
self.assertIn('training_loss', logs)
self.assertIn('learning_rate', logs)
logs = trainer.evaluate(tf.convert_to_tensor(5, dtype=tf.int32))
self.assertIn('validation_loss', logs)
def test_base_async_trainer(self): def test_base_async_trainer(self):
if TPU_TEST or GPU_TEST: if TPU_TEST or GPU_TEST:
self.skipTest('Aysnc training is not available on GPU/GPU.') self.skipTest('Aysnc training is not available on GPU/GPU.')
...@@ -204,7 +228,7 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -204,7 +228,7 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase):
def test_async_trainer_train(self): def test_async_trainer_train(self):
if TPU_TEST or GPU_TEST: if TPU_TEST or GPU_TEST:
self.skipTest('Aysnc training is not available on GPU/GPU.') self.skipTest('Aysnc training is not available on GPU/TPU.')
num_workers = 3 num_workers = 3
num_ps = 2 num_ps = 2
cluster_resolver = create_in_process_cluster(num_workers, num_ps) cluster_resolver = create_in_process_cluster(num_workers, num_ps)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment