Commit 481cf8da authored by Vivek Rathod's avatar Vivek Rathod Committed by TF Object Detection Team
Browse files

Output multiclass scores from post-process.

PiperOrigin-RevId: 336090589
parent 92752da2
......@@ -2856,6 +2856,8 @@ class CenterNetMetaArch(model.DetectionModel):
feature extractor's final layer output.
detection_scores: A tensor of shape [batch, max_detections] holding
the predicted score for each box.
detection_multiclass_scores: A tensor of shape [batch, max_detection,
num_classes] holding multiclass score for each box.
detection_classes: An integer tensor of shape [batch, max_detections]
containing the detected class for each box.
num_detections: An integer tensor of shape [batch] containing the
......@@ -2883,7 +2885,8 @@ class CenterNetMetaArch(model.DetectionModel):
top_k_feature_map_locations(
object_center_prob, max_pool_kernel_size=3,
k=self._center_params.max_box_predictions))
multiclass_scores = tf.gather_nd(
object_center_prob, tf.stack([y_indices, x_indices], -1), batch_dims=1)
boxes_strided, classes, scores, num_detections = (
prediction_tensors_to_boxes(
detection_scores, y_indices, x_indices, channel_indices,
......@@ -2895,6 +2898,8 @@ class CenterNetMetaArch(model.DetectionModel):
postprocess_dict = {
fields.DetectionResultFields.detection_boxes: boxes,
fields.DetectionResultFields.detection_scores: scores,
fields.DetectionResultFields.detection_multiclass_scores:
multiclass_scores,
fields.DetectionResultFields.detection_classes: classes,
fields.DetectionResultFields.num_detections: num_detections,
'detection_boxes_strided': boxes_strided
......
......@@ -1507,7 +1507,7 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase):
keypoint_offsets = np.zeros((1, 32, 32, 2), dtype=np.float32)
keypoint_regression = np.random.randn(1, 32, 32, num_keypoints * 2)
class_probs = np.zeros(10)
class_probs = np.ones(10) * _logit(0.25)
class_probs[target_class_id] = _logit(0.75)
class_center[0, 16, 16] = class_probs
height_width[0, 16, 16] = [5, 10]
......@@ -1582,6 +1582,11 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase):
np.array([55, 46, 75, 86]) / 128.0)
self.assertAllClose(detections['detection_scores'][0],
[.75, .5, .5, .5, .5])
expected_multiclass_scores = [.25] * 10
expected_multiclass_scores[target_class_id] = .75
self.assertAllClose(expected_multiclass_scores,
detections['detection_multiclass_scores'][0][0])
# The output embedding extracted at the object center will be a 3-D array of
# shape [batch, num_boxes, embedding_size]. The valid predicted embedding
# will be the first embedding in the first batch. It is a 1-D array of
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment