Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
df103208
Commit
df103208
authored
Jul 20, 2020
by
TF Object Detection Team
Browse files
Merge pull request #8909 from kmindspark:singleframe2
PiperOrigin-RevId: 322234001
parents
363a36cd
dbc211f2
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
54 additions
and
5 deletions
+54
-5
research/object_detection/inputs.py
research/object_detection/inputs.py
+6
-2
research/object_detection/inputs_test.py
research/object_detection/inputs_test.py
+43
-2
research/object_detection/protos/input_reader.proto
research/object_detection/protos/input_reader.proto
+5
-1
No files found.
research/object_detection/inputs.py
View file @
df103208
...
...
@@ -1094,8 +1094,12 @@ def get_reduce_to_frame_fn(input_reader_config, is_training):
num_frames
=
tf
.
cast
(
tf
.
shape
(
tensor_dict
[
fields
.
InputDataFields
.
source_id
])[
0
],
dtype
=
tf
.
int32
)
frame_index
=
tf
.
random
.
uniform
((),
minval
=
0
,
maxval
=
num_frames
,
dtype
=
tf
.
int32
)
if
input_reader_config
.
frame_index
==
-
1
:
frame_index
=
tf
.
random
.
uniform
((),
minval
=
0
,
maxval
=
num_frames
,
dtype
=
tf
.
int32
)
else
:
frame_index
=
tf
.
constant
(
input_reader_config
.
frame_index
,
dtype
=
tf
.
int32
)
out_tensor_dict
=
{}
for
key
in
tensor_dict
:
if
key
in
fields
.
SEQUENCE_FIELDS
:
...
...
research/object_detection/inputs_test.py
View file @
df103208
...
...
@@ -61,7 +61,7 @@ def _get_configs_for_model(model_name):
configs
,
kwargs_dict
=
override_dict
)
def
_get_configs_for_model_sequence_example
(
model_name
):
def
_get_configs_for_model_sequence_example
(
model_name
,
frame_index
=-
1
):
"""Returns configurations for model."""
fname
=
os
.
path
.
join
(
tf
.
resource_loader
.
get_data_files_path
(),
'test_data/'
+
model_name
+
'.config'
)
...
...
@@ -74,7 +74,8 @@ def _get_configs_for_model_sequence_example(model_name):
override_dict
=
{
'train_input_path'
:
data_path
,
'eval_input_path'
:
data_path
,
'label_map_path'
:
label_map_path
'label_map_path'
:
label_map_path
,
'frame_index'
:
frame_index
}
return
config_util
.
merge_external_params_with_configs
(
configs
,
kwargs_dict
=
override_dict
)
...
...
@@ -312,6 +313,46 @@ class InputFnTest(test_case.TestCase, parameterized.TestCase):
tf
.
float32
,
labels
[
fields
.
InputDataFields
.
groundtruth_weights
].
dtype
)
def
test_context_rcnn_resnet50_train_input_with_sequence_example_frame_index
(
self
,
train_batch_size
=
8
):
"""Tests the training input function for FasterRcnnResnet50."""
configs
=
_get_configs_for_model_sequence_example
(
'context_rcnn_camera_trap'
,
frame_index
=
2
)
model_config
=
configs
[
'model'
]
train_config
=
configs
[
'train_config'
]
train_config
.
batch_size
=
train_batch_size
train_input_fn
=
inputs
.
create_train_input_fn
(
train_config
,
configs
[
'train_input_config'
],
model_config
)
features
,
labels
=
_make_initializable_iterator
(
train_input_fn
()).
get_next
()
self
.
assertAllEqual
([
train_batch_size
,
640
,
640
,
3
],
features
[
fields
.
InputDataFields
.
image
].
shape
.
as_list
())
self
.
assertEqual
(
tf
.
float32
,
features
[
fields
.
InputDataFields
.
image
].
dtype
)
self
.
assertAllEqual
([
train_batch_size
],
features
[
inputs
.
HASH_KEY
].
shape
.
as_list
())
self
.
assertEqual
(
tf
.
int32
,
features
[
inputs
.
HASH_KEY
].
dtype
)
self
.
assertAllEqual
(
[
train_batch_size
,
100
,
4
],
labels
[
fields
.
InputDataFields
.
groundtruth_boxes
].
shape
.
as_list
())
self
.
assertEqual
(
tf
.
float32
,
labels
[
fields
.
InputDataFields
.
groundtruth_boxes
].
dtype
)
self
.
assertAllEqual
(
[
train_batch_size
,
100
,
model_config
.
faster_rcnn
.
num_classes
],
labels
[
fields
.
InputDataFields
.
groundtruth_classes
].
shape
.
as_list
())
self
.
assertEqual
(
tf
.
float32
,
labels
[
fields
.
InputDataFields
.
groundtruth_classes
].
dtype
)
self
.
assertAllEqual
(
[
train_batch_size
,
100
],
labels
[
fields
.
InputDataFields
.
groundtruth_weights
].
shape
.
as_list
())
self
.
assertEqual
(
tf
.
float32
,
labels
[
fields
.
InputDataFields
.
groundtruth_weights
].
dtype
)
self
.
assertAllEqual
(
[
train_batch_size
,
100
,
model_config
.
faster_rcnn
.
num_classes
],
labels
[
fields
.
InputDataFields
.
groundtruth_confidences
].
shape
.
as_list
())
self
.
assertEqual
(
tf
.
float32
,
labels
[
fields
.
InputDataFields
.
groundtruth_confidences
].
dtype
)
def
test_ssd_inceptionV2_train_input
(
self
):
"""Tests the training input function for SSDInceptionV2."""
configs
=
_get_configs_for_model
(
'ssd_inception_v2_pets'
)
...
...
research/object_detection/protos/input_reader.proto
View file @
df103208
...
...
@@ -31,7 +31,7 @@ enum InputType {
TF_SEQUENCE_EXAMPLE
=
2
;
// TfSequenceExample Input
}
// Next id: 3
2
// Next id: 3
3
message
InputReader
{
// Name of input reader. Typically used to describe the dataset that is read
// by this input reader.
...
...
@@ -133,6 +133,10 @@ message InputReader {
// Whether input data type is tf.Examples or tf.SequenceExamples
optional
InputType
input_type
=
30
[
default
=
TF_EXAMPLE
];
// Which frame to choose from the input if Sequence Example. -1 indicates
// random choice.
optional
int32
frame_index
=
32
[
default
=
-
1
];
oneof
input_reader
{
TFRecordInputReader
tf_record_input_reader
=
8
;
ExternalInputReader
external_input_reader
=
9
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment