Commit 87769dc6 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower Committed by TF Object Detection Team
Browse files

Ensure that CenterNetMetaArch can process keypoints even when no instances are present.

PiperOrigin-RevId: 399303089
parent fcb152bf
......@@ -1585,6 +1585,13 @@ def _gather_candidates_at_indices(keypoint_candidates,
gathered_keypoint_candidates = tf.transpose(
nearby_candidate_coords_transposed, [0, 2, 1, 3])
# The reshape operation above may result in a singleton last dimension, but
# downstream code requires it to always be at least 2-valued.
original_shape = tf.shape(gathered_keypoint_candidates)
new_shape = tf.concat((original_shape[:3],
[tf.maximum(original_shape[3], 2)]), 0)
gathered_keypoint_candidates = tf.reshape(gathered_keypoint_candidates,
new_shape)
gathered_keypoint_scores = tf.transpose(nearby_candidate_scores_transposed,
[0, 2, 1])
......
......@@ -1242,6 +1242,44 @@ class CenterNetMetaArchHelpersTest(test_case.TestCase, parameterized.TestCase):
np.testing.assert_allclose(expected_refined_keypoints, refined_keypoints)
np.testing.assert_allclose(expected_refined_scores, refined_scores)
def test_refine_keypoints_with_empty_regressed_keypoints(self):
regressed_keypoints_np = np.zeros((1, 0, 2, 2), dtype=np.float32)
keypoint_candidates_np = np.ones((1, 1, 2, 2), dtype=np.float32)
keypoint_scores_np = np.ones((1, 1, 2), dtype=np.float32)
num_keypoints_candidates_np = np.ones((1, 1), dtype=np.int32)
unmatched_keypoint_score = 0.1
def graph_fn():
regressed_keypoints = tf.constant(
regressed_keypoints_np, dtype=tf.float32)
keypoint_candidates = tf.constant(
keypoint_candidates_np, dtype=tf.float32)
keypoint_scores = tf.constant(keypoint_scores_np, dtype=tf.float32)
num_keypoint_candidates = tf.constant(num_keypoints_candidates_np,
dtype=tf.int32)
# The behavior of bboxes=None is different now. We provide the bboxes
# explicitly by using the regressed keypoints to create the same
# behavior.
regressed_keypoints_flattened = tf.reshape(
regressed_keypoints, [-1, 3, 2])
bboxes_flattened = keypoint_ops.keypoints_to_enclosing_bounding_boxes(
regressed_keypoints_flattened)
(refined_keypoints, refined_scores, _) = cnma.refine_keypoints(
regressed_keypoints,
keypoint_candidates,
keypoint_scores,
num_keypoint_candidates,
bboxes=bboxes_flattened,
unmatched_keypoint_score=unmatched_keypoint_score,
box_scale=1.2,
candidate_search_scale=0.3,
candidate_ranking_mode='min_distance')
return refined_keypoints, refined_scores
refined_keypoints, refined_scores = self.execute(graph_fn, [])
self.assertEqual(refined_keypoints.shape, (1, 0, 2, 2))
self.assertEqual(refined_scores.shape, (1, 0, 2))
def test_refine_keypoints_without_bbox(self):
regressed_keypoints_np = np.array(
[
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment