Commit 5d92fa8d authored by Yoni Ben-Meshulam's avatar Yoni Ben-Meshulam Committed by TF Object Detection Team
Browse files

Introduce weighted dataset sampling to TF Object Detection.

When InputReader.sample_from_datasets_weights is set, we apply these to sample individual input files.

PiperOrigin-RevId: 336128228
parent dae05499
......@@ -50,20 +50,22 @@ def make_initializable_iterator(dataset):
return iterator
def read_dataset(file_read_func, input_files, config,
def _read_dataset_internal(file_read_func,
input_files,
config,
filename_shard_fn=None):
"""Reads a dataset, and handles repetition and shuffling.
Args:
file_read_func: Function to use in tf_data.parallel_interleave, to
read every individual file into a tf.data.Dataset.
file_read_func: Function to use in tf_data.parallel_interleave, to read
every individual file into a tf.data.Dataset.
input_files: A list of file paths to read.
config: A input_reader_builder.InputReader object.
filename_shard_fn: optional, A funciton used to shard filenames across
replicas. This function takes as input a TF dataset of filenames and
is expected to return its sharded version. It is useful when the
dataset is being loaded on one of possibly many replicas and we want
to evenly shard the files between the replicas.
filename_shard_fn: optional, A function used to shard filenames across
replicas. This function takes as input a TF dataset of filenames and is
expected to return its sharded version. It is useful when the dataset is
being loaded on one of possibly many replicas and we want to evenly shard
the files between the replicas.
Returns:
A tf.data.Dataset of (undecoded) tf-records based on config.
......@@ -71,8 +73,9 @@ def read_dataset(file_read_func, input_files, config,
Raises:
RuntimeError: If no files are found at the supplied path(s).
"""
# Shard, shuffle, and read files.
filenames = tf.gfile.Glob(input_files)
tf.logging.info('Reading record datasets for input file: %s' % input_files)
tf.logging.info('Number of filenames to read: %s' % len(filenames))
if not filenames:
raise RuntimeError('Did not find any input files matching the glob pattern '
'{}'.format(input_files))
......@@ -103,6 +106,50 @@ def read_dataset(file_read_func, input_files, config,
return records_dataset
def read_dataset(file_read_func, input_files, config, filename_shard_fn=None):
"""Reads multiple datasets with sampling.
Args:
file_read_func: Function to use in tf_data.parallel_interleave, to read
every individual file into a tf.data.Dataset.
input_files: A list of file paths to read.
config: A input_reader_builder.InputReader object.
filename_shard_fn: optional, A function used to shard filenames across
replicas. This function takes as input a TF dataset of filenames and is
expected to return its sharded version. It is useful when the dataset is
being loaded on one of possibly many replicas and we want to evenly shard
the files between the replicas.
Returns:
A tf.data.Dataset of (undecoded) tf-records based on config.
Raises:
RuntimeError: If no files are found at the supplied path(s).
"""
if config.sample_from_datasets_weights:
tf.logging.info('Reading weighted datasets: %s' % input_files)
if len(input_files) != len(config.sample_from_datasets_weights):
raise ValueError('Expected the number of input files to be the same as '
'the number of dataset sample weights. But got '
'[input_files, sample_from_datasets_weights]: [' +
input_files + ', ' +
str(config.sample_from_datasets_weights) + ']')
tf.logging.info('Sampling from datasets %s with weights %s' %
(input_files, config.sample_from_datasets_weights))
records_datasets = []
for input_file in input_files:
records_dataset = _read_dataset_internal(file_read_func, [input_file],
config, filename_shard_fn)
records_datasets.append(records_dataset)
dataset_weights = list(config.sample_from_datasets_weights)
return tf.data.experimental.sample_from_datasets(records_datasets,
dataset_weights)
else:
tf.logging.info('Reading unweighted datasets: %s' % input_files)
return _read_dataset_internal(file_read_func, input_files, config,
filename_shard_fn)
def shard_function_for_context(input_context):
"""Returns a function that shards filenames based on the input context."""
......
......@@ -532,6 +532,9 @@ class ReadDatasetTest(test_case.TestCase):
return get_iterator_next_for_testing(dataset, self.is_tf2())
def _assert_item_count(self, data, item, percentage):
self.assertAlmostEqual(data.count(item)/len(data), percentage, places=1)
def test_make_initializable_iterator_with_hashTable(self):
def graph_fn():
......@@ -554,6 +557,66 @@ class ReadDatasetTest(test_case.TestCase):
result = self.execute(graph_fn, [])
self.assertAllEqual(result, [-1, 100, 1, 100])
def test_read_dataset_sample_from_datasets_weights_equal_weight(self):
"""Ensure that the files' values are equally-weighted."""
config = input_reader_pb2.InputReader()
config.num_readers = 2
config.shuffle = False
config.sample_from_datasets_weights.extend([0.5, 0.5])
def graph_fn():
return self._get_dataset_next(
[self._path_template % '0', self._path_template % '1'],
config,
batch_size=1000)
data = list(self.execute(graph_fn, []))
self.assertEqual(len(data), 1000)
self._assert_item_count(data, 1, 0.25)
self._assert_item_count(data, 10, 0.25)
self._assert_item_count(data, 2, 0.25)
self._assert_item_count(data, 20, 0.25)
def test_read_dataset_sample_from_datasets_weights_zero_weight(self):
"""Ensure that the files' values are equally-weighted."""
config = input_reader_pb2.InputReader()
config.num_readers = 2
config.shuffle = False
config.sample_from_datasets_weights.extend([1.0, 0.0])
def graph_fn():
return self._get_dataset_next(
[self._path_template % '0', self._path_template % '1'],
config,
batch_size=1000)
data = list(self.execute(graph_fn, []))
self.assertEqual(len(data), 1000)
self._assert_item_count(data, 1, 0.5)
self._assert_item_count(data, 10, 0.5)
self._assert_item_count(data, 2, 0.0)
self._assert_item_count(data, 20, 0.0)
def test_read_dataset_sample_from_datasets_weights_unbalanced(self):
"""Ensure that the files' values are equally-weighted."""
config = input_reader_pb2.InputReader()
config.num_readers = 2
config.shuffle = False
config.sample_from_datasets_weights.extend([0.1, 0.9])
def graph_fn():
return self._get_dataset_next(
[self._path_template % '0', self._path_template % '1'],
config,
batch_size=1000)
data = list(self.execute(graph_fn, []))
self.assertEqual(len(data), 1000)
self._assert_item_count(data, 1, 0.05)
self._assert_item_count(data, 10, 0.05)
self._assert_item_count(data, 2, 0.45)
self._assert_item_count(data, 20, 0.45)
def test_read_dataset(self):
config = input_reader_pb2.InputReader()
config.num_readers = 1
......
......@@ -30,7 +30,7 @@ enum InputType {
TF_SEQUENCE_EXAMPLE = 2; // TfSequenceExample Input
}
// Next id: 34
// Next id: 35
message InputReader {
// Name of input reader. Typically used to describe the dataset that is read
// by this input reader.
......@@ -61,6 +61,9 @@ message InputReader {
optional uint32 sample_1_of_n_examples = 22 [default = 1];
// Number of file shards to read in parallel.
//
// When sample_from_datasets_weights are configured, num_readers is applied
// for each dataset.
optional uint32 num_readers = 6 [default = 64];
// Number of batches to produce in parallel. If this is run on a 2x2 TPU set
......@@ -144,6 +147,18 @@ message InputReader {
ExternalInputReader external_input_reader = 9;
}
// When multiple input files are configured, we can sample across them based
// on weights.
//
// The number of weights must match the number of input files configured.
//
// When set, shuffling, shuffle buffer size, and num_readers settings are
// applied individually to each dataset.
//
// Implementation follows tf.data.experimental.sample_from_datasets sampling
// strategy.
repeated float sample_from_datasets_weights = 34;
// Expand labels to ancestors or descendants in the hierarchy for
// for positive and negative labels, respectively.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment