Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
8f932583
Commit
8f932583
authored
Feb 09, 2018
by
Zhichao Lu
Committed by
lzc5123016
Feb 13, 2018
Browse files
Remove sharding from the input pipeline.
PiperOrigin-RevId: 185222703
parent
fe31beae
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
25 additions
and
38 deletions
+25
-38
research/object_detection/builders/dataset_builder.py
research/object_detection/builders/dataset_builder.py
+5
-7
research/object_detection/protos/input_reader.proto
research/object_detection/protos/input_reader.proto
+7
-4
research/object_detection/train.py
research/object_detection/train.py
+1
-3
research/object_detection/utils/dataset_util.py
research/object_detection/utils/dataset_util.py
+12
-24
No files found.
research/object_detection/builders/dataset_builder.py
View file @
8f932583
...
@@ -21,7 +21,7 @@ Note: If users wishes to also use their own InputReaders with the Object
...
@@ -21,7 +21,7 @@ Note: If users wishes to also use their own InputReaders with the Object
Detection configuration framework, they should define their own builder function
Detection configuration framework, they should define their own builder function
that wraps the build function.
that wraps the build function.
"""
"""
import
functools
import
tensorflow
as
tf
import
tensorflow
as
tf
from
object_detection.core
import
standard_fields
as
fields
from
object_detection.core
import
standard_fields
as
fields
...
@@ -86,8 +86,8 @@ def _get_padding_shapes(dataset, max_num_boxes, num_classes,
...
@@ -86,8 +86,8 @@ def _get_padding_shapes(dataset, max_num_boxes, num_classes,
for
tensor_key
,
_
in
dataset
.
output_shapes
.
items
()}
for
tensor_key
,
_
in
dataset
.
output_shapes
.
items
()}
def
build
(
input_reader_config
,
transform_input_data_fn
=
None
,
num_workers
=
1
,
def
build
(
input_reader_config
,
transform_input_data_fn
=
None
,
worker_index
=
0
,
batch_size
=
1
,
max_num_boxes
=
None
,
num_classes
=
None
,
batch_size
=
1
,
max_num_boxes
=
None
,
num_classes
=
None
,
spatial_image_shape
=
None
):
spatial_image_shape
=
None
):
"""Builds a tf.data.Dataset.
"""Builds a tf.data.Dataset.
...
@@ -100,8 +100,6 @@ def build(input_reader_config, transform_input_data_fn=None, num_workers=1,
...
@@ -100,8 +100,6 @@ def build(input_reader_config, transform_input_data_fn=None, num_workers=1,
input_reader_config: A input_reader_pb2.InputReader object.
input_reader_config: A input_reader_pb2.InputReader object.
transform_input_data_fn: Function to apply to all records, or None if
transform_input_data_fn: Function to apply to all records, or None if
no extra decoding is required.
no extra decoding is required.
num_workers: Number of workers (tpu shard).
worker_index: Id for the current worker (tpu shard).
batch_size: Batch size. If not None, returns a padded batch dataset.
batch_size: Batch size. If not None, returns a padded batch dataset.
max_num_boxes: Max number of groundtruth boxes needed to computes shapes for
max_num_boxes: Max number of groundtruth boxes needed to computes shapes for
padding. This is only used if batch_size is greater than 1.
padding. This is only used if batch_size is greater than 1.
...
@@ -146,8 +144,8 @@ def build(input_reader_config, transform_input_data_fn=None, num_workers=1,
...
@@ -146,8 +144,8 @@ def build(input_reader_config, transform_input_data_fn=None, num_workers=1,
return
processed
return
processed
dataset
=
dataset_util
.
read_dataset
(
dataset
=
dataset_util
.
read_dataset
(
tf
.
data
.
TFRecordDataset
,
process_fn
,
config
.
input_path
[:]
,
functools
.
partial
(
tf
.
data
.
TFRecordDataset
,
buffer_size
=
8
*
1000
*
1000
)
,
input_reader_config
,
num_workers
,
worker_index
)
process_fn
,
config
.
input_path
[:],
input_reader_config
)
if
batch_size
>
1
:
if
batch_size
>
1
:
if
num_classes
is
None
:
if
num_classes
is
None
:
...
...
research/object_detection/protos/input_reader.proto
View file @
8f932583
...
@@ -32,7 +32,7 @@ message InputReader {
...
@@ -32,7 +32,7 @@ message InputReader {
optional
bool
shuffle
=
2
[
default
=
true
];
optional
bool
shuffle
=
2
[
default
=
true
];
// Buffer size to be used when shuffling.
// Buffer size to be used when shuffling.
optional
uint32
shuffle_buffer_size
=
11
[
default
=
100
];
optional
uint32
shuffle_buffer_size
=
11
[
default
=
2048
];
// Buffer size to be used when shuffling file names.
// Buffer size to be used when shuffling file names.
optional
uint32
filenames_shuffle_buffer_size
=
12
[
default
=
100
];
optional
uint32
filenames_shuffle_buffer_size
=
12
[
default
=
100
];
...
@@ -49,10 +49,13 @@ message InputReader {
...
@@ -49,10 +49,13 @@ message InputReader {
optional
uint32
num_epochs
=
5
[
default
=
0
];
optional
uint32
num_epochs
=
5
[
default
=
0
];
// Number of reader instances to create.
// Number of reader instances to create.
optional
uint32
num_readers
=
6
[
default
=
8
];
optional
uint32
num_readers
=
6
[
default
=
32
];
// Size of the buffer for prefetching (in batches).
// Number of decoded records to prefetch before batching.
optional
uint32
prefetch_buffer_size
=
13
[
default
=
2
];
optional
uint32
prefetch_size
=
13
[
default
=
512
];
// Number of parallel decode ops to apply.
optional
uint32
num_parallel_map_calls
=
14
[
default
=
64
];
// Whether to load groundtruth instance masks.
// Whether to load groundtruth instance masks.
optional
bool
load_instance_masks
=
7
[
default
=
false
];
optional
bool
load_instance_masks
=
7
[
default
=
false
];
...
...
research/object_detection/train.py
View file @
8f932583
...
@@ -117,9 +117,7 @@ def main(_):
...
@@ -117,9 +117,7 @@ def main(_):
def
get_next
(
config
):
def
get_next
(
config
):
return
dataset_util
.
make_initializable_iterator
(
return
dataset_util
.
make_initializable_iterator
(
dataset_builder
.
build
(
dataset_builder
.
build
(
config
)).
get_next
()
config
,
num_workers
=
FLAGS
.
worker_replicas
,
worker_index
=
FLAGS
.
task
)).
get_next
()
create_input_dict_fn
=
functools
.
partial
(
get_next
,
input_config
)
create_input_dict_fn
=
functools
.
partial
(
get_next
,
input_config
)
...
...
research/object_detection/utils/dataset_util.py
View file @
8f932583
...
@@ -103,9 +103,7 @@ def make_initializable_iterator(dataset):
...
@@ -103,9 +103,7 @@ def make_initializable_iterator(dataset):
return
iterator
return
iterator
def
read_dataset
(
def
read_dataset
(
file_read_func
,
decode_func
,
input_files
,
config
):
file_read_func
,
decode_func
,
input_files
,
config
,
num_workers
=
1
,
worker_index
=
0
):
"""Reads a dataset, and handles repetition and shuffling.
"""Reads a dataset, and handles repetition and shuffling.
Args:
Args:
...
@@ -114,8 +112,6 @@ def read_dataset(
...
@@ -114,8 +112,6 @@ def read_dataset(
decode_func: Function to apply to all records.
decode_func: Function to apply to all records.
input_files: A list of file paths to read.
input_files: A list of file paths to read.
config: A input_reader_builder.InputReader object.
config: A input_reader_builder.InputReader object.
num_workers: Number of workers / shards.
worker_index: Id for the current worker.
Returns:
Returns:
A tf.data.Dataset based on config.
A tf.data.Dataset based on config.
...
@@ -123,25 +119,17 @@ def read_dataset(
...
@@ -123,25 +119,17 @@ def read_dataset(
# Shard, shuffle, and read files.
# Shard, shuffle, and read files.
filenames
=
tf
.
concat
([
tf
.
matching_files
(
pattern
)
for
pattern
in
input_files
],
filenames
=
tf
.
concat
([
tf
.
matching_files
(
pattern
)
for
pattern
in
input_files
],
0
)
0
)
dataset
=
tf
.
data
.
Dataset
.
from_tensor_slices
(
filenames
)
filename_dataset
=
tf
.
data
.
Dataset
.
from_tensor_slices
(
filenames
)
dataset
=
dataset
.
shard
(
num_workers
,
worker_index
)
dataset
=
dataset
.
repeat
(
config
.
num_epochs
or
None
)
if
config
.
shuffle
:
if
config
.
shuffle
:
dataset
=
dataset
.
shuffle
(
config
.
filenames_shuffle_buffer_size
,
filename_dataset
=
filename_dataset
.
shuffle
(
reshuffle_each_iteration
=
True
)
config
.
filenames_shuffle_buffer_size
)
filename_dataset
=
filename_dataset
.
repeat
(
config
.
num_epochs
or
None
)
# Read file records and shuffle them.
# If cycle_length is larger than the number of files, more than one reader
# will be assigned to the same file, leading to repetition.
cycle_length
=
tf
.
cast
(
tf
.
minimum
(
config
.
num_readers
,
tf
.
size
(
filenames
)),
tf
.
int64
)
# TODO: find the optimal block_length.
dataset
=
dataset
.
interleave
(
file_read_func
,
cycle_length
=
cycle_length
,
block_length
=
1
)
records_dataset
=
filename_dataset
.
apply
(
tf
.
contrib
.
data
.
parallel_interleave
(
file_read_func
,
cycle_length
=
config
.
num_readers
,
sloppy
=
True
))
if
config
.
shuffle
:
if
config
.
shuffle
:
dataset
=
dataset
.
shuffle
(
config
.
shuffle_buffer_size
,
records_dataset
.
shuffle
(
config
.
shuffle_buffer_size
)
reshuffle_each_iteration
=
True
)
tensor_dataset
=
records_dataset
.
map
(
decode_func
,
num_parallel_calls
=
config
.
num_parallel_map_calls
)
dataset
=
dataset
.
map
(
decode_func
,
num_parallel_calls
=
config
.
num_readers
)
return
tensor_dataset
.
prefetch
(
config
.
prefetch_size
)
return
dataset
.
prefetch
(
config
.
prefetch_buffer_size
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment