Commit f15f5995 authored by Ronny Votel's avatar Ronny Votel Committed by TF Object Detection Team
Browse files

Updating how keypoints are rescored in CenterNet. Previously, all keypoint...

Updating how keypoints are rescored in CenterNet. Previously, all keypoint scores were averaged. Now, any keypoints with confidence score 0 (e.g. out of frame keypoints) are not averaged.

PiperOrigin-RevId: 362383830
parent 30f93777
...@@ -43,10 +43,6 @@ NUM_SIZE_CHANNELS = 2 ...@@ -43,10 +43,6 @@ NUM_SIZE_CHANNELS = 2
# Error range for detecting peaks. # Error range for detecting peaks.
PEAK_EPSILON = 1e-6 PEAK_EPSILON = 1e-6
# Constants shared between all keypoint tasks.
UNMATCHED_KEYPOINT_SCORE = 0.1
KEYPOINT_CANDIDATE_SEARCH_SCALE = 0.3
class CenterNetFeatureExtractor(tf.keras.Model): class CenterNetFeatureExtractor(tf.keras.Model):
"""Base class for feature extractors for the CenterNet meta architecture. """Base class for feature extractors for the CenterNet meta architecture.
...@@ -3020,14 +3016,13 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -3020,14 +3016,13 @@ class CenterNetMetaArch(model.DetectionModel):
shape_utils.combined_static_and_dynamic_shape(keypoint_scores)) shape_utils.combined_static_and_dynamic_shape(keypoint_scores))
classes_tiled = tf.tile(classes[:, :, tf.newaxis], classes_tiled = tf.tile(classes[:, :, tf.newaxis],
multiples=[1, 1, total_num_keypoints]) multiples=[1, 1, total_num_keypoints])
# TODO(yuhuic): Investigate whether this function will reate subgraphs in # TODO(yuhuic): Investigate whether this function will create subgraphs in
# tflite that will cause the model to run slower at inference. # tflite that will cause the model to run slower at inference.
for kp_params in self._kp_params_dict.values(): for kp_params in self._kp_params_dict.values():
if not kp_params.rescore_instances: if not kp_params.rescore_instances:
continue continue
class_id = kp_params.class_id class_id = kp_params.class_id
keypoint_indices = kp_params.keypoint_indices keypoint_indices = kp_params.keypoint_indices
num_keypoints = len(keypoint_indices)
kpt_mask = tf.reduce_sum( kpt_mask = tf.reduce_sum(
tf.one_hot(keypoint_indices, depth=total_num_keypoints), axis=0) tf.one_hot(keypoint_indices, depth=total_num_keypoints), axis=0)
kpt_mask_tiled = tf.tile(kpt_mask[tf.newaxis, tf.newaxis, :], kpt_mask_tiled = tf.tile(kpt_mask[tf.newaxis, tf.newaxis, :],
...@@ -3037,7 +3032,12 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -3037,7 +3032,12 @@ class CenterNetMetaArch(model.DetectionModel):
kpt_mask_tiled == 1.0) kpt_mask_tiled == 1.0)
class_and_keypoint_mask_float = tf.cast(class_and_keypoint_mask, class_and_keypoint_mask_float = tf.cast(class_and_keypoint_mask,
dtype=tf.float32) dtype=tf.float32)
scores_for_class = (1./num_keypoints) * ( visible_keypoints = tf.math.greater(keypoint_scores, 0.0)
num_visible_keypoints = tf.reduce_sum(
class_and_keypoint_mask_float *
tf.cast(visible_keypoints, tf.float32), axis=-1)
num_visible_keypoints = tf.math.maximum(num_visible_keypoints, 1.0)
scores_for_class = (1./num_visible_keypoints) * (
tf.reduce_sum(class_and_keypoint_mask_float * tf.reduce_sum(class_and_keypoint_mask_float *
scores[:, :, tf.newaxis] * scores[:, :, tf.newaxis] *
keypoint_scores, axis=-1)) keypoint_scores, axis=-1))
......
...@@ -1432,6 +1432,7 @@ def get_fake_kp_params(num_candidates_per_keypoint=100, ...@@ -1432,6 +1432,7 @@ def get_fake_kp_params(num_candidates_per_keypoint=100,
keypoint_std_dev=[0.00001] * len(_KEYPOINT_INDICES), keypoint_std_dev=[0.00001] * len(_KEYPOINT_INDICES),
classification_loss=losses.WeightedSigmoidClassificationLoss(), classification_loss=losses.WeightedSigmoidClassificationLoss(),
localization_loss=losses.L1LocalizationLoss(), localization_loss=losses.L1LocalizationLoss(),
unmatched_keypoint_score=0.1,
keypoint_candidate_score_threshold=0.1, keypoint_candidate_score_threshold=0.1,
num_candidates_per_keypoint=num_candidates_per_keypoint, num_candidates_per_keypoint=num_candidates_per_keypoint,
per_keypoint_offset=per_keypoint_offset, per_keypoint_offset=per_keypoint_offset,
...@@ -1818,6 +1819,8 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase): ...@@ -1818,6 +1819,8 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase):
model = build_center_net_meta_arch() model = build_center_net_meta_arch()
max_detection = model._center_params.max_box_predictions max_detection = model._center_params.max_box_predictions
num_keypoints = len(model._kp_params_dict[_TASK_NAME].keypoint_indices) num_keypoints = len(model._kp_params_dict[_TASK_NAME].keypoint_indices)
unmatched_keypoint_score = (
model._kp_params_dict[_TASK_NAME].unmatched_keypoint_score)
class_center = np.zeros((1, 32, 32, 10), dtype=np.float32) class_center = np.zeros((1, 32, 32, 10), dtype=np.float32)
height_width = np.zeros((1, 32, 32, 2), dtype=np.float32) height_width = np.zeros((1, 32, 32, 2), dtype=np.float32)
...@@ -1938,7 +1941,7 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase): ...@@ -1938,7 +1941,7 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase):
expected_kpts_for_obj_0 = np.array( expected_kpts_for_obj_0 = np.array(
[[14., 14.], [14., 18.], [18., 14.], [17., 17.]]) / 32. [[14., 14.], [14., 18.], [18., 14.], [17., 17.]]) / 32.
expected_kpt_scores_for_obj_0 = np.array( expected_kpt_scores_for_obj_0 = np.array(
[0.9, 0.9, 0.9, cnma.UNMATCHED_KEYPOINT_SCORE]) [0.9, 0.9, 0.9, unmatched_keypoint_score])
np.testing.assert_allclose(detections['detection_keypoints'][0][0], np.testing.assert_allclose(detections['detection_keypoints'][0][0],
expected_kpts_for_obj_0, rtol=1e-6) expected_kpts_for_obj_0, rtol=1e-6)
np.testing.assert_allclose(detections['detection_keypoint_scores'][0][0], np.testing.assert_allclose(detections['detection_keypoint_scores'][0][0],
...@@ -2267,7 +2270,7 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase): ...@@ -2267,7 +2270,7 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase):
scores = tf.constant([[0.5, 0.75]], dtype=tf.float32) scores = tf.constant([[0.5, 0.75]], dtype=tf.float32)
keypoint_scores = tf.constant( keypoint_scores = tf.constant(
[ [
[[0.1, 0.2, 0.3, 0.4, 0.5], [[0.1, 0.0, 0.3, 0.4, 0.5],
[0.1, 0.2, 0.3, 0.4, 0.5]], [0.1, 0.2, 0.3, 0.4, 0.5]],
]) ])
new_scores = model._rescore_instances(classes, scores, keypoint_scores) new_scores = model._rescore_instances(classes, scores, keypoint_scores)
...@@ -2275,7 +2278,7 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase): ...@@ -2275,7 +2278,7 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase):
new_scores = self.execute_cpu(graph_fn, []) new_scores = self.execute_cpu(graph_fn, [])
expected_scores = np.array( expected_scores = np.array(
[[0.5, 0.75 * (0.1 + 0.2 + 0.3)/3]] [[0.5, 0.75 * (0.1 + 0.3)/2]]
) )
self.assertAllClose(expected_scores, new_scores) self.assertAllClose(expected_scores, new_scores)
......
...@@ -637,6 +637,8 @@ def _maybe_update_config_with_key_value(configs, key, value): ...@@ -637,6 +637,8 @@ def _maybe_update_config_with_key_value(configs, key, value):
_update_keypoint_candidate_score_threshold(configs["model"], value) _update_keypoint_candidate_score_threshold(configs["model"], value)
elif field_name == "rescore_instances": elif field_name == "rescore_instances":
_update_rescore_instances(configs["model"], value) _update_rescore_instances(configs["model"], value)
elif field_name == "unmatched_keypoint_score":
_update_unmatched_keypoint_score(configs["model"], value)
else: else:
return False return False
return True return True
...@@ -1199,3 +1201,16 @@ def _update_rescore_instances(model_config, should_rescore): ...@@ -1199,3 +1201,16 @@ def _update_rescore_instances(model_config, should_rescore):
tf.logging.warning("Ignoring config override key for " tf.logging.warning("Ignoring config override key for "
"rescore_instances since there are multiple keypoint " "rescore_instances since there are multiple keypoint "
"estimation tasks") "estimation tasks")
def _update_unmatched_keypoint_score(model_config, score):
meta_architecture = model_config.WhichOneof("model")
if meta_architecture == "center_net":
if len(model_config.center_net.keypoint_estimation_task) == 1:
kpt_estimation_task = model_config.center_net.keypoint_estimation_task[0]
kpt_estimation_task.unmatched_keypoint_score = score
else:
tf.logging.warning("Ignoring config override key for "
"unmatched_keypoint_score since there are multiple "
"keypoint estimation tasks")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment