Commit 481cf8da authored by Vivek Rathod's avatar Vivek Rathod Committed by TF Object Detection Team
Browse files

Output multiclass scores from post-process.

PiperOrigin-RevId: 336090589
parent 92752da2
...@@ -2856,6 +2856,8 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -2856,6 +2856,8 @@ class CenterNetMetaArch(model.DetectionModel):
feature extractor's final layer output. feature extractor's final layer output.
detection_scores: A tensor of shape [batch, max_detections] holding detection_scores: A tensor of shape [batch, max_detections] holding
the predicted score for each box. the predicted score for each box.
detection_multiclass_scores: A tensor of shape [batch, max_detection,
num_classes] holding multiclass score for each box.
detection_classes: An integer tensor of shape [batch, max_detections] detection_classes: An integer tensor of shape [batch, max_detections]
containing the detected class for each box. containing the detected class for each box.
num_detections: An integer tensor of shape [batch] containing the num_detections: An integer tensor of shape [batch] containing the
...@@ -2883,7 +2885,8 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -2883,7 +2885,8 @@ class CenterNetMetaArch(model.DetectionModel):
top_k_feature_map_locations( top_k_feature_map_locations(
object_center_prob, max_pool_kernel_size=3, object_center_prob, max_pool_kernel_size=3,
k=self._center_params.max_box_predictions)) k=self._center_params.max_box_predictions))
multiclass_scores = tf.gather_nd(
object_center_prob, tf.stack([y_indices, x_indices], -1), batch_dims=1)
boxes_strided, classes, scores, num_detections = ( boxes_strided, classes, scores, num_detections = (
prediction_tensors_to_boxes( prediction_tensors_to_boxes(
detection_scores, y_indices, x_indices, channel_indices, detection_scores, y_indices, x_indices, channel_indices,
...@@ -2895,6 +2898,8 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -2895,6 +2898,8 @@ class CenterNetMetaArch(model.DetectionModel):
postprocess_dict = { postprocess_dict = {
fields.DetectionResultFields.detection_boxes: boxes, fields.DetectionResultFields.detection_boxes: boxes,
fields.DetectionResultFields.detection_scores: scores, fields.DetectionResultFields.detection_scores: scores,
fields.DetectionResultFields.detection_multiclass_scores:
multiclass_scores,
fields.DetectionResultFields.detection_classes: classes, fields.DetectionResultFields.detection_classes: classes,
fields.DetectionResultFields.num_detections: num_detections, fields.DetectionResultFields.num_detections: num_detections,
'detection_boxes_strided': boxes_strided 'detection_boxes_strided': boxes_strided
......
...@@ -1507,7 +1507,7 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase): ...@@ -1507,7 +1507,7 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase):
keypoint_offsets = np.zeros((1, 32, 32, 2), dtype=np.float32) keypoint_offsets = np.zeros((1, 32, 32, 2), dtype=np.float32)
keypoint_regression = np.random.randn(1, 32, 32, num_keypoints * 2) keypoint_regression = np.random.randn(1, 32, 32, num_keypoints * 2)
class_probs = np.zeros(10) class_probs = np.ones(10) * _logit(0.25)
class_probs[target_class_id] = _logit(0.75) class_probs[target_class_id] = _logit(0.75)
class_center[0, 16, 16] = class_probs class_center[0, 16, 16] = class_probs
height_width[0, 16, 16] = [5, 10] height_width[0, 16, 16] = [5, 10]
...@@ -1582,6 +1582,11 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase): ...@@ -1582,6 +1582,11 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase):
np.array([55, 46, 75, 86]) / 128.0) np.array([55, 46, 75, 86]) / 128.0)
self.assertAllClose(detections['detection_scores'][0], self.assertAllClose(detections['detection_scores'][0],
[.75, .5, .5, .5, .5]) [.75, .5, .5, .5, .5])
expected_multiclass_scores = [.25] * 10
expected_multiclass_scores[target_class_id] = .75
self.assertAllClose(expected_multiclass_scores,
detections['detection_multiclass_scores'][0][0])
# The output embedding extracted at the object center will be a 3-D array of # The output embedding extracted at the object center will be a 3-D array of
# shape [batch, num_boxes, embedding_size]. The valid predicted embedding # shape [batch, num_boxes, embedding_size]. The valid predicted embedding
# will be the first embedding in the first batch. It is a 1-D array of # will be the first embedding in the first batch. It is a 1-D array of
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment