"git@developer.sourcefind.cn:modelzoo/resnet50_tensorflow.git" did not exist on "696b69a498b43f8e6a1ecb24bb82f7b9db87c570"
Commit 5d92fa8d authored by Yoni Ben-Meshulam's avatar Yoni Ben-Meshulam Committed by TF Object Detection Team
Browse files

Introduce weighted dataset sampling to TF Object Detection.

When InputReader.sample_from_datasets_weights is set, we apply these to sample individual input files.

PiperOrigin-RevId: 336128228
parent dae05499
...@@ -50,20 +50,22 @@ def make_initializable_iterator(dataset): ...@@ -50,20 +50,22 @@ def make_initializable_iterator(dataset):
return iterator return iterator
def read_dataset(file_read_func, input_files, config, def _read_dataset_internal(file_read_func,
filename_shard_fn=None): input_files,
config,
filename_shard_fn=None):
"""Reads a dataset, and handles repetition and shuffling. """Reads a dataset, and handles repetition and shuffling.
Args: Args:
file_read_func: Function to use in tf_data.parallel_interleave, to file_read_func: Function to use in tf_data.parallel_interleave, to read
read every individual file into a tf.data.Dataset. every individual file into a tf.data.Dataset.
input_files: A list of file paths to read. input_files: A list of file paths to read.
config: A input_reader_builder.InputReader object. config: A input_reader_builder.InputReader object.
filename_shard_fn: optional, A funciton used to shard filenames across filename_shard_fn: optional, A function used to shard filenames across
replicas. This function takes as input a TF dataset of filenames and replicas. This function takes as input a TF dataset of filenames and is
is expected to return its sharded version. It is useful when the expected to return its sharded version. It is useful when the dataset is
dataset is being loaded on one of possibly many replicas and we want being loaded on one of possibly many replicas and we want to evenly shard
to evenly shard the files between the replicas. the files between the replicas.
Returns: Returns:
A tf.data.Dataset of (undecoded) tf-records based on config. A tf.data.Dataset of (undecoded) tf-records based on config.
...@@ -71,8 +73,9 @@ def read_dataset(file_read_func, input_files, config, ...@@ -71,8 +73,9 @@ def read_dataset(file_read_func, input_files, config,
Raises: Raises:
RuntimeError: If no files are found at the supplied path(s). RuntimeError: If no files are found at the supplied path(s).
""" """
# Shard, shuffle, and read files.
filenames = tf.gfile.Glob(input_files) filenames = tf.gfile.Glob(input_files)
tf.logging.info('Reading record datasets for input file: %s' % input_files)
tf.logging.info('Number of filenames to read: %s' % len(filenames))
if not filenames: if not filenames:
raise RuntimeError('Did not find any input files matching the glob pattern ' raise RuntimeError('Did not find any input files matching the glob pattern '
'{}'.format(input_files)) '{}'.format(input_files))
...@@ -103,6 +106,50 @@ def read_dataset(file_read_func, input_files, config, ...@@ -103,6 +106,50 @@ def read_dataset(file_read_func, input_files, config,
return records_dataset return records_dataset
def read_dataset(file_read_func, input_files, config, filename_shard_fn=None):
"""Reads multiple datasets with sampling.
Args:
file_read_func: Function to use in tf_data.parallel_interleave, to read
every individual file into a tf.data.Dataset.
input_files: A list of file paths to read.
config: A input_reader_builder.InputReader object.
filename_shard_fn: optional, A function used to shard filenames across
replicas. This function takes as input a TF dataset of filenames and is
expected to return its sharded version. It is useful when the dataset is
being loaded on one of possibly many replicas and we want to evenly shard
the files between the replicas.
Returns:
A tf.data.Dataset of (undecoded) tf-records based on config.
Raises:
RuntimeError: If no files are found at the supplied path(s).
"""
if config.sample_from_datasets_weights:
tf.logging.info('Reading weighted datasets: %s' % input_files)
if len(input_files) != len(config.sample_from_datasets_weights):
raise ValueError('Expected the number of input files to be the same as '
'the number of dataset sample weights. But got '
'[input_files, sample_from_datasets_weights]: [' +
input_files + ', ' +
str(config.sample_from_datasets_weights) + ']')
tf.logging.info('Sampling from datasets %s with weights %s' %
(input_files, config.sample_from_datasets_weights))
records_datasets = []
for input_file in input_files:
records_dataset = _read_dataset_internal(file_read_func, [input_file],
config, filename_shard_fn)
records_datasets.append(records_dataset)
dataset_weights = list(config.sample_from_datasets_weights)
return tf.data.experimental.sample_from_datasets(records_datasets,
dataset_weights)
else:
tf.logging.info('Reading unweighted datasets: %s' % input_files)
return _read_dataset_internal(file_read_func, input_files, config,
filename_shard_fn)
def shard_function_for_context(input_context): def shard_function_for_context(input_context):
"""Returns a function that shards filenames based on the input context.""" """Returns a function that shards filenames based on the input context."""
......
...@@ -532,6 +532,9 @@ class ReadDatasetTest(test_case.TestCase): ...@@ -532,6 +532,9 @@ class ReadDatasetTest(test_case.TestCase):
return get_iterator_next_for_testing(dataset, self.is_tf2()) return get_iterator_next_for_testing(dataset, self.is_tf2())
def _assert_item_count(self, data, item, percentage):
self.assertAlmostEqual(data.count(item)/len(data), percentage, places=1)
def test_make_initializable_iterator_with_hashTable(self): def test_make_initializable_iterator_with_hashTable(self):
def graph_fn(): def graph_fn():
...@@ -554,6 +557,66 @@ class ReadDatasetTest(test_case.TestCase): ...@@ -554,6 +557,66 @@ class ReadDatasetTest(test_case.TestCase):
result = self.execute(graph_fn, []) result = self.execute(graph_fn, [])
self.assertAllEqual(result, [-1, 100, 1, 100]) self.assertAllEqual(result, [-1, 100, 1, 100])
def test_read_dataset_sample_from_datasets_weights_equal_weight(self):
"""Ensure that the files' values are equally-weighted."""
config = input_reader_pb2.InputReader()
config.num_readers = 2
config.shuffle = False
config.sample_from_datasets_weights.extend([0.5, 0.5])
def graph_fn():
return self._get_dataset_next(
[self._path_template % '0', self._path_template % '1'],
config,
batch_size=1000)
data = list(self.execute(graph_fn, []))
self.assertEqual(len(data), 1000)
self._assert_item_count(data, 1, 0.25)
self._assert_item_count(data, 10, 0.25)
self._assert_item_count(data, 2, 0.25)
self._assert_item_count(data, 20, 0.25)
def test_read_dataset_sample_from_datasets_weights_zero_weight(self):
"""Ensure that the files' values are equally-weighted."""
config = input_reader_pb2.InputReader()
config.num_readers = 2
config.shuffle = False
config.sample_from_datasets_weights.extend([1.0, 0.0])
def graph_fn():
return self._get_dataset_next(
[self._path_template % '0', self._path_template % '1'],
config,
batch_size=1000)
data = list(self.execute(graph_fn, []))
self.assertEqual(len(data), 1000)
self._assert_item_count(data, 1, 0.5)
self._assert_item_count(data, 10, 0.5)
self._assert_item_count(data, 2, 0.0)
self._assert_item_count(data, 20, 0.0)
def test_read_dataset_sample_from_datasets_weights_unbalanced(self):
"""Ensure that the files' values are equally-weighted."""
config = input_reader_pb2.InputReader()
config.num_readers = 2
config.shuffle = False
config.sample_from_datasets_weights.extend([0.1, 0.9])
def graph_fn():
return self._get_dataset_next(
[self._path_template % '0', self._path_template % '1'],
config,
batch_size=1000)
data = list(self.execute(graph_fn, []))
self.assertEqual(len(data), 1000)
self._assert_item_count(data, 1, 0.05)
self._assert_item_count(data, 10, 0.05)
self._assert_item_count(data, 2, 0.45)
self._assert_item_count(data, 20, 0.45)
def test_read_dataset(self): def test_read_dataset(self):
config = input_reader_pb2.InputReader() config = input_reader_pb2.InputReader()
config.num_readers = 1 config.num_readers = 1
......
...@@ -30,7 +30,7 @@ enum InputType { ...@@ -30,7 +30,7 @@ enum InputType {
TF_SEQUENCE_EXAMPLE = 2; // TfSequenceExample Input TF_SEQUENCE_EXAMPLE = 2; // TfSequenceExample Input
} }
// Next id: 34 // Next id: 35
message InputReader { message InputReader {
// Name of input reader. Typically used to describe the dataset that is read // Name of input reader. Typically used to describe the dataset that is read
// by this input reader. // by this input reader.
...@@ -61,6 +61,9 @@ message InputReader { ...@@ -61,6 +61,9 @@ message InputReader {
optional uint32 sample_1_of_n_examples = 22 [default = 1]; optional uint32 sample_1_of_n_examples = 22 [default = 1];
// Number of file shards to read in parallel. // Number of file shards to read in parallel.
//
// When sample_from_datasets_weights are configured, num_readers is applied
// for each dataset.
optional uint32 num_readers = 6 [default = 64]; optional uint32 num_readers = 6 [default = 64];
// Number of batches to produce in parallel. If this is run on a 2x2 TPU set // Number of batches to produce in parallel. If this is run on a 2x2 TPU set
...@@ -144,6 +147,18 @@ message InputReader { ...@@ -144,6 +147,18 @@ message InputReader {
ExternalInputReader external_input_reader = 9; ExternalInputReader external_input_reader = 9;
} }
// When multiple input files are configured, we can sample across them based
// on weights.
//
// The number of weights must match the number of input files configured.
//
// When set, shuffling, shuffle buffer size, and num_readers settings are
// applied individually to each dataset.
//
// Implementation follows tf.data.experimental.sample_from_datasets sampling
// strategy.
repeated float sample_from_datasets_weights = 34;
// Expand labels to ancestors or descendants in the hierarchy for // Expand labels to ancestors or descendants in the hierarchy for
// for positive and negative labels, respectively. // for positive and negative labels, respectively.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment