"...git@developer.sourcefind.cn:OpenDAS/mmdetection3d.git" did not exist on "30101f73aaed90a0fb119161fe02fd95abeaa3fb"
Commit 61a61902 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Use `strategy.distribute_datasets_from_function` in the classifier trainer.

PiperOrigin-RevId: 307483983
parent ba772461
...@@ -32,7 +32,6 @@ import tensorflow as tf ...@@ -32,7 +32,6 @@ import tensorflow as tf
from official.modeling import performance from official.modeling import performance
from official.modeling.hyperparams import params_dict from official.modeling.hyperparams import params_dict
from official.utils import hyperparams_flags from official.utils import hyperparams_flags
from official.utils.logs import logger
from official.utils.misc import distribution_utils from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils from official.utils.misc import keras_utils
from official.vision.image_classification import callbacks as custom_callbacks from official.vision.image_classification import callbacks as custom_callbacks
...@@ -316,7 +315,8 @@ def train_and_eval( ...@@ -316,7 +315,8 @@ def train_and_eval(
one_hot = label_smoothing and label_smoothing > 0 one_hot = label_smoothing and label_smoothing > 0
builders = _get_dataset_builders(params, strategy, one_hot) builders = _get_dataset_builders(params, strategy, one_hot)
datasets = [builder.build() if builder else None for builder in builders] datasets = [builder.build(strategy)
if builder else None for builder in builders]
# Unpack datasets and builders based on train/val/test splits # Unpack datasets and builders based on train/val/test splits
train_builder, validation_builder = builders # pylint: disable=unbalanced-tuple-unpacking train_builder, validation_builder = builders # pylint: disable=unbalanced-tuple-unpacking
......
...@@ -10,6 +10,7 @@ train_dataset: ...@@ -10,6 +10,7 @@ train_dataset:
num_classes: 1000 num_classes: 1000
num_examples: 1281167 num_examples: 1281167
batch_size: 32 batch_size: 32
use_per_replica_batch_size: True
dtype: 'float32' dtype: 'float32'
validation_dataset: validation_dataset:
name: 'imagenet2012' name: 'imagenet2012'
...@@ -19,6 +20,7 @@ validation_dataset: ...@@ -19,6 +20,7 @@ validation_dataset:
num_classes: 1000 num_classes: 1000
num_examples: 50000 num_examples: 50000
batch_size: 32 batch_size: 32
use_per_replica_batch_size: True
dtype: 'float32' dtype: 'float32'
model: model:
model_params: model_params:
......
...@@ -84,7 +84,8 @@ class DatasetConfig(base_config.Config): ...@@ -84,7 +84,8 @@ class DatasetConfig(base_config.Config):
use_per_replica_batch_size: Whether to scale the batch size based on use_per_replica_batch_size: Whether to scale the batch size based on
available resources. If set to `True`, the dataset builder will return available resources. If set to `True`, the dataset builder will return
batch_size multiplied by `num_devices`, the number of device replicas batch_size multiplied by `num_devices`, the number of device replicas
(e.g., the number of GPUs or TPU cores). (e.g., the number of GPUs or TPU cores). This setting should be `True` if
the strategy argument is passed to `build()` and `num_devices > 1`.
num_devices: The number of replica devices to use. This should be set by num_devices: The number of replica devices to use. This should be set by
`strategy.num_replicas_in_sync` when using a distribution strategy. `strategy.num_replicas_in_sync` when using a distribution strategy.
dtype: The desired dtype of the dataset. This will be set during dtype: The desired dtype of the dataset. This will be set during
...@@ -194,6 +195,14 @@ class DatasetBuilder: ...@@ -194,6 +195,14 @@ class DatasetBuilder:
"""The global batch size across all replicas.""" """The global batch size across all replicas."""
return self.batch_size return self.batch_size
@property
def local_batch_size(self):
"""The base unscaled batch size."""
if self.config.use_per_replica_batch_size:
return self.config.batch_size
else:
return self.config.batch_size // self.config.num_devices
@property @property
def num_steps(self) -> int: def num_steps(self) -> int:
"""The number of steps (batches) to exhaust this dataset.""" """The number of steps (batches) to exhaust this dataset."""
...@@ -264,19 +273,42 @@ class DatasetBuilder: ...@@ -264,19 +273,42 @@ class DatasetBuilder:
self.builder_info = tfds.builder(self.config.name).info self.builder_info = tfds.builder(self.config.name).info
return self.builder_info return self.builder_info
def build(self, input_context: tf.distribute.InputContext = None def build(self, strategy: tf.distribute.Strategy = None) -> tf.data.Dataset:
) -> tf.data.Dataset: """Construct a dataset end-to-end and return it using an optional strategy.
Args:
strategy: a strategy that, if passed, will distribute the dataset
according to that strategy. If passed and `num_devices > 1`,
`use_per_replica_batch_size` must be set to `True`.
Returns:
A TensorFlow dataset outputting batched images and labels.
"""
if strategy:
if strategy.num_replicas_in_sync != self.config.num_devices:
logging.warn('Passed a strategy with %d devices, but expected'
'%d devices.',
strategy.num_replicas_in_sync,
self.config.num_devices)
dataset = strategy.experimental_distribute_datasets_from_function(
self._build)
else:
dataset = self._build()
return dataset
def _build(self, input_context: tf.distribute.InputContext = None
) -> tf.data.Dataset:
"""Construct a dataset end-to-end and return it. """Construct a dataset end-to-end and return it.
Args: Args:
input_context: An optional context provided by `tf.distribute` for input_context: An optional context provided by `tf.distribute` for
cross-replica training. This isn't necessary if using Keras cross-replica training.
compile/fit.
Returns: Returns:
A TensorFlow dataset outputting batched images and labels. A TensorFlow dataset outputting batched images and labels.
""" """
builders = { builders = {
'tfds': self.load_tfds, 'tfds': self.load_tfds,
'records': self.load_records, 'records': self.load_records,
...@@ -366,8 +398,8 @@ class DatasetBuilder: ...@@ -366,8 +398,8 @@ class DatasetBuilder:
Args: Args:
dataset: A `tf.data.Dataset` that loads raw files. dataset: A `tf.data.Dataset` that loads raw files.
input_context: An optional context provided by `tf.distribute` for input_context: An optional context provided by `tf.distribute` for
cross-replica training. This isn't necessary if using Keras cross-replica training. If set with more than one replica, this
compile/fit. function assumes `use_per_replica_batch_size=True`.
Returns: Returns:
A TensorFlow dataset outputting batched images and labels. A TensorFlow dataset outputting batched images and labels.
...@@ -387,8 +419,6 @@ class DatasetBuilder: ...@@ -387,8 +419,6 @@ class DatasetBuilder:
cycle_length=16, cycle_length=16,
num_parallel_calls=tf.data.experimental.AUTOTUNE) num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset = dataset.prefetch(self.global_batch_size)
if self.config.cache: if self.config.cache:
dataset = dataset.cache() dataset = dataset.cache()
...@@ -404,13 +434,25 @@ class DatasetBuilder: ...@@ -404,13 +434,25 @@ class DatasetBuilder:
dataset = dataset.map(preprocess, dataset = dataset.map(preprocess,
num_parallel_calls=tf.data.experimental.AUTOTUNE) num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset = dataset.batch(self.batch_size, drop_remainder=self.is_training) if input_context and self.config.num_devices > 1:
if not self.config.use_per_replica_batch_size:
# Note: we could do image normalization here, but we defer it to the model raise ValueError(
# which can perform it much faster on a GPU/TPU 'The builder does not support a global batch size with more than '
# TODO(dankondratyuk): if we fix prefetching, we can do it here 'one replica. Got {} replicas. Please set a '
'`per_replica_batch_size` and enable '
'`use_per_replica_batch_size=True`.'.format(
self.config.num_devices))
# The batch size of the dataset will be multiplied by the number of
# replicas automatically when strategy.distribute_datasets_from_function
# is called, so we use local batch size here.
dataset = dataset.batch(self.local_batch_size,
drop_remainder=self.is_training)
else:
dataset = dataset.batch(self.global_batch_size,
drop_remainder=self.is_training)
if self.is_training and self.config.deterministic_train is not None: if self.is_training:
options = tf.data.Options() options = tf.data.Options()
options.experimental_deterministic = self.config.deterministic_train options.experimental_deterministic = self.config.deterministic_train
options.experimental_slack = self.config.use_slack options.experimental_slack = self.config.use_slack
...@@ -421,9 +463,7 @@ class DatasetBuilder: ...@@ -421,9 +463,7 @@ class DatasetBuilder:
dataset = dataset.with_options(options) dataset = dataset.with_options(options)
# Prefetch overlaps in-feed with training # Prefetch overlaps in-feed with training
# Note: autotune here is not recommended, as this can lead to memory leaks. dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
# Instead, use a constant prefetch size like the the number of devices.
dataset = dataset.prefetch(self.config.num_devices)
return dataset return dataset
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment