Commit 17169b3b authored by Yu-hui Chen's avatar Yu-hui Chen Committed by TF Object Detection Team
Browse files

Updated the postprocessing logic to perform the keypoint rescoring after NMS

such that NMS won't wash out the rescored instance scores.

PiperOrigin-RevId: 407602744
parent 88d844e7
...@@ -4171,11 +4171,7 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -4171,11 +4171,7 @@ class CenterNetMetaArch(model.DetectionModel):
keypoints, keypoint_scores, self._stride, true_image_shapes, keypoints, keypoint_scores, self._stride, true_image_shapes,
clip_out_of_frame_keypoints=clip_keypoints)) clip_out_of_frame_keypoints=clip_keypoints))
# Update instance scores based on keypoints.
scores = self._rescore_instances(
channel_indices, detection_scores, keypoint_scores)
postprocess_dict.update({ postprocess_dict.update({
fields.DetectionResultFields.detection_scores: scores,
fields.DetectionResultFields.detection_keypoints: keypoints, fields.DetectionResultFields.detection_keypoints: keypoints,
fields.DetectionResultFields.detection_keypoint_scores: fields.DetectionResultFields.detection_keypoint_scores:
keypoint_scores keypoint_scores
...@@ -4253,6 +4249,22 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -4253,6 +4249,22 @@ class CenterNetMetaArch(model.DetectionModel):
postprocess_dict[ postprocess_dict[
fields.DetectionResultFields.num_detections] = num_detections fields.DetectionResultFields.num_detections] = num_detections
postprocess_dict.update(nmsed_additional_fields) postprocess_dict.update(nmsed_additional_fields)
# Perform the rescoring once the NMS is applied to make sure the rescored
# scores won't be washed out by the NMS function.
if self._kp_params_dict:
channel_indices = postprocess_dict[
fields.DetectionResultFields.detection_classes]
detection_scores = postprocess_dict[
fields.DetectionResultFields.detection_scores]
keypoint_scores = postprocess_dict[
fields.DetectionResultFields.detection_keypoint_scores]
# Update instance scores based on keypoints.
scores = self._rescore_instances(
channel_indices, detection_scores, keypoint_scores)
postprocess_dict.update({
fields.DetectionResultFields.detection_scores: scores,
})
return postprocess_dict return postprocess_dict
def postprocess_single_instance_keypoints( def postprocess_single_instance_keypoints(
......
...@@ -1772,7 +1772,8 @@ def get_fake_kp_params(num_candidates_per_keypoint=100, ...@@ -1772,7 +1772,8 @@ def get_fake_kp_params(num_candidates_per_keypoint=100,
per_keypoint_depth=False, per_keypoint_depth=False,
peak_radius=0, peak_radius=0,
candidate_ranking_mode='min_distance', candidate_ranking_mode='min_distance',
argmax_postprocessing=False): argmax_postprocessing=False,
rescore_instances=False):
"""Returns the fake keypoint estimation parameter namedtuple.""" """Returns the fake keypoint estimation parameter namedtuple."""
return cnma.KeypointEstimationParams( return cnma.KeypointEstimationParams(
task_name=_TASK_NAME, task_name=_TASK_NAME,
...@@ -1789,7 +1790,9 @@ def get_fake_kp_params(num_candidates_per_keypoint=100, ...@@ -1789,7 +1790,9 @@ def get_fake_kp_params(num_candidates_per_keypoint=100,
per_keypoint_depth=per_keypoint_depth, per_keypoint_depth=per_keypoint_depth,
offset_peak_radius=peak_radius, offset_peak_radius=peak_radius,
candidate_ranking_mode=candidate_ranking_mode, candidate_ranking_mode=candidate_ranking_mode,
argmax_postprocessing=argmax_postprocessing) argmax_postprocessing=argmax_postprocessing,
rescore_instances=rescore_instances,
rescoring_threshold=0.5)
def get_fake_mask_params(): def get_fake_mask_params():
...@@ -1845,7 +1848,8 @@ def build_center_net_meta_arch(build_resnet=False, ...@@ -1845,7 +1848,8 @@ def build_center_net_meta_arch(build_resnet=False,
peak_radius=0, peak_radius=0,
keypoint_only=False, keypoint_only=False,
candidate_ranking_mode='min_distance', candidate_ranking_mode='min_distance',
argmax_postprocessing=False): argmax_postprocessing=False,
rescore_instances=False):
"""Builds the CenterNet meta architecture.""" """Builds the CenterNet meta architecture."""
if build_resnet: if build_resnet:
feature_extractor = ( feature_extractor = (
...@@ -1867,7 +1871,7 @@ def build_center_net_meta_arch(build_resnet=False, ...@@ -1867,7 +1871,7 @@ def build_center_net_meta_arch(build_resnet=False,
non_max_suppression_fn = None non_max_suppression_fn = None
if apply_non_max_suppression: if apply_non_max_suppression:
post_processing_proto = post_processing_pb2.PostProcessing() post_processing_proto = post_processing_pb2.PostProcessing()
post_processing_proto.batch_non_max_suppression.iou_threshold = 1.0 post_processing_proto.batch_non_max_suppression.iou_threshold = 0.6
post_processing_proto.batch_non_max_suppression.score_threshold = 0.6 post_processing_proto.batch_non_max_suppression.score_threshold = 0.6
(post_processing_proto.batch_non_max_suppression.max_total_detections (post_processing_proto.batch_non_max_suppression.max_total_detections
) = max_box_predictions ) = max_box_predictions
...@@ -1893,7 +1897,7 @@ def build_center_net_meta_arch(build_resnet=False, ...@@ -1893,7 +1897,7 @@ def build_center_net_meta_arch(build_resnet=False,
per_keypoint_offset, predict_depth, per_keypoint_offset, predict_depth,
per_keypoint_depth, peak_radius, per_keypoint_depth, peak_radius,
candidate_ranking_mode, candidate_ranking_mode,
argmax_postprocessing) argmax_postprocessing, rescore_instances)
}, },
non_max_suppression_fn=non_max_suppression_fn) non_max_suppression_fn=non_max_suppression_fn)
elif detection_only: elif detection_only:
...@@ -1922,7 +1926,7 @@ def build_center_net_meta_arch(build_resnet=False, ...@@ -1922,7 +1926,7 @@ def build_center_net_meta_arch(build_resnet=False,
per_keypoint_offset, predict_depth, per_keypoint_offset, predict_depth,
per_keypoint_depth, peak_radius, per_keypoint_depth, peak_radius,
candidate_ranking_mode, candidate_ranking_mode,
argmax_postprocessing) argmax_postprocessing, rescore_instances)
}, },
non_max_suppression_fn=non_max_suppression_fn) non_max_suppression_fn=non_max_suppression_fn)
else: else:
...@@ -2456,6 +2460,75 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase): ...@@ -2456,6 +2460,75 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase):
self.assertAllClose(expected_multiclass_scores, self.assertAllClose(expected_multiclass_scores,
detections['detection_multiclass_scores'][0][0]) detections['detection_multiclass_scores'][0][0])
def test_non_max_suppression_with_kpts_rescoring(self):
"""Tests application of NMS on CenterNet detections and keypoints."""
model = build_center_net_meta_arch(
num_classes=1, max_box_predictions=5, per_keypoint_offset=True,
candidate_ranking_mode='min_distance',
argmax_postprocessing=False, apply_non_max_suppression=True,
rescore_instances=True)
num_keypoints = len(model._kp_params_dict[_TASK_NAME].keypoint_indices)
class_center = np.zeros((1, 32, 32, 2), dtype=np.float32)
height_width = np.zeros((1, 32, 32, 2), dtype=np.float32)
offset = np.zeros((1, 32, 32, 2), dtype=np.float32)
keypoint_heatmaps = np.ones(
(1, 32, 32, num_keypoints), dtype=np.float32) * _logit(0.01)
keypoint_offsets = np.zeros(
(1, 32, 32, num_keypoints * 2), dtype=np.float32)
keypoint_regression = np.random.randn(1, 32, 32, num_keypoints * 2)
class_probs = np.zeros(2)
class_probs[1] = _logit(0.75)
class_center[0, 16, 16] = class_probs
height_width[0, 16, 16] = [5, 10]
offset[0, 16, 16] = [.25, .5]
class_center[0, 16, 17] = class_probs
height_width[0, 16, 17] = [5, 10]
offset[0, 16, 17] = [.25, .5]
keypoint_regression[0, 16, 16] = [
-1., -1.,
-1., 1.,
1., -1.,
1., 1.]
keypoint_heatmaps[0, 14, 14, 0] = _logit(0.9)
keypoint_heatmaps[0, 14, 18, 1] = _logit(0.9)
keypoint_heatmaps[0, 18, 14, 2] = _logit(0.9)
keypoint_heatmaps[0, 18, 18, 3] = _logit(0.05) # Note the low score.
class_center = tf.constant(class_center)
height_width = tf.constant(height_width)
offset = tf.constant(offset)
keypoint_heatmaps = tf.constant(keypoint_heatmaps, dtype=tf.float32)
keypoint_offsets = tf.constant(keypoint_offsets, dtype=tf.float32)
keypoint_regression = tf.constant(keypoint_regression, dtype=tf.float32)
prediction_dict = {
cnma.OBJECT_CENTER: [class_center],
cnma.BOX_SCALE: [height_width],
cnma.BOX_OFFSET: [offset],
cnma.get_keypoint_name(_TASK_NAME, cnma.KEYPOINT_HEATMAP):
[keypoint_heatmaps],
cnma.get_keypoint_name(_TASK_NAME, cnma.KEYPOINT_OFFSET):
[keypoint_offsets],
cnma.get_keypoint_name(_TASK_NAME, cnma.KEYPOINT_REGRESSION):
[keypoint_regression],
}
def graph_fn():
detections = model.postprocess(prediction_dict,
tf.constant([[128, 128, 3]]))
return detections
detections = self.execute_cpu(graph_fn, [])
num_detections = int(detections['num_detections'])
# One of the box is filtered by NMS.
self.assertEqual(num_detections, 1)
# The keypoint scores are [0.9, 0.9, 0.9, 0.1] and the resulting rescored
# score is 0.9 * 3 / 4 = 0.675.
self.assertAllClose(detections['detection_scores'][0][:num_detections],
[0.675])
@parameterized.parameters( @parameterized.parameters(
{ {
'candidate_ranking_mode': 'min_distance', 'candidate_ranking_mode': 'min_distance',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment