"vscode:/vscode.git/clone" did not exist on "d78ec6ead525b6c514bde6ee79237293fa1f9308"
Commit eebea3f8 authored by Haoyu Zhang's avatar Haoyu Zhang
Browse files

Optimize data input pipeline


Co-authored-by: default avatarJiri Simsa <jsimsa@google.com>
parent 8e7051a8
......@@ -200,12 +200,13 @@ def input_fn(is_training,
dataset = dataset.shuffle(buffer_size=_NUM_TRAIN_FILES)
# Convert to individual records.
# cycle_length = 10 means 10 files will be read and deserialized in parallel.
# This number is low enough to not cause too much contention on small systems
# but high enough to provide the benefits of parallelization. You may want
# to increase this number if you have a large number of CPU cores.
dataset = dataset.apply(tf.data.experimental.parallel_interleave(
tf.data.TFRecordDataset, cycle_length=10))
# cycle_length = 10 means that up to 10 files will be read and deserialized in
# parallel. You may want to increase this number if you have a large number of
# CPU cores.
dataset = dataset.interleave(
tf.data.TFRecordDataset,
cycle_length=10,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
return resnet_run_loop.process_record_dataset(
dataset=dataset,
......
......@@ -83,6 +83,11 @@ def process_record_dataset(dataset,
tf.compat.v1.logging.info('datasets_num_private_threads: %s',
datasets_num_private_threads)
# Disable intra-op parallelism to optimize for throughput instead of latency.
options = tf.data.Options()
options.experimental_threading.max_intra_op_parallelism = 1
dataset = dataset.with_options(options)
# Prefetches a batch at a time to smooth out the time taken to load input
# files for shuffling and processing.
dataset = dataset.prefetch(buffer_size=batch_size)
......@@ -94,12 +99,10 @@ def process_record_dataset(dataset,
dataset = dataset.repeat(num_epochs)
# Parses the raw records into images and labels.
dataset = dataset.apply(
tf.data.experimental.map_and_batch(
lambda value: parse_record_fn(value, is_training, dtype),
batch_size=batch_size,
num_parallel_batches=num_parallel_batches,
drop_remainder=False))
dataset = dataset.map(
lambda value: parse_record_fn(value, is_training, dtype),
num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset = dataset.batch(batch_size, drop_remainder=False)
# Operations between the final prefetch and the get_next call to the iterator
# will happen synchronously during run time. We prefetch here again to
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment