Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
d04c9e9b
Commit
d04c9e9b
authored
Jan 25, 2021
by
Yoni Ben-Meshulam
Committed by
TF Object Detection Team
Jan 25, 2021
Browse files
Use dataset weights to weight the number of input readers
PiperOrigin-RevId: 353788624
parent
219274da
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
28 additions
and
8 deletions
+28
-8
research/object_detection/builders/dataset_builder.py
research/object_detection/builders/dataset_builder.py
+21
-6
research/object_detection/protos/input_reader.proto
research/object_detection/protos/input_reader.proto
+7
-2
No files found.
research/object_detection/builders/dataset_builder.py
View file @
d04c9e9b
...
@@ -27,6 +27,7 @@ from __future__ import division
...
@@ -27,6 +27,7 @@ from __future__ import division
from
__future__
import
print_function
from
__future__
import
print_function
import
functools
import
functools
import
math
import
tensorflow.compat.v1
as
tf
import
tensorflow.compat.v1
as
tf
from
object_detection.builders
import
decoder_builder
from
object_detection.builders
import
decoder_builder
...
@@ -52,6 +53,7 @@ def make_initializable_iterator(dataset):
...
@@ -52,6 +53,7 @@ def make_initializable_iterator(dataset):
def
_read_dataset_internal
(
file_read_func
,
def
_read_dataset_internal
(
file_read_func
,
input_files
,
input_files
,
num_readers
,
config
,
config
,
filename_shard_fn
=
None
):
filename_shard_fn
=
None
):
"""Reads a dataset, and handles repetition and shuffling.
"""Reads a dataset, and handles repetition and shuffling.
...
@@ -60,6 +62,7 @@ def _read_dataset_internal(file_read_func,
...
@@ -60,6 +62,7 @@ def _read_dataset_internal(file_read_func,
file_read_func: Function to use in tf_data.parallel_interleave, to read
file_read_func: Function to use in tf_data.parallel_interleave, to read
every individual file into a tf.data.Dataset.
every individual file into a tf.data.Dataset.
input_files: A list of file paths to read.
input_files: A list of file paths to read.
num_readers: Number of readers to use.
config: A input_reader_builder.InputReader object.
config: A input_reader_builder.InputReader object.
filename_shard_fn: optional, A function used to shard filenames across
filename_shard_fn: optional, A function used to shard filenames across
replicas. This function takes as input a TF dataset of filenames and is
replicas. This function takes as input a TF dataset of filenames and is
...
@@ -79,7 +82,6 @@ def _read_dataset_internal(file_read_func,
...
@@ -79,7 +82,6 @@ def _read_dataset_internal(file_read_func,
if
not
filenames
:
if
not
filenames
:
raise
RuntimeError
(
'Did not find any input files matching the glob pattern '
raise
RuntimeError
(
'Did not find any input files matching the glob pattern '
'{}'
.
format
(
input_files
))
'{}'
.
format
(
input_files
))
num_readers
=
config
.
num_readers
if
num_readers
>
len
(
filenames
):
if
num_readers
>
len
(
filenames
):
num_readers
=
len
(
filenames
)
num_readers
=
len
(
filenames
)
tf
.
logging
.
warning
(
'num_readers has been reduced to %d to match input file '
tf
.
logging
.
warning
(
'num_readers has been reduced to %d to match input file '
...
@@ -137,17 +139,30 @@ def read_dataset(file_read_func, input_files, config, filename_shard_fn=None):
...
@@ -137,17 +139,30 @@ def read_dataset(file_read_func, input_files, config, filename_shard_fn=None):
tf
.
logging
.
info
(
'Sampling from datasets %s with weights %s'
%
tf
.
logging
.
info
(
'Sampling from datasets %s with weights %s'
%
(
input_files
,
config
.
sample_from_datasets_weights
))
(
input_files
,
config
.
sample_from_datasets_weights
))
records_datasets
=
[]
records_datasets
=
[]
for
input_file
in
input_files
:
dataset_weights
=
[]
for
i
,
input_file
in
enumerate
(
input_files
):
weight
=
config
.
sample_from_datasets_weights
[
i
]
num_readers
=
math
.
ceil
(
config
.
num_readers
*
weight
/
sum
(
config
.
sample_from_datasets_weights
))
tf
.
logging
.
info
(
'Num readers for dataset [%s]: %d'
,
input_file
,
num_readers
)
if
num_readers
==
0
:
tf
.
logging
.
info
(
'Skipping dataset due to zero weights: %s'
,
input_file
)
continue
tf
.
logging
.
info
(
'Num readers for dataset [%s]: %d'
,
input_file
,
num_readers
)
records_dataset
=
_read_dataset_internal
(
file_read_func
,
[
input_file
],
records_dataset
=
_read_dataset_internal
(
file_read_func
,
[
input_file
],
config
,
filename_shard_fn
)
num_readers
,
config
,
filename_shard_fn
)
dataset_weights
.
append
(
weight
)
records_datasets
.
append
(
records_dataset
)
records_datasets
.
append
(
records_dataset
)
dataset_weights
=
list
(
config
.
sample_from_datasets_weights
)
return
tf
.
data
.
experimental
.
sample_from_datasets
(
records_datasets
,
return
tf
.
data
.
experimental
.
sample_from_datasets
(
records_datasets
,
dataset_weights
)
dataset_weights
)
else
:
else
:
tf
.
logging
.
info
(
'Reading unweighted datasets: %s'
%
input_files
)
tf
.
logging
.
info
(
'Reading unweighted datasets: %s'
%
input_files
)
return
_read_dataset_internal
(
file_read_func
,
input_files
,
config
,
return
_read_dataset_internal
(
file_read_func
,
input_files
,
filename_shard_fn
)
config
.
num_readers
,
config
,
filename_shard_fn
)
def
shard_function_for_context
(
input_context
):
def
shard_function_for_context
(
input_context
):
...
...
research/object_detection/protos/input_reader.proto
View file @
d04c9e9b
...
@@ -161,12 +161,17 @@ message InputReader {
...
@@ -161,12 +161,17 @@ message InputReader {
//
//
// The number of weights must match the number of input files configured.
// The number of weights must match the number of input files configured.
//
//
// When set, shuffling, shuffle buffer size, and num_readers settings are
// The number of input readers per dataset is num_readers, scaled relative to
// the dataset weight.
//
// When set, shuffling and shuffle buffer size, settings are
// applied individually to each dataset.
// applied individually to each dataset.
//
//
// Implementation follows tf.data.experimental.sample_from_datasets sampling
// Implementation follows tf.data.experimental.sample_from_datasets sampling
// strategy. Weights may take any value - only relative weights matter.
// strategy. Weights may take any value - only relative weights matter.
// Zero weights will result in a dataset not being sampled.
//
// Zero weights will result in a dataset not being sampled and no input
// readers spawned.
//
//
// Examples, assuming two input files configured:
// Examples, assuming two input files configured:
//
//
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment