Commit 61a61902 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Use `strategy.distribute_datasets_from_function` in the classifier trainer.

PiperOrigin-RevId: 307483983
parent ba772461
......@@ -32,7 +32,6 @@ import tensorflow as tf
from official.modeling import performance
from official.modeling.hyperparams import params_dict
from official.utils import hyperparams_flags
from official.utils.logs import logger
from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils
from official.vision.image_classification import callbacks as custom_callbacks
......@@ -316,7 +315,8 @@ def train_and_eval(
one_hot = label_smoothing and label_smoothing > 0
builders = _get_dataset_builders(params, strategy, one_hot)
datasets = [builder.build() if builder else None for builder in builders]
datasets = [builder.build(strategy)
if builder else None for builder in builders]
# Unpack datasets and builders based on train/val/test splits
train_builder, validation_builder = builders # pylint: disable=unbalanced-tuple-unpacking
......
......@@ -10,6 +10,7 @@ train_dataset:
num_classes: 1000
num_examples: 1281167
batch_size: 32
use_per_replica_batch_size: True
dtype: 'float32'
validation_dataset:
name: 'imagenet2012'
......@@ -19,6 +20,7 @@ validation_dataset:
num_classes: 1000
num_examples: 50000
batch_size: 32
use_per_replica_batch_size: True
dtype: 'float32'
model:
model_params:
......
......@@ -84,7 +84,8 @@ class DatasetConfig(base_config.Config):
use_per_replica_batch_size: Whether to scale the batch size based on
available resources. If set to `True`, the dataset builder will return
batch_size multiplied by `num_devices`, the number of device replicas
(e.g., the number of GPUs or TPU cores).
(e.g., the number of GPUs or TPU cores). This setting should be `True` if
the strategy argument is passed to `build()` and `num_devices > 1`.
num_devices: The number of replica devices to use. This should be set by
`strategy.num_replicas_in_sync` when using a distribution strategy.
dtype: The desired dtype of the dataset. This will be set during
......@@ -194,6 +195,14 @@ class DatasetBuilder:
"""The global batch size across all replicas."""
return self.batch_size
@property
def local_batch_size(self):
"""The base unscaled batch size."""
if self.config.use_per_replica_batch_size:
return self.config.batch_size
else:
return self.config.batch_size // self.config.num_devices
@property
def num_steps(self) -> int:
"""The number of steps (batches) to exhaust this dataset."""
......@@ -264,19 +273,42 @@ class DatasetBuilder:
self.builder_info = tfds.builder(self.config.name).info
return self.builder_info
def build(self, input_context: tf.distribute.InputContext = None
) -> tf.data.Dataset:
def build(self, strategy: tf.distribute.Strategy = None) -> tf.data.Dataset:
"""Construct a dataset end-to-end and return it using an optional strategy.
Args:
strategy: a strategy that, if passed, will distribute the dataset
according to that strategy. If passed and `num_devices > 1`,
`use_per_replica_batch_size` must be set to `True`.
Returns:
A TensorFlow dataset outputting batched images and labels.
"""
if strategy:
if strategy.num_replicas_in_sync != self.config.num_devices:
logging.warn('Passed a strategy with %d devices, but expected'
'%d devices.',
strategy.num_replicas_in_sync,
self.config.num_devices)
dataset = strategy.experimental_distribute_datasets_from_function(
self._build)
else:
dataset = self._build()
return dataset
def _build(self, input_context: tf.distribute.InputContext = None
) -> tf.data.Dataset:
"""Construct a dataset end-to-end and return it.
Args:
input_context: An optional context provided by `tf.distribute` for
cross-replica training. This isn't necessary if using Keras
compile/fit.
cross-replica training.
Returns:
A TensorFlow dataset outputting batched images and labels.
"""
builders = {
'tfds': self.load_tfds,
'records': self.load_records,
......@@ -366,8 +398,8 @@ class DatasetBuilder:
Args:
dataset: A `tf.data.Dataset` that loads raw files.
input_context: An optional context provided by `tf.distribute` for
cross-replica training. This isn't necessary if using Keras
compile/fit.
cross-replica training. If set with more than one replica, this
function assumes `use_per_replica_batch_size=True`.
Returns:
A TensorFlow dataset outputting batched images and labels.
......@@ -387,8 +419,6 @@ class DatasetBuilder:
cycle_length=16,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset = dataset.prefetch(self.global_batch_size)
if self.config.cache:
dataset = dataset.cache()
......@@ -404,13 +434,25 @@ class DatasetBuilder:
dataset = dataset.map(preprocess,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset = dataset.batch(self.batch_size, drop_remainder=self.is_training)
# Note: we could do image normalization here, but we defer it to the model
# which can perform it much faster on a GPU/TPU
# TODO(dankondratyuk): if we fix prefetching, we can do it here
if input_context and self.config.num_devices > 1:
if not self.config.use_per_replica_batch_size:
raise ValueError(
'The builder does not support a global batch size with more than '
'one replica. Got {} replicas. Please set a '
'`per_replica_batch_size` and enable '
'`use_per_replica_batch_size=True`.'.format(
self.config.num_devices))
# The batch size of the dataset will be multiplied by the number of
# replicas automatically when strategy.distribute_datasets_from_function
# is called, so we use local batch size here.
dataset = dataset.batch(self.local_batch_size,
drop_remainder=self.is_training)
else:
dataset = dataset.batch(self.global_batch_size,
drop_remainder=self.is_training)
if self.is_training and self.config.deterministic_train is not None:
if self.is_training:
options = tf.data.Options()
options.experimental_deterministic = self.config.deterministic_train
options.experimental_slack = self.config.use_slack
......@@ -421,9 +463,7 @@ class DatasetBuilder:
dataset = dataset.with_options(options)
# Prefetch overlaps in-feed with training
# Note: autotune here is not recommended, as this can lead to memory leaks.
# Instead, use a constant prefetch size like the the number of devices.
dataset = dataset.prefetch(self.config.num_devices)
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
return dataset
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment