Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
bb43ed96
Commit
bb43ed96
authored
Mar 27, 2018
by
Zhichao Lu
Committed by
pkulzc
Apr 02, 2018
Browse files
Write groundtruth weights from input pipeline into model.
PiperOrigin-RevId: 190636417
parent
45069b91
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
52 additions
and
63 deletions
+52
-63
research/object_detection/inputs.py
research/object_detection/inputs.py
+49
-62
research/object_detection/model.py
research/object_detection/model.py
+3
-1
No files found.
research/object_detection/inputs.py
View file @
bb43ed96
...
@@ -160,6 +160,52 @@ def augment_input_data(tensor_dict, data_augmentation_options):
...
@@ -160,6 +160,52 @@ def augment_input_data(tensor_dict, data_augmentation_options):
return
tensor_dict
return
tensor_dict
def
_get_labels_dict
(
input_dict
):
"""Extracts labels dict from input dict."""
required_label_keys
=
[
fields
.
InputDataFields
.
num_groundtruth_boxes
,
fields
.
InputDataFields
.
groundtruth_boxes
,
fields
.
InputDataFields
.
groundtruth_classes
,
fields
.
InputDataFields
.
groundtruth_weights
]
labels_dict
=
{}
for
key
in
required_label_keys
:
labels_dict
[
key
]
=
input_dict
[
key
]
optional_label_keys
=
[
fields
.
InputDataFields
.
groundtruth_keypoints
,
fields
.
InputDataFields
.
groundtruth_instance_masks
,
fields
.
InputDataFields
.
groundtruth_area
,
fields
.
InputDataFields
.
groundtruth_is_crowd
,
fields
.
InputDataFields
.
groundtruth_difficult
]
for
key
in
optional_label_keys
:
if
key
in
input_dict
:
labels_dict
[
key
]
=
input_dict
[
key
]
if
fields
.
InputDataFields
.
groundtruth_difficult
in
labels_dict
:
labels_dict
[
fields
.
InputDataFields
.
groundtruth_difficult
]
=
tf
.
cast
(
labels_dict
[
fields
.
InputDataFields
.
groundtruth_difficult
],
tf
.
int32
)
return
labels_dict
def
_get_features_dict
(
input_dict
):
"""Extracts features dict from input dict."""
hash_from_source_id
=
tf
.
string_to_hash_bucket_fast
(
input_dict
[
fields
.
InputDataFields
.
source_id
],
HASH_BINS
)
features
=
{
fields
.
InputDataFields
.
image
:
input_dict
[
fields
.
InputDataFields
.
image
],
HASH_KEY
:
tf
.
cast
(
hash_from_source_id
,
tf
.
int32
),
fields
.
InputDataFields
.
true_image_shape
:
input_dict
[
fields
.
InputDataFields
.
true_image_shape
]
}
if
fields
.
InputDataFields
.
original_image
in
input_dict
:
features
[
fields
.
InputDataFields
.
original_image
]
=
input_dict
[
fields
.
InputDataFields
.
original_image
]
return
features
def
create_train_input_fn
(
train_config
,
train_input_config
,
def
create_train_input_fn
(
train_config
,
train_input_config
,
model_config
):
model_config
):
"""Creates a train `input` function for `Estimator`.
"""Creates a train `input` function for `Estimator`.
...
@@ -249,38 +295,8 @@ def create_train_input_fn(train_config, train_input_config,
...
@@ -249,38 +295,8 @@ def create_train_input_fn(train_config, train_input_config,
num_classes
=
config_util
.
get_number_of_classes
(
model_config
),
num_classes
=
config_util
.
get_number_of_classes
(
model_config
),
spatial_image_shape
=
config_util
.
get_spatial_image_size
(
spatial_image_shape
=
config_util
.
get_spatial_image_size
(
image_resizer_config
))
image_resizer_config
))
tensor_dict
=
dataset_util
.
make_initializable_iterator
(
dataset
).
get_next
()
input_dict
=
dataset_util
.
make_initializable_iterator
(
dataset
).
get_next
()
return
(
_get_features_dict
(
input_dict
),
_get_labels_dict
(
input_dict
))
hash_from_source_id
=
tf
.
string_to_hash_bucket_fast
(
tensor_dict
[
fields
.
InputDataFields
.
source_id
],
HASH_BINS
)
features
=
{
fields
.
InputDataFields
.
image
:
tensor_dict
[
fields
.
InputDataFields
.
image
],
HASH_KEY
:
tf
.
cast
(
hash_from_source_id
,
tf
.
int32
),
fields
.
InputDataFields
.
true_image_shape
:
tensor_dict
[
fields
.
InputDataFields
.
true_image_shape
]
}
if
fields
.
InputDataFields
.
original_image
in
tensor_dict
:
features
[
fields
.
InputDataFields
.
original_image
]
=
tensor_dict
[
fields
.
InputDataFields
.
original_image
]
labels
=
{
fields
.
InputDataFields
.
num_groundtruth_boxes
:
tensor_dict
[
fields
.
InputDataFields
.
num_groundtruth_boxes
],
fields
.
InputDataFields
.
groundtruth_boxes
:
tensor_dict
[
fields
.
InputDataFields
.
groundtruth_boxes
],
fields
.
InputDataFields
.
groundtruth_classes
:
tensor_dict
[
fields
.
InputDataFields
.
groundtruth_classes
],
fields
.
InputDataFields
.
groundtruth_weights
:
tensor_dict
[
fields
.
InputDataFields
.
groundtruth_weights
]
}
if
fields
.
InputDataFields
.
groundtruth_keypoints
in
tensor_dict
:
labels
[
fields
.
InputDataFields
.
groundtruth_keypoints
]
=
tensor_dict
[
fields
.
InputDataFields
.
groundtruth_keypoints
]
if
fields
.
InputDataFields
.
groundtruth_instance_masks
in
tensor_dict
:
labels
[
fields
.
InputDataFields
.
groundtruth_instance_masks
]
=
tensor_dict
[
fields
.
InputDataFields
.
groundtruth_instance_masks
]
return
features
,
labels
return
_train_input_fn
return
_train_input_fn
...
@@ -365,36 +381,7 @@ def create_eval_input_fn(eval_config, eval_input_config, model_config):
...
@@ -365,36 +381,7 @@ def create_eval_input_fn(eval_config, eval_input_config, model_config):
image_resizer_config
))
image_resizer_config
))
input_dict
=
dataset_util
.
make_initializable_iterator
(
dataset
).
get_next
()
input_dict
=
dataset_util
.
make_initializable_iterator
(
dataset
).
get_next
()
hash_from_source_id
=
tf
.
string_to_hash_bucket_fast
(
return
(
_get_features_dict
(
input_dict
),
_get_labels_dict
(
input_dict
))
input_dict
[
fields
.
InputDataFields
.
source_id
],
HASH_BINS
)
features
=
{
fields
.
InputDataFields
.
image
:
input_dict
[
fields
.
InputDataFields
.
image
],
fields
.
InputDataFields
.
original_image
:
input_dict
[
fields
.
InputDataFields
.
original_image
],
HASH_KEY
:
tf
.
cast
(
hash_from_source_id
,
tf
.
int32
),
fields
.
InputDataFields
.
true_image_shape
:
input_dict
[
fields
.
InputDataFields
.
true_image_shape
]
}
labels
=
{
fields
.
InputDataFields
.
groundtruth_boxes
:
input_dict
[
fields
.
InputDataFields
.
groundtruth_boxes
],
fields
.
InputDataFields
.
groundtruth_classes
:
input_dict
[
fields
.
InputDataFields
.
groundtruth_classes
],
fields
.
InputDataFields
.
groundtruth_area
:
input_dict
[
fields
.
InputDataFields
.
groundtruth_area
],
fields
.
InputDataFields
.
groundtruth_is_crowd
:
input_dict
[
fields
.
InputDataFields
.
groundtruth_is_crowd
],
fields
.
InputDataFields
.
groundtruth_difficult
:
tf
.
cast
(
input_dict
[
fields
.
InputDataFields
.
groundtruth_difficult
],
tf
.
int32
)
}
if
fields
.
InputDataFields
.
groundtruth_instance_masks
in
input_dict
:
labels
[
fields
.
InputDataFields
.
groundtruth_instance_masks
]
=
input_dict
[
fields
.
InputDataFields
.
groundtruth_instance_masks
]
return
features
,
labels
return
_eval_input_fn
return
_eval_input_fn
...
...
research/object_detection/model.py
View file @
bb43ed96
...
@@ -241,7 +241,9 @@ def create_model_fn(detection_model_fn, configs, hparams, use_tpu=False):
...
@@ -241,7 +241,9 @@ def create_model_fn(detection_model_fn, configs, hparams, use_tpu=False):
groundtruth_boxes_list
=
gt_boxes_list
,
groundtruth_boxes_list
=
gt_boxes_list
,
groundtruth_classes_list
=
gt_classes_list
,
groundtruth_classes_list
=
gt_classes_list
,
groundtruth_masks_list
=
gt_masks_list
,
groundtruth_masks_list
=
gt_masks_list
,
groundtruth_keypoints_list
=
gt_keypoints_list
)
groundtruth_keypoints_list
=
gt_keypoints_list
,
groundtruth_weights_list
=
labels
[
fields
.
InputDataFields
.
groundtruth_weights
])
preprocessed_images
=
features
[
fields
.
InputDataFields
.
image
]
preprocessed_images
=
features
[
fields
.
InputDataFields
.
image
]
prediction_dict
=
detection_model
.
predict
(
prediction_dict
=
detection_model
.
predict
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment