"vscode:/vscode.git/clone" did not exist on "f6af67eb112793a5f9192adb544fa783c0ccf8f8"
Commit eebea3f8 authored by Haoyu Zhang's avatar Haoyu Zhang
Browse files

Optimize data input pipeline


Co-authored-by: default avatarJiri Simsa <jsimsa@google.com>
parent 8e7051a8
...@@ -200,12 +200,13 @@ def input_fn(is_training, ...@@ -200,12 +200,13 @@ def input_fn(is_training,
dataset = dataset.shuffle(buffer_size=_NUM_TRAIN_FILES) dataset = dataset.shuffle(buffer_size=_NUM_TRAIN_FILES)
# Convert to individual records. # Convert to individual records.
# cycle_length = 10 means 10 files will be read and deserialized in parallel. # cycle_length = 10 means that up to 10 files will be read and deserialized in
# This number is low enough to not cause too much contention on small systems # parallel. You may want to increase this number if you have a large number of
# but high enough to provide the benefits of parallelization. You may want # CPU cores.
# to increase this number if you have a large number of CPU cores. dataset = dataset.interleave(
dataset = dataset.apply(tf.data.experimental.parallel_interleave( tf.data.TFRecordDataset,
tf.data.TFRecordDataset, cycle_length=10)) cycle_length=10,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
return resnet_run_loop.process_record_dataset( return resnet_run_loop.process_record_dataset(
dataset=dataset, dataset=dataset,
......
...@@ -83,6 +83,11 @@ def process_record_dataset(dataset, ...@@ -83,6 +83,11 @@ def process_record_dataset(dataset,
tf.compat.v1.logging.info('datasets_num_private_threads: %s', tf.compat.v1.logging.info('datasets_num_private_threads: %s',
datasets_num_private_threads) datasets_num_private_threads)
# Disable intra-op parallelism to optimize for throughput instead of latency.
options = tf.data.Options()
options.experimental_threading.max_intra_op_parallelism = 1
dataset = dataset.with_options(options)
# Prefetches a batch at a time to smooth out the time taken to load input # Prefetches a batch at a time to smooth out the time taken to load input
# files for shuffling and processing. # files for shuffling and processing.
dataset = dataset.prefetch(buffer_size=batch_size) dataset = dataset.prefetch(buffer_size=batch_size)
...@@ -94,12 +99,10 @@ def process_record_dataset(dataset, ...@@ -94,12 +99,10 @@ def process_record_dataset(dataset,
dataset = dataset.repeat(num_epochs) dataset = dataset.repeat(num_epochs)
# Parses the raw records into images and labels. # Parses the raw records into images and labels.
dataset = dataset.apply( dataset = dataset.map(
tf.data.experimental.map_and_batch(
lambda value: parse_record_fn(value, is_training, dtype), lambda value: parse_record_fn(value, is_training, dtype),
batch_size=batch_size, num_parallel_calls=tf.data.experimental.AUTOTUNE)
num_parallel_batches=num_parallel_batches, dataset = dataset.batch(batch_size, drop_remainder=False)
drop_remainder=False))
# Operations between the final prefetch and the get_next call to the iterator # Operations between the final prefetch and the get_next call to the iterator
# will happen synchronously during run time. We prefetch here again to # will happen synchronously during run time. We prefetch here again to
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment