Commit d04c9e9b authored by Yoni Ben-Meshulam's avatar Yoni Ben-Meshulam Committed by TF Object Detection Team
Browse files

Use dataset weights to weight the number of input readers

PiperOrigin-RevId: 353788624
parent 219274da
......@@ -27,6 +27,7 @@ from __future__ import division
from __future__ import print_function
import functools
import math
import tensorflow.compat.v1 as tf
from object_detection.builders import decoder_builder
......@@ -52,6 +53,7 @@ def make_initializable_iterator(dataset):
def _read_dataset_internal(file_read_func,
input_files,
num_readers,
config,
filename_shard_fn=None):
"""Reads a dataset, and handles repetition and shuffling.
......@@ -60,6 +62,7 @@ def _read_dataset_internal(file_read_func,
file_read_func: Function to use in tf_data.parallel_interleave, to read
every individual file into a tf.data.Dataset.
input_files: A list of file paths to read.
num_readers: Number of readers to use.
config: A input_reader_builder.InputReader object.
filename_shard_fn: optional, A function used to shard filenames across
replicas. This function takes as input a TF dataset of filenames and is
......@@ -79,7 +82,6 @@ def _read_dataset_internal(file_read_func,
if not filenames:
raise RuntimeError('Did not find any input files matching the glob pattern '
'{}'.format(input_files))
num_readers = config.num_readers
if num_readers > len(filenames):
num_readers = len(filenames)
tf.logging.warning('num_readers has been reduced to %d to match input file '
......@@ -137,17 +139,30 @@ def read_dataset(file_read_func, input_files, config, filename_shard_fn=None):
tf.logging.info('Sampling from datasets %s with weights %s' %
(input_files, config.sample_from_datasets_weights))
records_datasets = []
for input_file in input_files:
dataset_weights = []
for i, input_file in enumerate(input_files):
weight = config.sample_from_datasets_weights[i]
num_readers = math.ceil(config.num_readers *
weight /
sum(config.sample_from_datasets_weights))
tf.logging.info(
'Num readers for dataset [%s]: %d', input_file, num_readers)
if num_readers == 0:
tf.logging.info('Skipping dataset due to zero weights: %s', input_file)
continue
tf.logging.info(
'Num readers for dataset [%s]: %d', input_file, num_readers)
records_dataset = _read_dataset_internal(file_read_func, [input_file],
config, filename_shard_fn)
num_readers, config,
filename_shard_fn)
dataset_weights.append(weight)
records_datasets.append(records_dataset)
dataset_weights = list(config.sample_from_datasets_weights)
return tf.data.experimental.sample_from_datasets(records_datasets,
dataset_weights)
else:
tf.logging.info('Reading unweighted datasets: %s' % input_files)
return _read_dataset_internal(file_read_func, input_files, config,
filename_shard_fn)
return _read_dataset_internal(file_read_func, input_files,
config.num_readers, config, filename_shard_fn)
def shard_function_for_context(input_context):
......
......@@ -161,12 +161,17 @@ message InputReader {
//
// The number of weights must match the number of input files configured.
//
// When set, shuffling, shuffle buffer size, and num_readers settings are
// The number of input readers per dataset is num_readers, scaled relative to
// the dataset weight.
//
// When set, shuffling and shuffle buffer size, settings are
// applied individually to each dataset.
//
// Implementation follows tf.data.experimental.sample_from_datasets sampling
// strategy. Weights may take any value - only relative weights matter.
// Zero weights will result in a dataset not being sampled.
//
// Zero weights will result in a dataset not being sampled and no input
// readers spawned.
//
// Examples, assuming two input files configured:
//
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment