Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
776cb1ca
Commit
776cb1ca
authored
Feb 24, 2021
by
Ronny Votel
Committed by
TF Object Detection Team
Feb 24, 2021
Browse files
Updating hyperparameters for CenterNet that together make a sizable impact to keypoint performance.
PiperOrigin-RevId: 359406239
parent
4864e795
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
355 additions
and
11 deletions
+355
-11
research/object_detection/builders/model_builder.py
research/object_detection/builders/model_builder.py
+4
-1
research/object_detection/meta_architectures/center_net_meta_arch.py
...ject_detection/meta_architectures/center_net_meta_arch.py
+99
-9
research/object_detection/meta_architectures/center_net_meta_arch_tf2_test.py
...ction/meta_architectures/center_net_meta_arch_tf2_test.py
+61
-0
research/object_detection/protos/center_net.proto
research/object_detection/protos/center_net.proto
+19
-0
research/object_detection/utils/config_util.py
research/object_detection/utils/config_util.py
+111
-1
research/object_detection/utils/config_util_test.py
research/object_detection/utils/config_util_test.py
+61
-0
No files found.
research/object_detection/builders/model_builder.py
View file @
776cb1ca
...
@@ -871,7 +871,10 @@ def keypoint_proto_to_params(kp_config, keypoint_map_dict):
...
@@ -871,7 +871,10 @@ def keypoint_proto_to_params(kp_config, keypoint_map_dict):
per_keypoint_offset
=
kp_config
.
per_keypoint_offset
,
per_keypoint_offset
=
kp_config
.
per_keypoint_offset
,
predict_depth
=
kp_config
.
predict_depth
,
predict_depth
=
kp_config
.
predict_depth
,
per_keypoint_depth
=
kp_config
.
per_keypoint_depth
,
per_keypoint_depth
=
kp_config
.
per_keypoint_depth
,
keypoint_depth_loss_weight
=
kp_config
.
keypoint_depth_loss_weight
)
keypoint_depth_loss_weight
=
kp_config
.
keypoint_depth_loss_weight
,
score_distance_offset
=
kp_config
.
score_distance_offset
,
clip_out_of_frame_keypoints
=
kp_config
.
clip_out_of_frame_keypoints
,
rescore_instances
=
kp_config
.
rescore_instances
)
def
object_detection_proto_to_params
(
od_config
):
def
object_detection_proto_to_params
(
od_config
):
...
...
research/object_detection/meta_architectures/center_net_meta_arch.py
View file @
776cb1ca
...
@@ -744,6 +744,7 @@ def refine_keypoints(regressed_keypoints,
...
@@ -744,6 +744,7 @@ def refine_keypoints(regressed_keypoints,
box_scale
=
1.2
,
box_scale
=
1.2
,
candidate_search_scale
=
0.3
,
candidate_search_scale
=
0.3
,
candidate_ranking_mode
=
'min_distance'
,
candidate_ranking_mode
=
'min_distance'
,
score_distance_offset
=
1e-6
,
keypoint_depth_candidates
=
None
):
keypoint_depth_candidates
=
None
):
"""Refines regressed keypoints by snapping to the nearest candidate keypoints.
"""Refines regressed keypoints by snapping to the nearest candidate keypoints.
...
@@ -800,6 +801,11 @@ def refine_keypoints(regressed_keypoints,
...
@@ -800,6 +801,11 @@ def refine_keypoints(regressed_keypoints,
candidate_ranking_mode: A string as one of ['min_distance',
candidate_ranking_mode: A string as one of ['min_distance',
'score_distance_ratio'] indicating how to select the candidate. If invalid
'score_distance_ratio'] indicating how to select the candidate. If invalid
value is provided, an ValueError will be raised.
value is provided, an ValueError will be raised.
score_distance_offset: The distance offset to apply in the denominator when
candidate_ranking_mode is 'score_distance_ratio'. The metric to maximize
in this scenario is score / (distance + score_distance_offset). Larger
values of score_distance_offset make the keypoint score gain more relative
importance.
keypoint_depth_candidates: (optional) A float tensor of shape
keypoint_depth_candidates: (optional) A float tensor of shape
[batch_size, max_candidates, num_keypoints] indicating the depths for
[batch_size, max_candidates, num_keypoints] indicating the depths for
keypoint candidates.
keypoint candidates.
...
@@ -873,7 +879,7 @@ def refine_keypoints(regressed_keypoints,
...
@@ -873,7 +879,7 @@ def refine_keypoints(regressed_keypoints,
tiled_keypoint_scores
=
tf
.
tile
(
tiled_keypoint_scores
=
tf
.
tile
(
tf
.
expand_dims
(
keypoint_scores
,
axis
=
1
),
tf
.
expand_dims
(
keypoint_scores
,
axis
=
1
),
multiples
=
[
1
,
num_instances
,
1
,
1
])
multiples
=
[
1
,
num_instances
,
1
,
1
])
ranking_scores
=
tiled_keypoint_scores
/
(
distances
+
1e-6
)
ranking_scores
=
tiled_keypoint_scores
/
(
distances
+
score_distance_offset
)
nearby_candidate_inds
=
tf
.
math
.
argmax
(
ranking_scores
,
axis
=
2
)
nearby_candidate_inds
=
tf
.
math
.
argmax
(
ranking_scores
,
axis
=
2
)
else
:
else
:
raise
ValueError
(
'Not recognized candidate_ranking_mode: %s'
%
raise
ValueError
(
'Not recognized candidate_ranking_mode: %s'
%
...
@@ -1590,7 +1596,9 @@ class KeypointEstimationParams(
...
@@ -1590,7 +1596,9 @@ class KeypointEstimationParams(
'peak_max_pool_kernel_size'
,
'unmatched_keypoint_score'
,
'box_scale'
,
'peak_max_pool_kernel_size'
,
'unmatched_keypoint_score'
,
'box_scale'
,
'candidate_search_scale'
,
'candidate_ranking_mode'
,
'candidate_search_scale'
,
'candidate_ranking_mode'
,
'offset_peak_radius'
,
'per_keypoint_offset'
,
'predict_depth'
,
'offset_peak_radius'
,
'per_keypoint_offset'
,
'predict_depth'
,
'per_keypoint_depth'
,
'keypoint_depth_loss_weight'
'per_keypoint_depth'
,
'keypoint_depth_loss_weight'
,
'score_distance_offset'
,
'clip_out_of_frame_keypoints'
,
'rescore_instances'
])):
])):
"""Namedtuple to host object detection related parameters.
"""Namedtuple to host object detection related parameters.
...
@@ -1626,7 +1634,10 @@ class KeypointEstimationParams(
...
@@ -1626,7 +1634,10 @@ class KeypointEstimationParams(
per_keypoint_offset
=
False
,
per_keypoint_offset
=
False
,
predict_depth
=
False
,
predict_depth
=
False
,
per_keypoint_depth
=
False
,
per_keypoint_depth
=
False
,
keypoint_depth_loss_weight
=
1.0
):
keypoint_depth_loss_weight
=
1.0
,
score_distance_offset
=
1e-6
,
clip_out_of_frame_keypoints
=
False
,
rescore_instances
=
False
):
"""Constructor with default values for KeypointEstimationParams.
"""Constructor with default values for KeypointEstimationParams.
Args:
Args:
...
@@ -1696,6 +1707,16 @@ class KeypointEstimationParams(
...
@@ -1696,6 +1707,16 @@ class KeypointEstimationParams(
of each keypoints in independent channels. Similar to
of each keypoints in independent channels. Similar to
per_keypoint_offset but for the keypoint depth.
per_keypoint_offset but for the keypoint depth.
keypoint_depth_loss_weight: The weight of the keypoint depth loss.
keypoint_depth_loss_weight: The weight of the keypoint depth loss.
score_distance_offset: The distance offset to apply in the denominator
when candidate_ranking_mode is 'score_distance_ratio'. The metric to
maximize in this scenario is score / (distance + score_distance_offset).
Larger values of score_distance_offset make the keypoint score gain more
relative importance.
clip_out_of_frame_keypoints: Whether keypoints outside the image frame
should be clipped back to the image boundary. If True, the keypoints
that are clipped have scores set to 0.0.
rescore_instances: Whether to rescore instances based on a combination of
detection score and keypoint scores.
Returns:
Returns:
An initialized KeypointEstimationParams namedtuple.
An initialized KeypointEstimationParams namedtuple.
...
@@ -1709,7 +1730,8 @@ class KeypointEstimationParams(
...
@@ -1709,7 +1730,8 @@ class KeypointEstimationParams(
peak_max_pool_kernel_size
,
unmatched_keypoint_score
,
box_scale
,
peak_max_pool_kernel_size
,
unmatched_keypoint_score
,
box_scale
,
candidate_search_scale
,
candidate_ranking_mode
,
offset_peak_radius
,
candidate_search_scale
,
candidate_ranking_mode
,
offset_peak_radius
,
per_keypoint_offset
,
predict_depth
,
per_keypoint_depth
,
per_keypoint_offset
,
predict_depth
,
per_keypoint_depth
,
keypoint_depth_loss_weight
)
keypoint_depth_loss_weight
,
score_distance_offset
,
clip_out_of_frame_keypoints
,
rescore_instances
)
class
ObjectCenterParams
(
class
ObjectCenterParams
(
...
@@ -2949,6 +2971,71 @@ class CenterNetMetaArch(model.DetectionModel):
...
@@ -2949,6 +2971,71 @@ class CenterNetMetaArch(model.DetectionModel):
loss_dict
[
TEMPORAL_OFFSET
]
=
offset_loss
loss_dict
[
TEMPORAL_OFFSET
]
=
offset_loss
return
loss_dict
return
loss_dict
def
_should_clip_keypoints
(
self
):
"""Returns a boolean indicating whether keypoint clipping should occur.
If there is only one keypoint task, clipping is controlled by the field
`clip_out_of_frame_keypoints`. If there are multiple keypoint tasks,
clipping logic is defined based on unanimous agreement of keypoint
parameters. If there is any ambiguity, clip_out_of_frame_keypoints is set
to False (default).
"""
kp_params_iterator
=
iter
(
self
.
_kp_params_dict
.
values
())
if
len
(
self
.
_kp_params_dict
)
==
1
:
kp_params
=
next
(
kp_params_iterator
)
return
kp_params
.
clip_out_of_frame_keypoints
# Multi-task setting.
kp_params
=
next
(
kp_params_iterator
)
should_clip
=
kp_params
.
clip_out_of_frame_keypoints
for
kp_params
in
kp_params_iterator
:
if
kp_params
.
clip_out_of_frame_keypoints
!=
should_clip
:
return
False
return
should_clip
def
_rescore_instances
(
self
,
classes
,
scores
,
keypoint_scores
):
"""Rescores instances based on detection and keypoint scores.
Args:
classes: A [batch, max_detections] int32 tensor with detection classes.
scores: A [batch, max_detections] float32 tensor with detection scores.
keypoint_scores: A [batch, max_detections, total_num_keypoints] float32
tensor with keypoint scores.
Returns:
A [batch, max_detections] float32 tensor with possibly altered detection
scores.
"""
batch
,
max_detections
,
total_num_keypoints
=
(
shape_utils
.
combined_static_and_dynamic_shape
(
keypoint_scores
))
classes_tiled
=
tf
.
tile
(
classes
[:,
:,
tf
.
newaxis
],
multiples
=
[
1
,
1
,
total_num_keypoints
])
# TODO(yuhuic): Investigate whether this function will reate subgraphs in
# tflite that will cause the model to run slower at inference.
for
kp_params
in
self
.
_kp_params_dict
.
values
():
if
not
kp_params
.
rescore_instances
:
continue
class_id
=
kp_params
.
class_id
keypoint_indices
=
kp_params
.
keypoint_indices
num_keypoints
=
len
(
keypoint_indices
)
kpt_mask
=
tf
.
reduce_sum
(
tf
.
one_hot
(
keypoint_indices
,
depth
=
total_num_keypoints
),
axis
=
0
)
kpt_mask_tiled
=
tf
.
tile
(
kpt_mask
[
tf
.
newaxis
,
tf
.
newaxis
,
:],
multiples
=
[
batch
,
max_detections
,
1
])
class_and_keypoint_mask
=
tf
.
math
.
logical_and
(
classes_tiled
==
class_id
,
kpt_mask_tiled
==
1.0
)
class_and_keypoint_mask_float
=
tf
.
cast
(
class_and_keypoint_mask
,
dtype
=
tf
.
float32
)
scores_for_class
=
(
1.
/
num_keypoints
)
*
(
tf
.
reduce_sum
(
class_and_keypoint_mask_float
*
scores
[:,
:,
tf
.
newaxis
]
*
keypoint_scores
,
axis
=-
1
))
scores
=
tf
.
where
(
classes
==
class_id
,
scores_for_class
,
scores
)
return
scores
def
preprocess
(
self
,
inputs
):
def
preprocess
(
self
,
inputs
):
outputs
=
shape_utils
.
resize_images_and_return_shapes
(
outputs
=
shape_utils
.
resize_images_and_return_shapes
(
inputs
,
self
.
_image_resizer_fn
)
inputs
,
self
.
_image_resizer_fn
)
...
@@ -3214,18 +3301,16 @@ class CenterNetMetaArch(model.DetectionModel):
...
@@ -3214,18 +3301,16 @@ class CenterNetMetaArch(model.DetectionModel):
# If the model is trained to predict only one class of object and its
# If the model is trained to predict only one class of object and its
# keypoint, we fall back to a simpler postprocessing function which uses
# keypoint, we fall back to a simpler postprocessing function which uses
# the ops that are supported by tf.lite on GPU.
# the ops that are supported by tf.lite on GPU.
clip_keypoints
=
self
.
_should_clip_keypoints
()
if
len
(
self
.
_kp_params_dict
)
==
1
and
self
.
_num_classes
==
1
:
if
len
(
self
.
_kp_params_dict
)
==
1
and
self
.
_num_classes
==
1
:
(
keypoints
,
keypoint_scores
,
(
keypoints
,
keypoint_scores
,
keypoint_depths
)
=
self
.
_postprocess_keypoints_single_class
(
keypoint_depths
)
=
self
.
_postprocess_keypoints_single_class
(
prediction_dict
,
classes
,
y_indices
,
x_indices
,
boxes_strided
,
prediction_dict
,
classes
,
y_indices
,
x_indices
,
boxes_strided
,
num_detections
)
num_detections
)
# The map_fn used to clip out of frame keypoints creates issues when
# converting to tf.lite model so we disable it and let the users to
# handle those out of frame keypoints.
keypoints
,
keypoint_scores
=
(
keypoints
,
keypoint_scores
=
(
convert_strided_predictions_to_normalized_keypoints
(
convert_strided_predictions_to_normalized_keypoints
(
keypoints
,
keypoint_scores
,
self
.
_stride
,
true_image_shapes
,
keypoints
,
keypoint_scores
,
self
.
_stride
,
true_image_shapes
,
clip_out_of_frame_keypoints
=
False
))
clip_out_of_frame_keypoints
=
clip_keypoints
))
if
keypoint_depths
is
not
None
:
if
keypoint_depths
is
not
None
:
postprocess_dict
.
update
({
postprocess_dict
.
update
({
fields
.
DetectionResultFields
.
detection_keypoint_depths
:
fields
.
DetectionResultFields
.
detection_keypoint_depths
:
...
@@ -3244,8 +3329,12 @@ class CenterNetMetaArch(model.DetectionModel):
...
@@ -3244,8 +3329,12 @@ class CenterNetMetaArch(model.DetectionModel):
keypoints
,
keypoint_scores
=
(
keypoints
,
keypoint_scores
=
(
convert_strided_predictions_to_normalized_keypoints
(
convert_strided_predictions_to_normalized_keypoints
(
keypoints
,
keypoint_scores
,
self
.
_stride
,
true_image_shapes
,
keypoints
,
keypoint_scores
,
self
.
_stride
,
true_image_shapes
,
clip_out_of_frame_keypoints
=
True
))
clip_out_of_frame_keypoints
=
clip_keypoints
))
# Update instance scores based on keypoints.
scores
=
self
.
_rescore_instances
(
classes
,
scores
,
keypoint_scores
)
postprocess_dict
.
update
({
postprocess_dict
.
update
({
fields
.
DetectionResultFields
.
detection_scores
:
scores
,
fields
.
DetectionResultFields
.
detection_keypoints
:
keypoints
,
fields
.
DetectionResultFields
.
detection_keypoints
:
keypoints
,
fields
.
DetectionResultFields
.
detection_keypoint_scores
:
fields
.
DetectionResultFields
.
detection_keypoint_scores
:
keypoint_scores
keypoint_scores
...
@@ -3783,6 +3872,7 @@ class CenterNetMetaArch(model.DetectionModel):
...
@@ -3783,6 +3872,7 @@ class CenterNetMetaArch(model.DetectionModel):
box_scale
=
kp_params
.
box_scale
,
box_scale
=
kp_params
.
box_scale
,
candidate_search_scale
=
kp_params
.
candidate_search_scale
,
candidate_search_scale
=
kp_params
.
candidate_search_scale
,
candidate_ranking_mode
=
kp_params
.
candidate_ranking_mode
,
candidate_ranking_mode
=
kp_params
.
candidate_ranking_mode
,
score_distance_offset
=
kp_params
.
score_distance_offset
,
keypoint_depth_candidates
=
keypoint_depth_candidates
)
keypoint_depth_candidates
=
keypoint_depth_candidates
)
return
refined_keypoints
,
refined_scores
,
refined_depths
return
refined_keypoints
,
refined_scores
,
refined_depths
...
...
research/object_detection/meta_architectures/center_net_meta_arch_tf2_test.py
View file @
776cb1ca
...
@@ -2218,6 +2218,67 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase):
...
@@ -2218,6 +2218,67 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase):
classes
,
num_detections
,
batch_index
,
class_id
)
classes
,
num_detections
,
batch_index
,
class_id
)
self
.
assertAllEqual
(
valid_indices
.
numpy
(),
[
0
,
2
])
self
.
assertAllEqual
(
valid_indices
.
numpy
(),
[
0
,
2
])
def
test_rescore_instances
(
self
):
feature_extractor
=
DummyFeatureExtractor
(
channel_means
=
(
1.0
,
2.0
,
3.0
),
channel_stds
=
(
10.
,
20.
,
30.
),
bgr_ordering
=
False
,
num_feature_outputs
=
2
,
stride
=
4
)
image_resizer_fn
=
functools
.
partial
(
preprocessor
.
resize_to_range
,
min_dimension
=
128
,
max_dimension
=
128
,
pad_to_max_dimesnion
=
True
)
kp_params_1
=
cnma
.
KeypointEstimationParams
(
task_name
=
'kpt_task_1'
,
class_id
=
0
,
keypoint_indices
=
[
0
,
1
,
2
],
keypoint_std_dev
=
[
0.00001
]
*
3
,
classification_loss
=
losses
.
WeightedSigmoidClassificationLoss
(),
localization_loss
=
losses
.
L1LocalizationLoss
(),
keypoint_candidate_score_threshold
=
0.1
,
rescore_instances
=
True
)
# Note rescoring for class_id = 0.
kp_params_2
=
cnma
.
KeypointEstimationParams
(
task_name
=
'kpt_task_2'
,
class_id
=
1
,
keypoint_indices
=
[
3
,
4
],
keypoint_std_dev
=
[
0.00001
]
*
2
,
classification_loss
=
losses
.
WeightedSigmoidClassificationLoss
(),
localization_loss
=
losses
.
L1LocalizationLoss
(),
keypoint_candidate_score_threshold
=
0.1
,
rescore_instances
=
False
)
model
=
cnma
.
CenterNetMetaArch
(
is_training
=
True
,
add_summaries
=
False
,
num_classes
=
2
,
feature_extractor
=
feature_extractor
,
image_resizer_fn
=
image_resizer_fn
,
object_center_params
=
get_fake_center_params
(),
object_detection_params
=
get_fake_od_params
(),
keypoint_params_dict
=
{
'kpt_task_1'
:
kp_params_1
,
'kpt_task_2'
:
kp_params_2
,
})
def
graph_fn
():
classes
=
tf
.
constant
([[
1
,
0
]],
dtype
=
tf
.
int32
)
scores
=
tf
.
constant
([[
0.5
,
0.75
]],
dtype
=
tf
.
float32
)
keypoint_scores
=
tf
.
constant
(
[
[[
0.1
,
0.2
,
0.3
,
0.4
,
0.5
],
[
0.1
,
0.2
,
0.3
,
0.4
,
0.5
]],
])
new_scores
=
model
.
_rescore_instances
(
classes
,
scores
,
keypoint_scores
)
return
new_scores
new_scores
=
self
.
execute_cpu
(
graph_fn
,
[])
expected_scores
=
np
.
array
(
[[
0.5
,
0.75
*
(
0.1
+
0.2
+
0.3
)
/
3
]]
)
self
.
assertAllClose
(
expected_scores
,
new_scores
)
def
get_fake_prediction_dict
(
input_height
,
def
get_fake_prediction_dict
(
input_height
,
input_width
,
input_width
,
...
...
research/object_detection/protos/center_net.proto
View file @
776cb1ca
...
@@ -153,6 +153,12 @@ message CenterNet {
...
@@ -153,6 +153,12 @@ message CenterNet {
// the keypoint candidate.
// the keypoint candidate.
optional
string
candidate_ranking_mode
=
16
[
default
=
"min_distance"
];
optional
string
candidate_ranking_mode
=
16
[
default
=
"min_distance"
];
// The score distance ratio offset, only used if candidate_ranking_mode is
// 'score_distance_ratio'. The offset is used in the maximization of score
// distance ratio, defined as:
// keypoint_score / (distance + score_distance_offset)
optional
float
score_distance_offset
=
22
[
default
=
1.0
];
// The radius (in the unit of output pixel) around heatmap peak to assign
// The radius (in the unit of output pixel) around heatmap peak to assign
// the offset targets. If set 0, then the offset target will only be
// the offset targets. If set 0, then the offset target will only be
// assigned to the heatmap peak (same behavior as the original paper).
// assigned to the heatmap peak (same behavior as the original paper).
...
@@ -180,6 +186,19 @@ message CenterNet {
...
@@ -180,6 +186,19 @@ message CenterNet {
// The weight of the keypoint depth loss.
// The weight of the keypoint depth loss.
optional
float
keypoint_depth_loss_weight
=
21
[
default
=
1.0
];
optional
float
keypoint_depth_loss_weight
=
21
[
default
=
1.0
];
// Whether keypoints outside the image frame should be clipped back to the
// image boundary. If true, the keypoints that are clipped have scores set
// to 0.0.
optional
bool
clip_out_of_frame_keypoints
=
23
[
default
=
false
];
// Whether instances should be rescored based on keypoint confidences. If
// False, will use the detection score (from the object center heatmap). If
// True, will compute new scores with:
// new_score = o * (1/k) sum {s_i}
// where o is the object score, s_i is the score for keypoint i, and k is
// the number of keypoints for that class.
optional
bool
rescore_instances
=
24
[
default
=
false
];
}
}
repeated
KeypointEstimation
keypoint_estimation_task
=
7
;
repeated
KeypointEstimation
keypoint_estimation_task
=
7
;
...
...
research/object_detection/utils/config_util.py
View file @
776cb1ca
...
@@ -19,9 +19,9 @@ from __future__ import division
...
@@ -19,9 +19,9 @@ from __future__ import division
from
__future__
import
print_function
from
__future__
import
print_function
import
os
import
os
import
tensorflow.compat.v1
as
tf
from
google.protobuf
import
text_format
from
google.protobuf
import
text_format
import
tensorflow.compat.v1
as
tf
from
tensorflow.python.lib.io
import
file_io
from
tensorflow.python.lib.io
import
file_io
...
@@ -623,6 +623,20 @@ def _maybe_update_config_with_key_value(configs, key, value):
...
@@ -623,6 +623,20 @@ def _maybe_update_config_with_key_value(configs, key, value):
_update_num_classes
(
configs
[
"model"
],
value
)
_update_num_classes
(
configs
[
"model"
],
value
)
elif
field_name
==
"sample_from_datasets_weights"
:
elif
field_name
==
"sample_from_datasets_weights"
:
_update_sample_from_datasets_weights
(
configs
[
"train_input_config"
],
value
)
_update_sample_from_datasets_weights
(
configs
[
"train_input_config"
],
value
)
elif
field_name
==
"peak_max_pool_kernel_size"
:
_update_peak_max_pool_kernel_size
(
configs
[
"model"
],
value
)
elif
field_name
==
"candidate_search_scale"
:
_update_candidate_search_scale
(
configs
[
"model"
],
value
)
elif
field_name
==
"candidate_ranking_mode"
:
_update_candidate_ranking_mode
(
configs
[
"model"
],
value
)
elif
field_name
==
"score_distance_offset"
:
_update_score_distance_offset
(
configs
[
"model"
],
value
)
elif
field_name
==
"box_scale"
:
_update_box_scale
(
configs
[
"model"
],
value
)
elif
field_name
==
"keypoint_candidate_score_threshold"
:
_update_keypoint_candidate_score_threshold
(
configs
[
"model"
],
value
)
elif
field_name
==
"rescore_instances"
:
_update_rescore_instances
(
configs
[
"model"
],
value
)
else
:
else
:
return
False
return
False
return
True
return
True
...
@@ -1089,3 +1103,99 @@ def _update_sample_from_datasets_weights(input_reader_config, weights):
...
@@ -1089,3 +1103,99 @@ def _update_sample_from_datasets_weights(input_reader_config, weights):
del
input_reader_config
.
sample_from_datasets_weights
[:]
del
input_reader_config
.
sample_from_datasets_weights
[:]
input_reader_config
.
sample_from_datasets_weights
.
extend
(
weights
)
input_reader_config
.
sample_from_datasets_weights
.
extend
(
weights
)
def
_update_peak_max_pool_kernel_size
(
model_config
,
kernel_size
):
"""Updates the max pool kernel size (NMS) for keypoints in CenterNet."""
meta_architecture
=
model_config
.
WhichOneof
(
"model"
)
if
meta_architecture
==
"center_net"
:
if
len
(
model_config
.
center_net
.
keypoint_estimation_task
)
==
1
:
kpt_estimation_task
=
model_config
.
center_net
.
keypoint_estimation_task
[
0
]
kpt_estimation_task
.
peak_max_pool_kernel_size
=
kernel_size
else
:
tf
.
logging
.
warning
(
"Ignoring config override key for "
"peak_max_pool_kernel_size since there are multiple "
"keypoint estimation tasks"
)
def
_update_candidate_search_scale
(
model_config
,
search_scale
):
"""Updates the keypoint candidate search scale in CenterNet."""
meta_architecture
=
model_config
.
WhichOneof
(
"model"
)
if
meta_architecture
==
"center_net"
:
if
len
(
model_config
.
center_net
.
keypoint_estimation_task
)
==
1
:
kpt_estimation_task
=
model_config
.
center_net
.
keypoint_estimation_task
[
0
]
kpt_estimation_task
.
candidate_search_scale
=
search_scale
else
:
tf
.
logging
.
warning
(
"Ignoring config override key for "
"candidate_search_scale since there are multiple "
"keypoint estimation tasks"
)
def
_update_candidate_ranking_mode
(
model_config
,
mode
):
"""Updates how keypoints are snapped to candidates in CenterNet."""
if
mode
not
in
(
"min_distance"
,
"score_distance_ratio"
):
raise
ValueError
(
"Attempting to set the keypoint candidate ranking mode "
"to {}, but the only options are 'min_distance' and "
"'score_distance_ratio'."
.
format
(
mode
))
meta_architecture
=
model_config
.
WhichOneof
(
"model"
)
if
meta_architecture
==
"center_net"
:
if
len
(
model_config
.
center_net
.
keypoint_estimation_task
)
==
1
:
kpt_estimation_task
=
model_config
.
center_net
.
keypoint_estimation_task
[
0
]
kpt_estimation_task
.
candidate_ranking_mode
=
mode
else
:
tf
.
logging
.
warning
(
"Ignoring config override key for "
"candidate_ranking_mode since there are multiple "
"keypoint estimation tasks"
)
def
_update_score_distance_offset
(
model_config
,
offset
):
"""Updates the keypoint candidate selection metric. See CenterNet proto."""
meta_architecture
=
model_config
.
WhichOneof
(
"model"
)
if
meta_architecture
==
"center_net"
:
if
len
(
model_config
.
center_net
.
keypoint_estimation_task
)
==
1
:
kpt_estimation_task
=
model_config
.
center_net
.
keypoint_estimation_task
[
0
]
kpt_estimation_task
.
score_distance_offset
=
offset
else
:
tf
.
logging
.
warning
(
"Ignoring config override key for "
"score_distance_offset since there are multiple "
"keypoint estimation tasks"
)
def
_update_box_scale
(
model_config
,
box_scale
):
"""Updates the keypoint candidate search region. See CenterNet proto."""
meta_architecture
=
model_config
.
WhichOneof
(
"model"
)
if
meta_architecture
==
"center_net"
:
if
len
(
model_config
.
center_net
.
keypoint_estimation_task
)
==
1
:
kpt_estimation_task
=
model_config
.
center_net
.
keypoint_estimation_task
[
0
]
kpt_estimation_task
.
box_scale
=
box_scale
else
:
tf
.
logging
.
warning
(
"Ignoring config override key for box_scale since "
"there are multiple keypoint estimation tasks"
)
def
_update_keypoint_candidate_score_threshold
(
model_config
,
threshold
):
"""Updates the keypoint candidate score threshold. See CenterNet proto."""
meta_architecture
=
model_config
.
WhichOneof
(
"model"
)
if
meta_architecture
==
"center_net"
:
if
len
(
model_config
.
center_net
.
keypoint_estimation_task
)
==
1
:
kpt_estimation_task
=
model_config
.
center_net
.
keypoint_estimation_task
[
0
]
kpt_estimation_task
.
keypoint_candidate_score_threshold
=
threshold
else
:
tf
.
logging
.
warning
(
"Ignoring config override key for "
"keypoint_candidate_score_threshold since there are "
"multiple keypoint estimation tasks"
)
def
_update_rescore_instances
(
model_config
,
should_rescore
):
"""Updates whether boxes should be rescored based on keypoint confidences."""
if
isinstance
(
should_rescore
,
str
):
should_rescore
=
True
if
should_rescore
==
"True"
else
False
meta_architecture
=
model_config
.
WhichOneof
(
"model"
)
if
meta_architecture
==
"center_net"
:
if
len
(
model_config
.
center_net
.
keypoint_estimation_task
)
==
1
:
kpt_estimation_task
=
model_config
.
center_net
.
keypoint_estimation_task
[
0
]
kpt_estimation_task
.
rescore_instances
=
should_rescore
else
:
tf
.
logging
.
warning
(
"Ignoring config override key for "
"rescore_instances since there are multiple keypoint "
"estimation tasks"
)
research/object_detection/utils/config_util_test.py
View file @
776cb1ca
...
@@ -1018,6 +1018,67 @@ class ConfigUtilTest(tf.test.TestCase):
...
@@ -1018,6 +1018,67 @@ class ConfigUtilTest(tf.test.TestCase):
output_dict
,
output_dict
,
config_util
.
remove_unecessary_ema
(
input_dict
,
no_ema_collection
))
config_util
.
remove_unecessary_ema
(
input_dict
,
no_ema_collection
))
def
testUpdateRescoreInstances
(
self
):
pipeline_config_path
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
"pipeline.config"
)
pipeline_config
=
pipeline_pb2
.
TrainEvalPipelineConfig
()
kpt_task
=
pipeline_config
.
model
.
center_net
.
keypoint_estimation_task
.
add
()
kpt_task
.
rescore_instances
=
True
_write_config
(
pipeline_config
,
pipeline_config_path
)
configs
=
config_util
.
get_configs_from_pipeline_file
(
pipeline_config_path
)
cn_config
=
configs
[
"model"
].
center_net
self
.
assertEqual
(
True
,
cn_config
.
keypoint_estimation_task
[
0
].
rescore_instances
)
config_util
.
merge_external_params_with_configs
(
configs
,
kwargs_dict
=
{
"rescore_instances"
:
False
})
cn_config
=
configs
[
"model"
].
center_net
self
.
assertEqual
(
False
,
cn_config
.
keypoint_estimation_task
[
0
].
rescore_instances
)
def
testUpdateRescoreInstancesWithBooleanString
(
self
):
pipeline_config_path
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
"pipeline.config"
)
pipeline_config
=
pipeline_pb2
.
TrainEvalPipelineConfig
()
kpt_task
=
pipeline_config
.
model
.
center_net
.
keypoint_estimation_task
.
add
()
kpt_task
.
rescore_instances
=
True
_write_config
(
pipeline_config
,
pipeline_config_path
)
configs
=
config_util
.
get_configs_from_pipeline_file
(
pipeline_config_path
)
cn_config
=
configs
[
"model"
].
center_net
self
.
assertEqual
(
True
,
cn_config
.
keypoint_estimation_task
[
0
].
rescore_instances
)
config_util
.
merge_external_params_with_configs
(
configs
,
kwargs_dict
=
{
"rescore_instances"
:
"False"
})
cn_config
=
configs
[
"model"
].
center_net
self
.
assertEqual
(
False
,
cn_config
.
keypoint_estimation_task
[
0
].
rescore_instances
)
def
testUpdateRescoreInstancesWithMultipleTasks
(
self
):
pipeline_config_path
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
"pipeline.config"
)
pipeline_config
=
pipeline_pb2
.
TrainEvalPipelineConfig
()
kpt_task
=
pipeline_config
.
model
.
center_net
.
keypoint_estimation_task
.
add
()
kpt_task
.
rescore_instances
=
True
kpt_task
=
pipeline_config
.
model
.
center_net
.
keypoint_estimation_task
.
add
()
kpt_task
.
rescore_instances
=
True
_write_config
(
pipeline_config
,
pipeline_config_path
)
configs
=
config_util
.
get_configs_from_pipeline_file
(
pipeline_config_path
)
cn_config
=
configs
[
"model"
].
center_net
self
.
assertEqual
(
True
,
cn_config
.
keypoint_estimation_task
[
0
].
rescore_instances
)
config_util
.
merge_external_params_with_configs
(
configs
,
kwargs_dict
=
{
"rescore_instances"
:
False
})
cn_config
=
configs
[
"model"
].
center_net
self
.
assertEqual
(
True
,
cn_config
.
keypoint_estimation_task
[
0
].
rescore_instances
)
self
.
assertEqual
(
True
,
cn_config
.
keypoint_estimation_task
[
1
].
rescore_instances
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment