"vscode:/vscode.git/clone" did not exist on "f8e3ce890c140f45c089778e0674dd66e68b1b49"
Commit 8f932583 authored by Zhichao Lu's avatar Zhichao Lu Committed by lzc5123016
Browse files

Remove sharding from the input pipeline.

PiperOrigin-RevId: 185222703
parent fe31beae
......@@ -21,7 +21,7 @@ Note: If users wishes to also use their own InputReaders with the Object
Detection configuration framework, they should define their own builder function
that wraps the build function.
"""
import functools
import tensorflow as tf
from object_detection.core import standard_fields as fields
......@@ -86,8 +86,8 @@ def _get_padding_shapes(dataset, max_num_boxes, num_classes,
for tensor_key, _ in dataset.output_shapes.items()}
def build(input_reader_config, transform_input_data_fn=None, num_workers=1,
worker_index=0, batch_size=1, max_num_boxes=None, num_classes=None,
def build(input_reader_config, transform_input_data_fn=None,
batch_size=1, max_num_boxes=None, num_classes=None,
spatial_image_shape=None):
"""Builds a tf.data.Dataset.
......@@ -100,8 +100,6 @@ def build(input_reader_config, transform_input_data_fn=None, num_workers=1,
input_reader_config: A input_reader_pb2.InputReader object.
transform_input_data_fn: Function to apply to all records, or None if
no extra decoding is required.
num_workers: Number of workers (tpu shard).
worker_index: Id for the current worker (tpu shard).
batch_size: Batch size. If not None, returns a padded batch dataset.
max_num_boxes: Max number of groundtruth boxes needed to computes shapes for
padding. This is only used if batch_size is greater than 1.
......@@ -146,8 +144,8 @@ def build(input_reader_config, transform_input_data_fn=None, num_workers=1,
return processed
dataset = dataset_util.read_dataset(
tf.data.TFRecordDataset, process_fn, config.input_path[:],
input_reader_config, num_workers, worker_index)
functools.partial(tf.data.TFRecordDataset, buffer_size=8 * 1000 * 1000),
process_fn, config.input_path[:], input_reader_config)
if batch_size > 1:
if num_classes is None:
......
......@@ -32,7 +32,7 @@ message InputReader {
optional bool shuffle = 2 [default=true];
// Buffer size to be used when shuffling.
optional uint32 shuffle_buffer_size = 11 [default = 100];
optional uint32 shuffle_buffer_size = 11 [default = 2048];
// Buffer size to be used when shuffling file names.
optional uint32 filenames_shuffle_buffer_size = 12 [default = 100];
......@@ -49,10 +49,13 @@ message InputReader {
optional uint32 num_epochs = 5 [default=0];
// Number of reader instances to create.
optional uint32 num_readers = 6 [default=8];
optional uint32 num_readers = 6 [default=32];
// Size of the buffer for prefetching (in batches).
optional uint32 prefetch_buffer_size = 13 [default = 2];
// Number of decoded records to prefetch before batching.
optional uint32 prefetch_size = 13 [default = 512];
// Number of parallel decode ops to apply.
optional uint32 num_parallel_map_calls = 14 [default = 64];
// Whether to load groundtruth instance masks.
optional bool load_instance_masks = 7 [default = false];
......
......@@ -117,9 +117,7 @@ def main(_):
def get_next(config):
return dataset_util.make_initializable_iterator(
dataset_builder.build(
config, num_workers=FLAGS.worker_replicas,
worker_index=FLAGS.task)).get_next()
dataset_builder.build(config)).get_next()
create_input_dict_fn = functools.partial(get_next, input_config)
......
......@@ -103,9 +103,7 @@ def make_initializable_iterator(dataset):
return iterator
def read_dataset(
file_read_func, decode_func, input_files, config, num_workers=1,
worker_index=0):
def read_dataset(file_read_func, decode_func, input_files, config):
"""Reads a dataset, and handles repetition and shuffling.
Args:
......@@ -114,8 +112,6 @@ def read_dataset(
decode_func: Function to apply to all records.
input_files: A list of file paths to read.
config: A input_reader_builder.InputReader object.
num_workers: Number of workers / shards.
worker_index: Id for the current worker.
Returns:
A tf.data.Dataset based on config.
......@@ -123,25 +119,17 @@ def read_dataset(
# Shard, shuffle, and read files.
filenames = tf.concat([tf.matching_files(pattern) for pattern in input_files],
0)
dataset = tf.data.Dataset.from_tensor_slices(filenames)
dataset = dataset.shard(num_workers, worker_index)
dataset = dataset.repeat(config.num_epochs or None)
filename_dataset = tf.data.Dataset.from_tensor_slices(filenames)
if config.shuffle:
dataset = dataset.shuffle(config.filenames_shuffle_buffer_size,
reshuffle_each_iteration=True)
# Read file records and shuffle them.
# If cycle_length is larger than the number of files, more than one reader
# will be assigned to the same file, leading to repetition.
cycle_length = tf.cast(
tf.minimum(config.num_readers, tf.size(filenames)), tf.int64)
# TODO: find the optimal block_length.
dataset = dataset.interleave(
file_read_func, cycle_length=cycle_length, block_length=1)
filename_dataset = filename_dataset.shuffle(
config.filenames_shuffle_buffer_size)
filename_dataset = filename_dataset.repeat(config.num_epochs or None)
records_dataset = filename_dataset.apply(
tf.contrib.data.parallel_interleave(
file_read_func, cycle_length=config.num_readers, sloppy=True))
if config.shuffle:
dataset = dataset.shuffle(config.shuffle_buffer_size,
reshuffle_each_iteration=True)
dataset = dataset.map(decode_func, num_parallel_calls=config.num_readers)
return dataset.prefetch(config.prefetch_buffer_size)
records_dataset.shuffle(config.shuffle_buffer_size)
tensor_dataset = records_dataset.map(
decode_func, num_parallel_calls=config.num_parallel_map_calls)
return tensor_dataset.prefetch(config.prefetch_size)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment