Commit 856622d3 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 374244811
parent fb9f35c8
......@@ -19,6 +19,7 @@ The base trainer implements the Orbit `StandardTrainable` and
interchangable and independent on model architectures and tasks.
"""
import functools
from typing import Union, Optional
from absl import logging
import gin
import orbit
......@@ -28,7 +29,6 @@ from official.core import base_task
from official.core import config_definitions
from official.modeling import optimization
ExperimentConfig = config_definitions.ExperimentConfig
TrainerConfig = config_definitions.TrainerConfig
......@@ -143,6 +143,7 @@ class _AsyncTrainer(orbit.StandardTrainer, orbit.StandardEvaluator):
`tf.distribute.InputContext` instance.
*args: Any positional arguments to pass through to `dataset_or_fn`.
**kwargs: Any keyword arguments to pass through to `dataset_or_fn`.
Returns:
A distributed Dataset.
"""
......@@ -173,14 +174,19 @@ class Trainer(_AsyncTrainer):
"""Implements the common trainer shared for TensorFlow models."""
# pylint: disable=super-init-not-called
def __init__(self,
config: ExperimentConfig,
task: base_task.Task,
model: tf.keras.Model,
optimizer: tf.optimizers.Optimizer,
train: bool = True,
evaluate: bool = True,
checkpoint_exporter=None):
def __init__(
self,
config: ExperimentConfig,
task: base_task.Task,
model: tf.keras.Model,
optimizer: tf.optimizers.Optimizer,
train: bool = True,
evaluate: bool = True,
train_dataset: Optional[Union[tf.data.Dataset,
tf.distribute.DistributedDataset]] = None,
validation_dataset: Optional[Union[
tf.data.Dataset, tf.distribute.DistributedDataset]] = None,
checkpoint_exporter=None):
"""Initialize common trainer for TensorFlow models.
Args:
......@@ -192,13 +198,22 @@ class Trainer(_AsyncTrainer):
default to True.
evaluate: bool, whether or not this trainer will be used for evaluation.
default to True.
train_dataset: a dataset object created for training. With tf.distribute,
it needs to be a `DistributedDataset`.
validation_dataset: a dataset object created for evaluation. With
tf.distribute, it needs to be a `DistributedDataset`. The evaluator will
create a dataset iterator for each eval round, so the dataset does not
need to repeat.
checkpoint_exporter: an object that has the `maybe_export_checkpoint`
interface.
"""
# Gets the current distribution strategy. If not inside any strategy scope,
# it gets a single-replica no-op strategy.
self._strategy = tf.distribute.get_strategy()
self._validate_params(config)
self._validate_params(
config,
check_train_data=train_dataset is None,
check_validation_data=validation_dataset is None)
self._config = config
self._task = task
self._model = model
......@@ -239,7 +254,7 @@ class Trainer(_AsyncTrainer):
self.init_async()
if train:
train_dataset = self.distribute_dataset(
train_dataset = train_dataset or self.distribute_dataset(
self.task.build_inputs, self.config.task.train_data)
orbit.StandardTrainer.__init__(
self,
......@@ -250,16 +265,19 @@ class Trainer(_AsyncTrainer):
use_tpu_summary_optimization=config.trainer.allow_tpu_summary))
if evaluate:
eval_dataset = self.distribute_dataset(
validation_dataset = validation_dataset or self.distribute_dataset(
self.task.build_inputs, self.config.task.validation_data)
orbit.StandardEvaluator.__init__(
self,
eval_dataset,
validation_dataset,
options=orbit.StandardEvaluatorOptions(
use_tf_function=config.trainer.eval_tf_function,
use_tf_while_loop=config.trainer.eval_tf_while_loop))
def _validate_params(self, config):
def _validate_params(self,
config,
check_train_data=True,
check_validation_data=True):
r"""Validates if the configuration object passed to the Trainer.
The experiment configuration should be structured as:
......@@ -270,6 +288,8 @@ class Trainer(_AsyncTrainer):
Args:
config: a namedtuple, dataclass, ConfigDict, etc.
check_train_data: whether to check task.train_data field.
check_validation_data: whether to check task.validation_data field.
"""
if not hasattr(config, "trainer"):
raise AttributeError("The trainer requires the configuration contains an"
......@@ -279,11 +299,11 @@ class Trainer(_AsyncTrainer):
raise AttributeError("The trainer requires the configuration contains an"
" attribute `task`.")
if not hasattr(config.task, "train_data"):
if check_train_data and not hasattr(config.task, "train_data"):
raise AttributeError("The trainer requires the configuration contains an"
" attribute `task.train_data`.")
if not hasattr(config.task, "validation_data"):
if check_validation_data and not hasattr(config.task, "validation_data"):
raise AttributeError("The trainer requires the configuration contains an"
" attribute `task.validation_data`.")
......@@ -406,8 +426,8 @@ class Trainer(_AsyncTrainer):
for metric in self.validation_metrics + [self.validation_loss]:
metric.reset_states()
# Swaps weights to test on weights moving average.
if self.optimizer and isinstance(
self.optimizer, optimization.ExponentialMovingAverage):
if self.optimizer and isinstance(self.optimizer,
optimization.ExponentialMovingAverage):
self.optimizer.swap_weights()
def eval_step(self, iterator):
......@@ -451,8 +471,8 @@ class Trainer(_AsyncTrainer):
# Swaps back weights after testing when EMA is used.
# This happens after best checkpoint export so that average weights used for
# eval are exported instead of regular weights.
if self.optimizer and isinstance(
self.optimizer, optimization.ExponentialMovingAverage):
if self.optimizer and isinstance(self.optimizer,
optimization.ExponentialMovingAverage):
self.optimizer.swap_weights()
return logs
......
......@@ -20,6 +20,7 @@ import sys
from absl.testing import parameterized
import numpy as np
import orbit
import portpicker
import tensorflow as tf
......@@ -111,15 +112,14 @@ class MockAsyncTrainer(trainer_lib._AsyncTrainer):
aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA)
train_dataset = self.distribute_dataset(dataset_fn)
trainer_lib.orbit.StandardTrainer.__init__(
self, train_dataset, options=trainer_lib.orbit.StandardTrainerOptions())
orbit.StandardTrainer.__init__(
self, train_dataset, options=orbit.StandardTrainerOptions())
eval_dataset = self.distribute_dataset(dataset_fn)
trainer_lib.orbit.StandardEvaluator.__init__(
validation_dataset = self.distribute_dataset(dataset_fn)
orbit.StandardEvaluator.__init__(
self,
eval_dataset,
options=trainer_lib.orbit.StandardEvaluatorOptions(
use_tf_while_loop=True))
validation_dataset,
options=orbit.StandardEvaluatorOptions(use_tf_while_loop=True))
def train_loop_begin(self):
self.global_step.assign(0)
......@@ -185,6 +185,30 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase):
self.assertIn('training_loss', logs)
self.assertIn('learning_rate', logs)
@combinations.generate(all_strategy_combinations())
def test_trainer_passing_datasets(self, distribution):
with distribution.scope():
task = mock_task.MockTask(self._config)
train_dataset = orbit.utils.make_distributed_dataset(
distribution, task.build_inputs, self._config.task.train_data)
validation_dataset = orbit.utils.make_distributed_dataset(
distribution, task.build_inputs, self._config.task.validation_data)
self._config.task.train_data = None
self._config.task.validation_data = None
trainer = trainer_lib.Trainer(
self._config,
task,
model=task.build_model(),
optimizer=task.create_optimizer(self._config.trainer.optimizer_config,
self._config.runtime),
train_dataset=train_dataset,
validation_dataset=validation_dataset)
logs = trainer.train(tf.convert_to_tensor(5, dtype=tf.int32))
self.assertIn('training_loss', logs)
self.assertIn('learning_rate', logs)
logs = trainer.evaluate(tf.convert_to_tensor(5, dtype=tf.int32))
self.assertIn('validation_loss', logs)
def test_base_async_trainer(self):
if TPU_TEST or GPU_TEST:
self.skipTest('Aysnc training is not available on GPU/GPU.')
......@@ -204,7 +228,7 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase):
def test_async_trainer_train(self):
if TPU_TEST or GPU_TEST:
self.skipTest('Aysnc training is not available on GPU/GPU.')
self.skipTest('Aysnc training is not available on GPU/TPU.')
num_workers = 3
num_ps = 2
cluster_resolver = create_in_process_cluster(num_workers, num_ps)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment