Commit 0a17932d authored by Yu-hui Chen's avatar Yu-hui Chen Committed by TF Object Detection Team
Browse files

Removed the dependency of the keypoint task and object detection task in the

postprocessing logic.

PiperOrigin-RevId: 362814113
parent 2e77bb3e
...@@ -125,22 +125,24 @@ def change_coordinate_frame(keypoints, window, scope=None): ...@@ -125,22 +125,24 @@ def change_coordinate_frame(keypoints, window, scope=None):
return new_keypoints return new_keypoints
def keypoints_to_enclosing_bounding_boxes(keypoints): def keypoints_to_enclosing_bounding_boxes(keypoints, keypoints_axis=1):
"""Creates enclosing bounding boxes from keypoints. """Creates enclosing bounding boxes from keypoints.
Args: Args:
keypoints: a [num_instances, num_keypoints, 2] float32 tensor with keypoints keypoints: a [num_instances, num_keypoints, 2] float32 tensor with keypoints
in [y, x] format. in [y, x] format.
keypoints_axis: An integer indicating the axis that correspond to the
keypoint dimension.
Returns: Returns:
A [num_instances, 4] float32 tensor that tightly covers all the keypoints A [num_instances, 4] float32 tensor that tightly covers all the keypoints
for each instance. for each instance.
""" """
ymin = tf.math.reduce_min(keypoints[:, :, 0], axis=1) ymin = tf.math.reduce_min(keypoints[..., 0], axis=keypoints_axis)
xmin = tf.math.reduce_min(keypoints[:, :, 1], axis=1) xmin = tf.math.reduce_min(keypoints[..., 1], axis=keypoints_axis)
ymax = tf.math.reduce_max(keypoints[:, :, 0], axis=1) ymax = tf.math.reduce_max(keypoints[..., 0], axis=keypoints_axis)
xmax = tf.math.reduce_max(keypoints[:, :, 1], axis=1) xmax = tf.math.reduce_max(keypoints[..., 1], axis=keypoints_axis)
return tf.stack([ymin, xmin, ymax, xmax], axis=1) return tf.stack([ymin, xmin, ymax, xmax], axis=keypoints_axis)
def to_normalized_coordinates(keypoints, height, width, def to_normalized_coordinates(keypoints, height, width,
......
...@@ -116,6 +116,35 @@ class KeypointOpsTest(test_case.TestCase): ...@@ -116,6 +116,35 @@ class KeypointOpsTest(test_case.TestCase):
]) ])
self.assertAllClose(expected_bboxes, output) self.assertAllClose(expected_bboxes, output)
def test_keypoints_to_enclosing_bounding_boxes_axis2(self):
def graph_fn():
keypoints = tf.constant(
[
[ # Instance 0.
[5., 10.],
[3., 20.],
[8., 4.],
],
[ # Instance 1.
[2., 12.],
[0., 3.],
[5., 19.],
],
], dtype=tf.float32)
keypoints = tf.stack([keypoints, keypoints], axis=0)
bboxes = keypoint_ops.keypoints_to_enclosing_bounding_boxes(
keypoints, keypoints_axis=2)
return bboxes
output = self.execute(graph_fn, [])
expected_bboxes = np.array(
[
[3., 4., 8., 20.],
[0., 3., 5., 19.]
])
self.assertAllClose(expected_bboxes, output[0])
self.assertAllClose(expected_bboxes, output[1])
def test_to_normalized_coordinates(self): def test_to_normalized_coordinates(self):
def graph_fn(): def graph_fn():
keypoints = tf.constant([ keypoints = tf.constant([
......
...@@ -329,20 +329,15 @@ def top_k_feature_map_locations(feature_map, max_pool_kernel_size=3, k=100, ...@@ -329,20 +329,15 @@ def top_k_feature_map_locations(feature_map, max_pool_kernel_size=3, k=100,
return scores, y_indices, x_indices, channel_indices return scores, y_indices, x_indices, channel_indices
def prediction_tensors_to_boxes(detection_scores, y_indices, x_indices, def prediction_tensors_to_boxes(y_indices, x_indices, height_width_predictions,
channel_indices, height_width_predictions,
offset_predictions): offset_predictions):
"""Converts CenterNet class-center, offset and size predictions to boxes. """Converts CenterNet class-center, offset and size predictions to boxes.
Args: Args:
detection_scores: A [batch, num_boxes] float32 tensor with detection
scores in range [0, 1].
y_indices: A [batch, num_boxes] int32 tensor with y indices corresponding to y_indices: A [batch, num_boxes] int32 tensor with y indices corresponding to
object center locations (expressed in output coordinate frame). object center locations (expressed in output coordinate frame).
x_indices: A [batch, num_boxes] int32 tensor with x indices corresponding to x_indices: A [batch, num_boxes] int32 tensor with x indices corresponding to
object center locations (expressed in output coordinate frame). object center locations (expressed in output coordinate frame).
channel_indices: A [batch, num_boxes] int32 tensor with channel indices
corresponding to object classes.
height_width_predictions: A float tensor of shape [batch_size, height, height_width_predictions: A float tensor of shape [batch_size, height,
width, 2] representing the height and width of a box centered at each width, 2] representing the height and width of a box centered at each
pixel. pixel.
...@@ -353,13 +348,6 @@ def prediction_tensors_to_boxes(detection_scores, y_indices, x_indices, ...@@ -353,13 +348,6 @@ def prediction_tensors_to_boxes(detection_scores, y_indices, x_indices,
Returns: Returns:
detection_boxes: A tensor of shape [batch_size, num_boxes, 4] holding the detection_boxes: A tensor of shape [batch_size, num_boxes, 4] holding the
the raw bounding box coordinates of boxes. the raw bounding box coordinates of boxes.
detection_classes: An integer tensor of shape [batch_size, num_boxes]
indicating the predicted class for each box.
detection_scores: A float tensor of shape [batch_size, num_boxes] indicating
the score for each box.
num_detections: An integer tensor of shape [batch_size,] indicating the
number of boxes detected for each sample in the batch.
""" """
batch_size, num_boxes = _get_shape(y_indices, 2) batch_size, num_boxes = _get_shape(y_indices, 2)
...@@ -383,16 +371,12 @@ def prediction_tensors_to_boxes(detection_scores, y_indices, x_indices, ...@@ -383,16 +371,12 @@ def prediction_tensors_to_boxes(detection_scores, y_indices, x_indices,
heights, widths = tf.unstack(height_width, axis=2) heights, widths = tf.unstack(height_width, axis=2)
y_offsets, x_offsets = tf.unstack(offsets, axis=2) y_offsets, x_offsets = tf.unstack(offsets, axis=2)
detection_classes = channel_indices
num_detections = tf.reduce_sum(tf.to_int32(detection_scores > 0), axis=1)
boxes = tf.stack([y_indices + y_offsets - heights / 2.0, boxes = tf.stack([y_indices + y_offsets - heights / 2.0,
x_indices + x_offsets - widths / 2.0, x_indices + x_offsets - widths / 2.0,
y_indices + y_offsets + heights / 2.0, y_indices + y_offsets + heights / 2.0,
x_indices + x_offsets + widths / 2.0], axis=2) x_indices + x_offsets + widths / 2.0], axis=2)
return boxes, detection_classes, detection_scores, num_detections return boxes
def prediction_tensors_to_temporal_offsets( def prediction_tensors_to_temporal_offsets(
...@@ -753,7 +737,8 @@ def refine_keypoints(regressed_keypoints, ...@@ -753,7 +737,8 @@ def refine_keypoints(regressed_keypoints,
candidate_search_scale=0.3, candidate_search_scale=0.3,
candidate_ranking_mode='min_distance', candidate_ranking_mode='min_distance',
score_distance_offset=1e-6, score_distance_offset=1e-6,
keypoint_depth_candidates=None): keypoint_depth_candidates=None,
keypoint_score_threshold=0.1):
"""Refines regressed keypoints by snapping to the nearest candidate keypoints. """Refines regressed keypoints by snapping to the nearest candidate keypoints.
The initial regressed keypoints represent a full set of keypoints regressed The initial regressed keypoints represent a full set of keypoints regressed
...@@ -817,6 +802,8 @@ def refine_keypoints(regressed_keypoints, ...@@ -817,6 +802,8 @@ def refine_keypoints(regressed_keypoints,
keypoint_depth_candidates: (optional) A float tensor of shape keypoint_depth_candidates: (optional) A float tensor of shape
[batch_size, max_candidates, num_keypoints] indicating the depths for [batch_size, max_candidates, num_keypoints] indicating the depths for
keypoint candidates. keypoint candidates.
keypoint_score_threshold: float, The heatmap score threshold for
a keypoint to become a valid candidate.
Returns: Returns:
A tuple with: A tuple with:
...@@ -903,12 +890,10 @@ def refine_keypoints(regressed_keypoints, ...@@ -903,12 +890,10 @@ def refine_keypoints(regressed_keypoints,
keypoint_depth_candidates)) keypoint_depth_candidates))
if bboxes is None: if bboxes is None:
# Create bboxes from regressed keypoints. # Filter out the chosen candidate with score lower than unmatched
# Shape [batch_size * num_instances, 4]. # keypoint score.
regressed_keypoints_flattened = tf.reshape( mask = tf.cast(nearby_candidate_scores <
regressed_keypoints, [-1, num_keypoints, 2]) keypoint_score_threshold, tf.int32)
bboxes_flattened = keypoint_ops.keypoints_to_enclosing_bounding_boxes(
regressed_keypoints_flattened)
else: else:
bboxes_flattened = tf.reshape(bboxes, [-1, 4]) bboxes_flattened = tf.reshape(bboxes, [-1, 4])
...@@ -937,7 +922,7 @@ def refine_keypoints(regressed_keypoints, ...@@ -937,7 +922,7 @@ def refine_keypoints(regressed_keypoints,
# Filter out the chosen candidate with score lower than unmatched # Filter out the chosen candidate with score lower than unmatched
# keypoint score. # keypoint score.
tf.cast(nearby_candidate_scores < tf.cast(nearby_candidate_scores <
unmatched_keypoint_score, tf.int32) + keypoint_score_threshold, tf.int32) +
tf.cast(min_distances > search_radius, tf.int32)) tf.cast(min_distances > search_radius, tf.int32))
mask = mask > 0 mask = mask > 0
...@@ -3289,24 +3274,30 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -3289,24 +3274,30 @@ class CenterNetMetaArch(model.DetectionModel):
k=self._center_params.max_box_predictions)) k=self._center_params.max_box_predictions))
multiclass_scores = tf.gather_nd( multiclass_scores = tf.gather_nd(
object_center_prob, tf.stack([y_indices, x_indices], -1), batch_dims=1) object_center_prob, tf.stack([y_indices, x_indices], -1), batch_dims=1)
boxes_strided, classes, scores, num_detections = (
prediction_tensors_to_boxes(
detection_scores, y_indices, x_indices, channel_indices,
prediction_dict[BOX_SCALE][-1], prediction_dict[BOX_OFFSET][-1]))
boxes = convert_strided_predictions_to_normalized_boxes(
boxes_strided, self._stride, true_image_shapes)
num_detections = tf.reduce_sum(tf.to_int32(detection_scores > 0), axis=1)
postprocess_dict = { postprocess_dict = {
fields.DetectionResultFields.detection_boxes: boxes, fields.DetectionResultFields.detection_scores: detection_scores,
fields.DetectionResultFields.detection_scores: scores,
fields.DetectionResultFields.detection_multiclass_scores: fields.DetectionResultFields.detection_multiclass_scores:
multiclass_scores, multiclass_scores,
fields.DetectionResultFields.detection_classes: classes, fields.DetectionResultFields.detection_classes: channel_indices,
fields.DetectionResultFields.num_detections: num_detections, fields.DetectionResultFields.num_detections: num_detections,
'detection_boxes_strided': boxes_strided
} }
if self._od_params:
boxes_strided = (
prediction_tensors_to_boxes(y_indices, x_indices,
prediction_dict[BOX_SCALE][-1],
prediction_dict[BOX_OFFSET][-1]))
boxes = convert_strided_predictions_to_normalized_boxes(
boxes_strided, self._stride, true_image_shapes)
postprocess_dict.update({
fields.DetectionResultFields.detection_boxes: boxes,
'detection_boxes_strided': boxes_strided
})
if self._kp_params_dict: if self._kp_params_dict:
# If the model is trained to predict only one class of object and its # If the model is trained to predict only one class of object and its
# keypoint, we fall back to a simpler postprocessing function which uses # keypoint, we fall back to a simpler postprocessing function which uses
...@@ -3315,7 +3306,7 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -3315,7 +3306,7 @@ class CenterNetMetaArch(model.DetectionModel):
if len(self._kp_params_dict) == 1 and self._num_classes == 1: if len(self._kp_params_dict) == 1 and self._num_classes == 1:
(keypoints, keypoint_scores, (keypoints, keypoint_scores,
keypoint_depths) = self._postprocess_keypoints_single_class( keypoint_depths) = self._postprocess_keypoints_single_class(
prediction_dict, classes, y_indices, x_indices, boxes_strided, prediction_dict, channel_indices, y_indices, x_indices, None,
num_detections) num_detections)
keypoints, keypoint_scores = ( keypoints, keypoint_scores = (
convert_strided_predictions_to_normalized_keypoints( convert_strided_predictions_to_normalized_keypoints(
...@@ -3334,21 +3325,30 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -3334,21 +3325,30 @@ class CenterNetMetaArch(model.DetectionModel):
for kp_dict in self._kp_params_dict.values() for kp_dict in self._kp_params_dict.values()
]) ])
keypoints, keypoint_scores = self._postprocess_keypoints_multi_class( keypoints, keypoint_scores = self._postprocess_keypoints_multi_class(
prediction_dict, classes, y_indices, x_indices, prediction_dict, channel_indices, y_indices, x_indices,
boxes_strided, num_detections) None, num_detections)
keypoints, keypoint_scores = ( keypoints, keypoint_scores = (
convert_strided_predictions_to_normalized_keypoints( convert_strided_predictions_to_normalized_keypoints(
keypoints, keypoint_scores, self._stride, true_image_shapes, keypoints, keypoint_scores, self._stride, true_image_shapes,
clip_out_of_frame_keypoints=clip_keypoints)) clip_out_of_frame_keypoints=clip_keypoints))
# Update instance scores based on keypoints. # Update instance scores based on keypoints.
scores = self._rescore_instances(classes, scores, keypoint_scores) scores = self._rescore_instances(
channel_indices, detection_scores, keypoint_scores)
postprocess_dict.update({ postprocess_dict.update({
fields.DetectionResultFields.detection_scores: scores, fields.DetectionResultFields.detection_scores: scores,
fields.DetectionResultFields.detection_keypoints: keypoints, fields.DetectionResultFields.detection_keypoints: keypoints,
fields.DetectionResultFields.detection_keypoint_scores: fields.DetectionResultFields.detection_keypoint_scores:
keypoint_scores keypoint_scores
}) })
if self._od_params is None:
# Still output the box prediction by enclosing the keypoints for
# evaluation purpose.
boxes = keypoint_ops.keypoints_to_enclosing_bounding_boxes(
keypoints, keypoints_axis=2)
postprocess_dict.update({
fields.DetectionResultFields.detection_boxes: boxes,
})
if self._mask_params: if self._mask_params:
masks = tf.nn.sigmoid(prediction_dict[SEGMENTATION_HEATMAP][-1]) masks = tf.nn.sigmoid(prediction_dict[SEGMENTATION_HEATMAP][-1])
...@@ -3360,7 +3360,7 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -3360,7 +3360,7 @@ class CenterNetMetaArch(model.DetectionModel):
densepose_class_index = self._densepose_params.class_id densepose_class_index = self._densepose_params.class_id
instance_masks, surface_coords = ( instance_masks, surface_coords = (
convert_strided_predictions_to_instance_masks( convert_strided_predictions_to_instance_masks(
boxes, classes, masks, true_image_shapes, boxes, channel_indices, masks, true_image_shapes,
densepose_part_heatmap, densepose_surface_coords, densepose_part_heatmap, densepose_surface_coords,
stride=self._stride, mask_height=self._mask_params.mask_height, stride=self._stride, mask_height=self._mask_params.mask_height,
mask_width=self._mask_params.mask_width, mask_width=self._mask_params.mask_width,
...@@ -3601,7 +3601,7 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -3601,7 +3601,7 @@ class CenterNetMetaArch(model.DetectionModel):
""" """
total_num_keypoints = sum(len(kp_dict.keypoint_indices) for kp_dict total_num_keypoints = sum(len(kp_dict.keypoint_indices) for kp_dict
in self._kp_params_dict.values()) in self._kp_params_dict.values())
batch_size, max_detections, _ = _get_shape(boxes, 3) batch_size, max_detections = _get_shape(classes, 2)
kpt_coords_for_example_list = [] kpt_coords_for_example_list = []
kpt_scores_for_example_list = [] kpt_scores_for_example_list = []
for ex_ind in range(batch_size): for ex_ind in range(batch_size):
...@@ -3626,6 +3626,9 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -3626,6 +3626,9 @@ class CenterNetMetaArch(model.DetectionModel):
# Gather the feature map locations corresponding to the object class. # Gather the feature map locations corresponding to the object class.
y_indices_for_kpt_class = tf.gather(y_indices, instance_inds, axis=1) y_indices_for_kpt_class = tf.gather(y_indices, instance_inds, axis=1)
x_indices_for_kpt_class = tf.gather(x_indices, instance_inds, axis=1) x_indices_for_kpt_class = tf.gather(x_indices, instance_inds, axis=1)
if boxes is None:
boxes_for_kpt_class = None
else:
boxes_for_kpt_class = tf.gather(boxes, instance_inds, axis=1) boxes_for_kpt_class = tf.gather(boxes, instance_inds, axis=1)
# Postprocess keypoints and scores for class and single image. Shapes # Postprocess keypoints and scores for class and single image. Shapes
...@@ -3735,7 +3738,7 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -3735,7 +3738,7 @@ class CenterNetMetaArch(model.DetectionModel):
keypoint_depth_predictions = prediction_dict[get_keypoint_name( keypoint_depth_predictions = prediction_dict[get_keypoint_name(
task_name, KEYPOINT_DEPTH)][-1] task_name, KEYPOINT_DEPTH)][-1]
batch_size, _, _ = _get_shape(boxes, 3) batch_size, _ = _get_shape(classes, 2)
kpt_coords_for_example_list = [] kpt_coords_for_example_list = []
kpt_scores_for_example_list = [] kpt_scores_for_example_list = []
kpt_depths_for_example_list = [] kpt_depths_for_example_list = []
...@@ -3863,6 +3866,9 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -3863,6 +3866,9 @@ class CenterNetMetaArch(model.DetectionModel):
...] ...]
y_indices = y_indices[batch_index:batch_index+1, ...] y_indices = y_indices[batch_index:batch_index+1, ...]
x_indices = x_indices[batch_index:batch_index+1, ...] x_indices = x_indices[batch_index:batch_index+1, ...]
if boxes is None:
boxes_slice = None
else:
boxes_slice = boxes[batch_index:batch_index+1, ...] boxes_slice = boxes[batch_index:batch_index+1, ...]
# Gather the regressed keypoints. Final tensor has shape # Gather the regressed keypoints. Final tensor has shape
...@@ -3901,7 +3907,9 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -3901,7 +3907,9 @@ class CenterNetMetaArch(model.DetectionModel):
candidate_search_scale=kp_params.candidate_search_scale, candidate_search_scale=kp_params.candidate_search_scale,
candidate_ranking_mode=kp_params.candidate_ranking_mode, candidate_ranking_mode=kp_params.candidate_ranking_mode,
score_distance_offset=kp_params.score_distance_offset, score_distance_offset=kp_params.score_distance_offset,
keypoint_depth_candidates=keypoint_depth_candidates) keypoint_depth_candidates=keypoint_depth_candidates,
keypoint_score_threshold=(
kp_params.keypoint_candidate_score_threshold))
return refined_keypoints, refined_scores, refined_depths return refined_keypoints, refined_scores, refined_depths
......
...@@ -25,6 +25,7 @@ import numpy as np ...@@ -25,6 +25,7 @@ import numpy as np
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
from object_detection.builders import post_processing_builder from object_detection.builders import post_processing_builder
from object_detection.core import keypoint_ops
from object_detection.core import losses from object_detection.core import losses
from object_detection.core import preprocessor from object_detection.core import preprocessor
from object_detection.core import standard_fields as fields from object_detection.core import standard_fields as fields
...@@ -588,18 +589,15 @@ class CenterNetMetaArchHelpersTest(test_case.TestCase, parameterized.TestCase): ...@@ -588,18 +589,15 @@ class CenterNetMetaArchHelpersTest(test_case.TestCase, parameterized.TestCase):
hw_pred_tensor = tf.constant(hw_pred) hw_pred_tensor = tf.constant(hw_pred)
offset_pred_tensor = tf.constant(offset_pred) offset_pred_tensor = tf.constant(offset_pred)
detection_scores, y_indices, x_indices, channel_indices = ( _, y_indices, x_indices, _ = (
cnma.top_k_feature_map_locations( cnma.top_k_feature_map_locations(
class_pred_tensor, max_pool_kernel_size=3, k=2)) class_pred_tensor, max_pool_kernel_size=3, k=2))
boxes, classes, scores, num_dets = cnma.prediction_tensors_to_boxes( boxes = cnma.prediction_tensors_to_boxes(
detection_scores, y_indices, x_indices, channel_indices, y_indices, x_indices, hw_pred_tensor, offset_pred_tensor)
hw_pred_tensor, offset_pred_tensor) return boxes
return boxes, classes, scores, num_dets
boxes, classes, scores, num_dets = self.execute(graph_fn, [])
np.testing.assert_array_equal(num_dets, [2, 2, 2]) boxes = self.execute(graph_fn, [])
np.testing.assert_allclose( np.testing.assert_allclose(
[[-9, -8, 31, 52], [25, 35, 75, 85]], boxes[0]) [[-9, -8, 31, 52], [25, 35, 75, 85]], boxes[0])
...@@ -608,14 +606,6 @@ class CenterNetMetaArchHelpersTest(test_case.TestCase, parameterized.TestCase): ...@@ -608,14 +606,6 @@ class CenterNetMetaArchHelpersTest(test_case.TestCase, parameterized.TestCase):
np.testing.assert_allclose( np.testing.assert_allclose(
[[69.5, 74.5, 90.5, 99.5], [40, 75, 80, 105]], boxes[2]) [[69.5, 74.5, 90.5, 99.5], [40, 75, 80, 105]], boxes[2])
np.testing.assert_array_equal(classes[0], [1, 0])
np.testing.assert_array_equal(classes[1], [2, 1])
np.testing.assert_array_equal(classes[2], [0, 4])
np.testing.assert_allclose(scores[0], [.7, .55])
np.testing.assert_allclose(scores[1][:1], [.9])
np.testing.assert_allclose(scores[2], [1., .8])
def test_offset_prediction(self): def test_offset_prediction(self):
class_pred = np.zeros((3, 128, 128, 5), dtype=np.float32) class_pred = np.zeros((3, 128, 128, 5), dtype=np.float32)
...@@ -1068,12 +1058,19 @@ class CenterNetMetaArchHelpersTest(test_case.TestCase, parameterized.TestCase): ...@@ -1068,12 +1058,19 @@ class CenterNetMetaArchHelpersTest(test_case.TestCase, parameterized.TestCase):
keypoint_scores = tf.constant(keypoint_scores_np, dtype=tf.float32) keypoint_scores = tf.constant(keypoint_scores_np, dtype=tf.float32)
num_keypoint_candidates = tf.constant(num_keypoints_candidates_np, num_keypoint_candidates = tf.constant(num_keypoints_candidates_np,
dtype=tf.int32) dtype=tf.int32)
# The behavior of bboxes=None is different now. We provide the bboxes
# explicitly by using the regressed keypoints to create the same
# behavior.
regressed_keypoints_flattened = tf.reshape(
regressed_keypoints, [-1, 3, 2])
bboxes_flattened = keypoint_ops.keypoints_to_enclosing_bounding_boxes(
regressed_keypoints_flattened)
(refined_keypoints, refined_scores, _) = cnma.refine_keypoints( (refined_keypoints, refined_scores, _) = cnma.refine_keypoints(
regressed_keypoints, regressed_keypoints,
keypoint_candidates, keypoint_candidates,
keypoint_scores, keypoint_scores,
num_keypoint_candidates, num_keypoint_candidates,
bboxes=None, bboxes=bboxes_flattened,
unmatched_keypoint_score=unmatched_keypoint_score, unmatched_keypoint_score=unmatched_keypoint_score,
box_scale=1.2, box_scale=1.2,
candidate_search_scale=0.3, candidate_search_scale=0.3,
...@@ -1144,6 +1141,85 @@ class CenterNetMetaArchHelpersTest(test_case.TestCase, parameterized.TestCase): ...@@ -1144,6 +1141,85 @@ class CenterNetMetaArchHelpersTest(test_case.TestCase, parameterized.TestCase):
np.testing.assert_allclose(expected_refined_keypoints, refined_keypoints) np.testing.assert_allclose(expected_refined_keypoints, refined_keypoints)
np.testing.assert_allclose(expected_refined_scores, refined_scores) np.testing.assert_allclose(expected_refined_scores, refined_scores)
def test_refine_keypoints_without_bbox(self):
regressed_keypoints_np = np.array(
[
# Example 0.
[
[[2.0, 2.0], [6.0, 10.0], [14.0, 7.0]], # Instance 0.
[[0.0, 6.0], [3.0, 3.0], [5.0, 7.0]], # Instance 1.
],
], dtype=np.float32)
keypoint_candidates_np = np.array(
[
# Example 0.
[
[[2.0, 2.5], [6.0, 10.5], [4.0, 7.0]], # Candidate 0.
[[1.0, 8.0], [0.0, 0.0], [2.0, 2.0]], # Candidate 1.
[[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]], # Candidate 2.
],
], dtype=np.float32)
keypoint_scores_np = np.array(
[
# Example 0.
[
[0.8, 0.9, 1.0], # Candidate 0.
[0.6, 0.1, 0.9], # Candidate 1.
[0.0, 0.0, 0.0], # Candidate 1.
],
], dtype=np.float32)
num_keypoints_candidates_np = np.array(
[
# Example 0.
[2, 2, 2],
], dtype=np.int32)
unmatched_keypoint_score = 0.1
def graph_fn():
regressed_keypoints = tf.constant(
regressed_keypoints_np, dtype=tf.float32)
keypoint_candidates = tf.constant(
keypoint_candidates_np, dtype=tf.float32)
keypoint_scores = tf.constant(keypoint_scores_np, dtype=tf.float32)
num_keypoint_candidates = tf.constant(num_keypoints_candidates_np,
dtype=tf.int32)
(refined_keypoints, refined_scores, _) = cnma.refine_keypoints(
regressed_keypoints,
keypoint_candidates,
keypoint_scores,
num_keypoint_candidates,
bboxes=None,
unmatched_keypoint_score=unmatched_keypoint_score,
box_scale=1.2,
candidate_search_scale=0.3,
candidate_ranking_mode='min_distance')
return refined_keypoints, refined_scores
refined_keypoints, refined_scores = self.execute(graph_fn, [])
# The expected refined keypoints pick the ones that are closest to the
# regressed keypoint locations without filtering out the candidates which
# are outside of the bounding box.
expected_refined_keypoints = np.array(
[
# Example 0.
[
[[2.0, 2.5], [6.0, 10.5], [4.0, 7.0]], # Instance 0.
[[1.0, 8.0], [0.0, 0.0], [4.0, 7.0]], # Instance 1.
],
], dtype=np.float32)
expected_refined_scores = np.array(
[
# Example 0.
[
[0.8, 0.9, 1.0], # Instance 0.
[0.6, 0.1, 1.0], # Instance 1.
],
], dtype=np.float32)
np.testing.assert_allclose(expected_refined_keypoints, refined_keypoints)
np.testing.assert_allclose(expected_refined_scores, refined_scores)
@parameterized.parameters({'predict_depth': True}, {'predict_depth': False}) @parameterized.parameters({'predict_depth': True}, {'predict_depth': False})
def test_refine_keypoints_with_bboxes(self, predict_depth): def test_refine_keypoints_with_bboxes(self, predict_depth):
regressed_keypoints_np = np.array( regressed_keypoints_np = np.array(
...@@ -1489,7 +1565,8 @@ def build_center_net_meta_arch(build_resnet=False, ...@@ -1489,7 +1565,8 @@ def build_center_net_meta_arch(build_resnet=False,
per_keypoint_offset=False, per_keypoint_offset=False,
predict_depth=False, predict_depth=False,
per_keypoint_depth=False, per_keypoint_depth=False,
peak_radius=0): peak_radius=0,
keypoint_only=False):
"""Builds the CenterNet meta architecture.""" """Builds the CenterNet meta architecture."""
if build_resnet: if build_resnet:
feature_extractor = ( feature_extractor = (
...@@ -1522,7 +1599,23 @@ def build_center_net_meta_arch(build_resnet=False, ...@@ -1522,7 +1599,23 @@ def build_center_net_meta_arch(build_resnet=False,
non_max_suppression_fn, _ = post_processing_builder.build( non_max_suppression_fn, _ = post_processing_builder.build(
post_processing_proto) post_processing_proto)
if detection_only: if keypoint_only:
num_candidates_per_keypoint = 100 if max_box_predictions > 1 else 1
return cnma.CenterNetMetaArch(
is_training=True,
add_summaries=False,
num_classes=num_classes,
feature_extractor=feature_extractor,
image_resizer_fn=image_resizer_fn,
object_center_params=get_fake_center_params(max_box_predictions),
keypoint_params_dict={
_TASK_NAME:
get_fake_kp_params(num_candidates_per_keypoint,
per_keypoint_offset, predict_depth,
per_keypoint_depth, peak_radius)
},
non_max_suppression_fn=non_max_suppression_fn)
elif detection_only:
return cnma.CenterNetMetaArch( return cnma.CenterNetMetaArch(
is_training=True, is_training=True,
add_summaries=False, add_summaries=False,
...@@ -1825,7 +1918,8 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase): ...@@ -1825,7 +1918,8 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase):
class_center = np.zeros((1, 32, 32, 10), dtype=np.float32) class_center = np.zeros((1, 32, 32, 10), dtype=np.float32)
height_width = np.zeros((1, 32, 32, 2), dtype=np.float32) height_width = np.zeros((1, 32, 32, 2), dtype=np.float32)
offset = np.zeros((1, 32, 32, 2), dtype=np.float32) offset = np.zeros((1, 32, 32, 2), dtype=np.float32)
keypoint_heatmaps = np.zeros((1, 32, 32, num_keypoints), dtype=np.float32) keypoint_heatmaps = np.ones(
(1, 32, 32, num_keypoints), dtype=np.float32) * _logit(0.001)
keypoint_offsets = np.zeros((1, 32, 32, 2), dtype=np.float32) keypoint_offsets = np.zeros((1, 32, 32, 2), dtype=np.float32)
keypoint_regression = np.random.randn(1, 32, 32, num_keypoints * 2) keypoint_regression = np.random.randn(1, 32, 32, num_keypoints * 2)
...@@ -1971,6 +2065,66 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase): ...@@ -1971,6 +2065,66 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase):
detections['detection_surface_coords'][0, 0, :, :], detections['detection_surface_coords'][0, 0, :, :],
np.zeros_like(detections['detection_surface_coords'][0, 0, :, :])) np.zeros_like(detections['detection_surface_coords'][0, 0, :, :]))
def test_postprocess_kpts_no_od(self):
"""Test the postprocess function."""
target_class_id = 1
model = build_center_net_meta_arch(keypoint_only=True)
max_detection = model._center_params.max_box_predictions
num_keypoints = len(model._kp_params_dict[_TASK_NAME].keypoint_indices)
class_center = np.zeros((1, 32, 32, 10), dtype=np.float32)
keypoint_heatmaps = np.zeros((1, 32, 32, num_keypoints), dtype=np.float32)
keypoint_offsets = np.zeros((1, 32, 32, 2), dtype=np.float32)
keypoint_regression = np.random.randn(1, 32, 32, num_keypoints * 2)
class_probs = np.ones(10) * _logit(0.25)
class_probs[target_class_id] = _logit(0.75)
class_center[0, 16, 16] = class_probs
keypoint_regression[0, 16, 16] = [
-1., -1.,
-1., 1.,
1., -1.,
1., 1.]
keypoint_heatmaps[0, 14, 14, 0] = _logit(0.9)
keypoint_heatmaps[0, 14, 18, 1] = _logit(0.9)
keypoint_heatmaps[0, 18, 14, 2] = _logit(0.9)
keypoint_heatmaps[0, 18, 18, 3] = _logit(0.05) # Note the low score.
class_center = tf.constant(class_center)
keypoint_heatmaps = tf.constant(keypoint_heatmaps, dtype=tf.float32)
keypoint_offsets = tf.constant(keypoint_offsets, dtype=tf.float32)
keypoint_regression = tf.constant(keypoint_regression, dtype=tf.float32)
prediction_dict = {
cnma.OBJECT_CENTER: [class_center],
cnma.get_keypoint_name(_TASK_NAME, cnma.KEYPOINT_HEATMAP):
[keypoint_heatmaps],
cnma.get_keypoint_name(_TASK_NAME, cnma.KEYPOINT_OFFSET):
[keypoint_offsets],
cnma.get_keypoint_name(_TASK_NAME, cnma.KEYPOINT_REGRESSION):
[keypoint_regression],
}
# def graph_fn():
detections = model.postprocess(prediction_dict,
tf.constant([[128, 128, 3]]))
# return detections
# detections = self.execute_cpu(graph_fn, [])
self.assertAllClose(detections['detection_scores'][0],
[.75, .5, .5, .5, .5])
expected_multiclass_scores = [.25] * 10
expected_multiclass_scores[target_class_id] = .75
self.assertAllClose(expected_multiclass_scores,
detections['detection_multiclass_scores'][0][0])
self.assertEqual(detections['detection_classes'][0, 0], target_class_id)
self.assertEqual(detections['num_detections'], [5])
self.assertAllEqual([1, max_detection, num_keypoints, 2],
detections['detection_keypoints'].shape)
self.assertAllEqual([1, max_detection, num_keypoints],
detections['detection_keypoint_scores'].shape)
def test_non_max_suppression(self): def test_non_max_suppression(self):
"""Tests application of NMS on CenterNet detections.""" """Tests application of NMS on CenterNet detections."""
target_class_id = 1 target_class_id = 1
...@@ -2149,7 +2303,8 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase): ...@@ -2149,7 +2303,8 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase):
class_center = np.zeros((1, 32, 32, 1), dtype=np.float32) class_center = np.zeros((1, 32, 32, 1), dtype=np.float32)
height_width = np.zeros((1, 32, 32, 2), dtype=np.float32) height_width = np.zeros((1, 32, 32, 2), dtype=np.float32)
offset = np.zeros((1, 32, 32, 2), dtype=np.float32) offset = np.zeros((1, 32, 32, 2), dtype=np.float32)
keypoint_heatmaps = np.zeros((1, 32, 32, num_keypoints), dtype=np.float32) keypoint_heatmaps = np.ones(
(1, 32, 32, num_keypoints), dtype=np.float32) * _logit(0.001)
keypoint_offsets = np.zeros((1, 32, 32, 2), dtype=np.float32) keypoint_offsets = np.zeros((1, 32, 32, 2), dtype=np.float32)
keypoint_regression = np.random.randn(1, 32, 32, num_keypoints * 2) keypoint_regression = np.random.randn(1, 32, 32, num_keypoints * 2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment