Commit ebc3edc6 authored by Allen Wang's avatar Allen Wang Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 310965031
parent 9b8544e0
...@@ -13,7 +13,7 @@ train_dataset: ...@@ -13,7 +13,7 @@ train_dataset:
image_size: 224 image_size: 224
num_classes: 1000 num_classes: 1000
num_examples: 1281167 num_examples: 1281167
batch_size: 64 batch_size: 256
use_per_replica_batch_size: True use_per_replica_batch_size: True
dtype: 'float16' dtype: 'float16'
mean_subtract: True mean_subtract: True
...@@ -26,7 +26,7 @@ validation_dataset: ...@@ -26,7 +26,7 @@ validation_dataset:
image_size: 224 image_size: 224
num_classes: 1000 num_classes: 1000
num_examples: 50000 num_examples: 50000
batch_size: 64 batch_size: 256
use_per_replica_batch_size: True use_per_replica_batch_size: True
dtype: 'float16' dtype: 'float16'
mean_subtract: True mean_subtract: True
......
...@@ -15,8 +15,8 @@ train_dataset: ...@@ -15,8 +15,8 @@ train_dataset:
num_examples: 1281167 num_examples: 1281167
batch_size: 128 batch_size: 128
use_per_replica_batch_size: True use_per_replica_batch_size: True
mean_subtract: True mean_subtract: False
standardize: True standardize: False
dtype: 'bfloat16' dtype: 'bfloat16'
validation_dataset: validation_dataset:
name: 'imagenet2012' name: 'imagenet2012'
...@@ -29,13 +29,13 @@ validation_dataset: ...@@ -29,13 +29,13 @@ validation_dataset:
num_examples: 50000 num_examples: 50000
batch_size: 128 batch_size: 128
use_per_replica_batch_size: True use_per_replica_batch_size: True
mean_subtract: True mean_subtract: False
standardize: True standardize: False
dtype: 'bfloat16' dtype: 'bfloat16'
model: model:
name: 'resnet' name: 'resnet'
model_params: model_params:
rescale_inputs: False rescale_inputs: True
optimizer: optimizer:
name: 'momentum' name: 'momentum'
momentum: 0.9 momentum: 0.9
......
...@@ -284,7 +284,6 @@ class DatasetBuilder: ...@@ -284,7 +284,6 @@ class DatasetBuilder:
'%d devices.', '%d devices.',
strategy.num_replicas_in_sync, strategy.num_replicas_in_sync,
self.config.num_devices) self.config.num_devices)
dataset = strategy.experimental_distribute_datasets_from_function( dataset = strategy.experimental_distribute_datasets_from_function(
self._build) self._build)
else: else:
...@@ -314,8 +313,9 @@ class DatasetBuilder: ...@@ -314,8 +313,9 @@ class DatasetBuilder:
if builder is None: if builder is None:
raise ValueError('Unknown builder type {}'.format(self.config.builder)) raise ValueError('Unknown builder type {}'.format(self.config.builder))
self.input_context = input_context
dataset = builder() dataset = builder()
dataset = self.pipeline(dataset, input_context) dataset = self.pipeline(dataset)
return dataset return dataset
...@@ -336,8 +336,9 @@ class DatasetBuilder: ...@@ -336,8 +336,9 @@ class DatasetBuilder:
decoders['image'] = tfds.decode.SkipDecoding() decoders['image'] = tfds.decode.SkipDecoding()
read_config = tfds.ReadConfig( read_config = tfds.ReadConfig(
interleave_cycle_length=64, interleave_cycle_length=10,
interleave_block_length=1) interleave_block_length=1,
input_context=self.input_context)
dataset = builder.as_dataset( dataset = builder.as_dataset(
split=self.config.split, split=self.config.split,
...@@ -351,19 +352,15 @@ class DatasetBuilder: ...@@ -351,19 +352,15 @@ class DatasetBuilder:
def load_records(self) -> tf.data.Dataset: def load_records(self) -> tf.data.Dataset:
"""Return a dataset loading files with TFRecords.""" """Return a dataset loading files with TFRecords."""
logging.info('Using TFRecords to load data.') logging.info('Using TFRecords to load data.')
if self.config.filenames is None: if self.config.filenames is None:
if self.config.data_dir is None: if self.config.data_dir is None:
raise ValueError('Dataset must specify a path for the data files.') raise ValueError('Dataset must specify a path for the data files.')
file_pattern = os.path.join(self.config.data_dir, file_pattern = os.path.join(self.config.data_dir,
'{}*'.format(self.config.split)) '{}*'.format(self.config.split))
dataset = tf.data.Dataset.list_files(file_pattern, shuffle=True) dataset = tf.data.Dataset.list_files(file_pattern, shuffle=False)
else: else:
dataset = tf.data.Dataset.from_tensor_slices(self.config.filenames) dataset = tf.data.Dataset.from_tensor_slices(self.config.filenames)
if self.is_training:
# Shuffle the input files.
dataset.shuffle(buffer_size=self.config.file_shuffle_buffer_size)
return dataset return dataset
...@@ -383,34 +380,37 @@ class DatasetBuilder: ...@@ -383,34 +380,37 @@ class DatasetBuilder:
num_parallel_calls=tf.data.experimental.AUTOTUNE) num_parallel_calls=tf.data.experimental.AUTOTUNE)
return dataset return dataset
def pipeline(self, def pipeline(self, dataset: tf.data.Dataset) -> tf.data.Dataset:
dataset: tf.data.Dataset,
input_context: tf.distribute.InputContext = None
) -> tf.data.Dataset:
"""Build a pipeline fetching, shuffling, and preprocessing the dataset. """Build a pipeline fetching, shuffling, and preprocessing the dataset.
Args: Args:
dataset: A `tf.data.Dataset` that loads raw files. dataset: A `tf.data.Dataset` that loads raw files.
input_context: An optional context provided by `tf.distribute` for
cross-replica training. If set with more than one replica, this
function assumes `use_per_replica_batch_size=True`.
Returns: Returns:
A TensorFlow dataset outputting batched images and labels. A TensorFlow dataset outputting batched images and labels.
""" """
if input_context and input_context.num_input_pipelines > 1: if (self.config.builder != 'tfds' and self.input_context
dataset = dataset.shard(input_context.num_input_pipelines, and self.input_context.num_input_pipelines > 1):
input_context.input_pipeline_id) dataset = dataset.shard(self.input_context.num_input_pipelines,
self.input_context.input_pipeline_id)
logging.info('Sharding the dataset: input_pipeline_id=%d '
'num_input_pipelines=%d',
self.input_context.num_input_pipelines,
self.input_context.input_pipeline_id)
if self.is_training and self.config.builder == 'records':
# Shuffle the input files.
dataset.shuffle(buffer_size=self.config.file_shuffle_buffer_size)
if self.is_training and not self.config.cache: if self.is_training and not self.config.cache:
dataset = dataset.repeat() dataset = dataset.repeat()
if self.config.builder == 'records': if self.config.builder == 'records':
# Read the data from disk in parallel # Read the data from disk in parallel
buffer_size = 8 * 1024 * 1024 # Use 8 MiB per file
dataset = dataset.interleave( dataset = dataset.interleave(
lambda name: tf.data.TFRecordDataset(name, buffer_size=buffer_size), tf.data.TFRecordDataset,
cycle_length=16, cycle_length=10,
block_length=1,
num_parallel_calls=tf.data.experimental.AUTOTUNE) num_parallel_calls=tf.data.experimental.AUTOTUNE)
if self.config.cache: if self.config.cache:
...@@ -428,7 +428,7 @@ class DatasetBuilder: ...@@ -428,7 +428,7 @@ class DatasetBuilder:
dataset = dataset.map(preprocess, dataset = dataset.map(preprocess,
num_parallel_calls=tf.data.experimental.AUTOTUNE) num_parallel_calls=tf.data.experimental.AUTOTUNE)
if input_context and self.config.num_devices > 1: if self.input_context and self.config.num_devices > 1:
if not self.config.use_per_replica_batch_size: if not self.config.use_per_replica_batch_size:
raise ValueError( raise ValueError(
'The builder does not support a global batch size with more than ' 'The builder does not support a global batch size with more than '
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment