Commit 6559ed14 authored by Vivek Rathod's avatar Vivek Rathod Committed by TF Object Detection Team
Browse files

Add opional Non-Max Suppression in CenterNet

PiperOrigin-RevId: 354590046
parent 3063aeb3
......@@ -1039,7 +1039,10 @@ def _build_center_net_model(center_net_config, is_training, add_summaries):
if center_net_config.HasField('temporal_offset_task'):
temporal_offset_params = temporal_offset_proto_to_params(
center_net_config.temporal_offset_task)
non_max_suppression_fn = None
if center_net_config.HasField('post_processing'):
non_max_suppression_fn, _ = post_processing_builder.build(
center_net_config.post_processing)
return center_net_meta_arch.CenterNetMetaArch(
is_training=is_training,
add_summaries=add_summaries,
......@@ -1054,7 +1057,8 @@ def _build_center_net_model(center_net_config, is_training, add_summaries):
track_params=track_params,
temporal_offset_params=temporal_offset_params,
use_depthwise=center_net_config.use_depthwise,
compute_heatmap_sparse=center_net_config.compute_heatmap_sparse)
compute_heatmap_sparse=center_net_config.compute_heatmap_sparse,
non_max_suppression_fn=non_max_suppression_fn)
def _build_center_net_feature_extractor(
......
......@@ -1896,7 +1896,8 @@ class CenterNetMetaArch(model.DetectionModel):
track_params=None,
temporal_offset_params=None,
use_depthwise=False,
compute_heatmap_sparse=False):
compute_heatmap_sparse=False,
non_max_suppression_fn=None):
"""Initializes a CenterNet model.
Args:
......@@ -1939,6 +1940,7 @@ class CenterNetMetaArch(model.DetectionModel):
the Op that computes the center heatmaps. The sparse version scales
better with number of channels in the heatmap, but in some cases is
known to cause an OOM error. See b/170989061.
non_max_suppression_fn: Optional Non Max Suppression function to apply.
"""
assert object_detection_params or keypoint_params_dict
# Shorten the name for convenience and better formatting.
......@@ -1977,6 +1979,7 @@ class CenterNetMetaArch(model.DetectionModel):
# Will be used in VOD single_frame_meta_arch for tensor reshape.
self._batched_prediction_tensor_names = []
self._non_max_suppression_fn = non_max_suppression_fn
super(CenterNetMetaArch, self).__init__(num_classes)
......@@ -3108,6 +3111,34 @@ class CenterNetMetaArch(model.DetectionModel):
prediction_dict[TEMPORAL_OFFSET][-1])
postprocess_dict[fields.DetectionResultFields.detection_offsets] = offsets
if self._non_max_suppression_fn:
boxes = tf.expand_dims(
postprocess_dict.pop(fields.DetectionResultFields.detection_boxes),
axis=-2)
multiclass_scores = postprocess_dict[
fields.DetectionResultFields.detection_multiclass_scores]
num_valid_boxes = postprocess_dict.pop(
fields.DetectionResultFields.num_detections)
# Remove scores and classes as NMS will compute these form multiclass
# scores.
postprocess_dict.pop(fields.DetectionResultFields.detection_scores)
postprocess_dict.pop(fields.DetectionResultFields.detection_classes)
(nmsed_boxes, nmsed_scores, nmsed_classes, _, nmsed_additional_fields,
num_detections) = self._non_max_suppression_fn(
boxes,
multiclass_scores,
additional_fields=postprocess_dict,
num_valid_boxes=num_valid_boxes)
postprocess_dict = nmsed_additional_fields
postprocess_dict[
fields.DetectionResultFields.detection_boxes] = nmsed_boxes
postprocess_dict[
fields.DetectionResultFields.detection_scores] = nmsed_scores
postprocess_dict[
fields.DetectionResultFields.detection_classes] = nmsed_classes
postprocess_dict[
fields.DetectionResultFields.num_detections] = num_detections
postprocess_dict.update(nmsed_additional_fields)
return postprocess_dict
def postprocess_single_instance_keypoints(self, prediction_dict,
......
......@@ -24,12 +24,14 @@ from absl.testing import parameterized
import numpy as np
import tensorflow.compat.v1 as tf
from object_detection.builders import post_processing_builder
from object_detection.core import losses
from object_detection.core import preprocessor
from object_detection.core import standard_fields as fields
from object_detection.core import target_assigner as cn_assigner
from object_detection.meta_architectures import center_net_meta_arch as cnma
from object_detection.models import center_net_resnet_feature_extractor
from object_detection.protos import post_processing_pb2
from object_detection.utils import test_case
from object_detection.utils import tf_version
......@@ -1349,7 +1351,9 @@ def get_fake_temporal_offset_params():
def build_center_net_meta_arch(build_resnet=False,
num_classes=_NUM_CLASSES,
max_box_predictions=5):
max_box_predictions=5,
apply_non_max_suppression=False,
detection_only=False):
"""Builds the CenterNet meta architecture."""
if build_resnet:
feature_extractor = (
......@@ -1368,7 +1372,31 @@ def build_center_net_meta_arch(build_resnet=False,
max_dimension=128,
pad_to_max_dimesnion=True)
if num_classes == 1:
non_max_suppression_fn = None
if apply_non_max_suppression:
post_processing_proto = post_processing_pb2.PostProcessing()
post_processing_proto.batch_non_max_suppression.iou_threshold = 1.0
post_processing_proto.batch_non_max_suppression.score_threshold = 0.6
(post_processing_proto.batch_non_max_suppression.max_total_detections
) = max_box_predictions
(post_processing_proto.batch_non_max_suppression.max_detections_per_class
) = max_box_predictions
(post_processing_proto.batch_non_max_suppression.change_coordinate_frame
) = False
non_max_suppression_fn, _ = post_processing_builder.build(
post_processing_proto)
if detection_only:
return cnma.CenterNetMetaArch(
is_training=True,
add_summaries=False,
num_classes=num_classes,
feature_extractor=feature_extractor,
image_resizer_fn=image_resizer_fn,
object_center_params=get_fake_center_params(max_box_predictions),
object_detection_params=get_fake_od_params(),
non_max_suppression_fn=non_max_suppression_fn)
elif num_classes == 1:
num_candidates_per_keypoint = 100 if max_box_predictions > 1 else 1
return cnma.CenterNetMetaArch(
is_training=True,
......@@ -1380,7 +1408,8 @@ def build_center_net_meta_arch(build_resnet=False,
object_detection_params=get_fake_od_params(),
keypoint_params_dict={
_TASK_NAME: get_fake_kp_params(num_candidates_per_keypoint)
})
},
non_max_suppression_fn=non_max_suppression_fn)
else:
return cnma.CenterNetMetaArch(
is_training=True,
......@@ -1394,7 +1423,8 @@ def build_center_net_meta_arch(build_resnet=False,
mask_params=get_fake_mask_params(),
densepose_params=get_fake_densepose_params(),
track_params=get_fake_track_params(),
temporal_offset_params=get_fake_temporal_offset_params())
temporal_offset_params=get_fake_temporal_offset_params(),
non_max_suppression_fn=non_max_suppression_fn)
def _logit(p):
......@@ -1728,7 +1758,6 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase):
return detections
detections = self.execute_cpu(graph_fn, [])
self.assertAllClose(detections['detection_boxes'][0, 0],
np.array([55, 46, 75, 86]) / 128.0)
self.assertAllClose(detections['detection_scores'][0],
......@@ -1801,6 +1830,49 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase):
detections['detection_surface_coords'][0, 0, :, :],
np.zeros_like(detections['detection_surface_coords'][0, 0, :, :]))
def test_non_max_suppression(self):
"""Tests application of NMS on CenterNet detections."""
target_class_id = 1
model = build_center_net_meta_arch(apply_non_max_suppression=True,
detection_only=True)
class_center = np.zeros((1, 32, 32, 10), dtype=np.float32)
height_width = np.zeros((1, 32, 32, 2), dtype=np.float32)
offset = np.zeros((1, 32, 32, 2), dtype=np.float32)
class_probs = np.ones(10) * _logit(0.25)
class_probs[target_class_id] = _logit(0.75)
class_center[0, 16, 16] = class_probs
height_width[0, 16, 16] = [5, 10]
offset[0, 16, 16] = [.25, .5]
class_center = tf.constant(class_center)
height_width = tf.constant(height_width)
offset = tf.constant(offset)
prediction_dict = {
cnma.OBJECT_CENTER: [class_center],
cnma.BOX_SCALE: [height_width],
cnma.BOX_OFFSET: [offset],
}
def graph_fn():
detections = model.postprocess(prediction_dict,
tf.constant([[128, 128, 3]]))
return detections
detections = self.execute_cpu(graph_fn, [])
num_detections = int(detections['num_detections'])
self.assertEqual(num_detections, 1)
self.assertAllClose(detections['detection_boxes'][0, 0],
np.array([55, 46, 75, 86]) / 128.0)
self.assertAllClose(detections['detection_scores'][0][:num_detections],
[.75])
expected_multiclass_scores = [.25] * 10
expected_multiclass_scores[target_class_id] = .75
self.assertAllClose(expected_multiclass_scores,
detections['detection_multiclass_scores'][0][0])
def test_postprocess_single_class(self):
"""Test the postprocess function."""
model = build_center_net_meta_arch(num_classes=1)
......
......@@ -4,6 +4,7 @@ package object_detection.protos;
import "object_detection/protos/image_resizer.proto";
import "object_detection/protos/losses.proto";
import "object_detection/protos/post_processing.proto";
// Configuration for the CenterNet meta architecture from the "Objects as
// Points" paper [1]
......@@ -271,6 +272,13 @@ message CenterNet {
optional TemporalOffsetEstimation temporal_offset_task = 12;
// CenterNet does not apply conventional post processing operations such as
// non max suppression as it applies a max-pool operator on box centers.
// However, in some cases we observe the need to remove duplicate predictions
// from CenterNet. Use this optional parameter to apply traditional non max
// suppression and score thresholding.
optional PostProcessing post_processing = 24;
}
message CenterNetFeatureExtractor {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment