Commit ebc3edc6 authored by Allen Wang's avatar Allen Wang Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 310965031
parent 9b8544e0
......@@ -13,7 +13,7 @@ train_dataset:
image_size: 224
num_classes: 1000
num_examples: 1281167
batch_size: 64
batch_size: 256
use_per_replica_batch_size: True
dtype: 'float16'
mean_subtract: True
......@@ -26,7 +26,7 @@ validation_dataset:
image_size: 224
num_classes: 1000
num_examples: 50000
batch_size: 64
batch_size: 256
use_per_replica_batch_size: True
dtype: 'float16'
mean_subtract: True
......
......@@ -15,8 +15,8 @@ train_dataset:
num_examples: 1281167
batch_size: 128
use_per_replica_batch_size: True
mean_subtract: True
standardize: True
mean_subtract: False
standardize: False
dtype: 'bfloat16'
validation_dataset:
name: 'imagenet2012'
......@@ -29,13 +29,13 @@ validation_dataset:
num_examples: 50000
batch_size: 128
use_per_replica_batch_size: True
mean_subtract: True
standardize: True
mean_subtract: False
standardize: False
dtype: 'bfloat16'
model:
name: 'resnet'
model_params:
rescale_inputs: False
rescale_inputs: True
optimizer:
name: 'momentum'
momentum: 0.9
......
......@@ -284,7 +284,6 @@ class DatasetBuilder:
'%d devices.',
strategy.num_replicas_in_sync,
self.config.num_devices)
dataset = strategy.experimental_distribute_datasets_from_function(
self._build)
else:
......@@ -314,8 +313,9 @@ class DatasetBuilder:
if builder is None:
raise ValueError('Unknown builder type {}'.format(self.config.builder))
self.input_context = input_context
dataset = builder()
dataset = self.pipeline(dataset, input_context)
dataset = self.pipeline(dataset)
return dataset
......@@ -336,8 +336,9 @@ class DatasetBuilder:
decoders['image'] = tfds.decode.SkipDecoding()
read_config = tfds.ReadConfig(
interleave_cycle_length=64,
interleave_block_length=1)
interleave_cycle_length=10,
interleave_block_length=1,
input_context=self.input_context)
dataset = builder.as_dataset(
split=self.config.split,
......@@ -351,19 +352,15 @@ class DatasetBuilder:
def load_records(self) -> tf.data.Dataset:
"""Return a dataset loading files with TFRecords."""
logging.info('Using TFRecords to load data.')
if self.config.filenames is None:
if self.config.data_dir is None:
raise ValueError('Dataset must specify a path for the data files.')
file_pattern = os.path.join(self.config.data_dir,
'{}*'.format(self.config.split))
dataset = tf.data.Dataset.list_files(file_pattern, shuffle=True)
dataset = tf.data.Dataset.list_files(file_pattern, shuffle=False)
else:
dataset = tf.data.Dataset.from_tensor_slices(self.config.filenames)
if self.is_training:
# Shuffle the input files.
dataset.shuffle(buffer_size=self.config.file_shuffle_buffer_size)
return dataset
......@@ -383,34 +380,37 @@ class DatasetBuilder:
num_parallel_calls=tf.data.experimental.AUTOTUNE)
return dataset
def pipeline(self,
dataset: tf.data.Dataset,
input_context: tf.distribute.InputContext = None
) -> tf.data.Dataset:
def pipeline(self, dataset: tf.data.Dataset) -> tf.data.Dataset:
"""Build a pipeline fetching, shuffling, and preprocessing the dataset.
Args:
dataset: A `tf.data.Dataset` that loads raw files.
input_context: An optional context provided by `tf.distribute` for
cross-replica training. If set with more than one replica, this
function assumes `use_per_replica_batch_size=True`.
Returns:
A TensorFlow dataset outputting batched images and labels.
"""
if input_context and input_context.num_input_pipelines > 1:
dataset = dataset.shard(input_context.num_input_pipelines,
input_context.input_pipeline_id)
if (self.config.builder != 'tfds' and self.input_context
and self.input_context.num_input_pipelines > 1):
dataset = dataset.shard(self.input_context.num_input_pipelines,
self.input_context.input_pipeline_id)
logging.info('Sharding the dataset: input_pipeline_id=%d '
'num_input_pipelines=%d',
self.input_context.num_input_pipelines,
self.input_context.input_pipeline_id)
if self.is_training and self.config.builder == 'records':
# Shuffle the input files.
dataset.shuffle(buffer_size=self.config.file_shuffle_buffer_size)
if self.is_training and not self.config.cache:
dataset = dataset.repeat()
if self.config.builder == 'records':
# Read the data from disk in parallel
buffer_size = 8 * 1024 * 1024 # Use 8 MiB per file
dataset = dataset.interleave(
lambda name: tf.data.TFRecordDataset(name, buffer_size=buffer_size),
cycle_length=16,
tf.data.TFRecordDataset,
cycle_length=10,
block_length=1,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
if self.config.cache:
......@@ -428,7 +428,7 @@ class DatasetBuilder:
dataset = dataset.map(preprocess,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
if input_context and self.config.num_devices > 1:
if self.input_context and self.config.num_devices > 1:
if not self.config.use_per_replica_batch_size:
raise ValueError(
'The builder does not support a global batch size with more than '
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment