Commit f15f5995 authored by Ronny Votel's avatar Ronny Votel Committed by TF Object Detection Team
Browse files

Updating how keypoints are rescored in CenterNet. Previously, all keypoint...

Updating how keypoints are rescored in CenterNet. Previously, all keypoint scores were averaged. Now, any keypoints with confidence score 0 (e.g. out of frame keypoints) are not averaged.

PiperOrigin-RevId: 362383830
parent 30f93777
......@@ -43,10 +43,6 @@ NUM_SIZE_CHANNELS = 2
# Error range for detecting peaks.
PEAK_EPSILON = 1e-6
# Constants shared between all keypoint tasks.
UNMATCHED_KEYPOINT_SCORE = 0.1
KEYPOINT_CANDIDATE_SEARCH_SCALE = 0.3
class CenterNetFeatureExtractor(tf.keras.Model):
"""Base class for feature extractors for the CenterNet meta architecture.
......@@ -3020,14 +3016,13 @@ class CenterNetMetaArch(model.DetectionModel):
shape_utils.combined_static_and_dynamic_shape(keypoint_scores))
classes_tiled = tf.tile(classes[:, :, tf.newaxis],
multiples=[1, 1, total_num_keypoints])
# TODO(yuhuic): Investigate whether this function will reate subgraphs in
# TODO(yuhuic): Investigate whether this function will create subgraphs in
# tflite that will cause the model to run slower at inference.
for kp_params in self._kp_params_dict.values():
if not kp_params.rescore_instances:
continue
class_id = kp_params.class_id
keypoint_indices = kp_params.keypoint_indices
num_keypoints = len(keypoint_indices)
kpt_mask = tf.reduce_sum(
tf.one_hot(keypoint_indices, depth=total_num_keypoints), axis=0)
kpt_mask_tiled = tf.tile(kpt_mask[tf.newaxis, tf.newaxis, :],
......@@ -3037,7 +3032,12 @@ class CenterNetMetaArch(model.DetectionModel):
kpt_mask_tiled == 1.0)
class_and_keypoint_mask_float = tf.cast(class_and_keypoint_mask,
dtype=tf.float32)
scores_for_class = (1./num_keypoints) * (
visible_keypoints = tf.math.greater(keypoint_scores, 0.0)
num_visible_keypoints = tf.reduce_sum(
class_and_keypoint_mask_float *
tf.cast(visible_keypoints, tf.float32), axis=-1)
num_visible_keypoints = tf.math.maximum(num_visible_keypoints, 1.0)
scores_for_class = (1./num_visible_keypoints) * (
tf.reduce_sum(class_and_keypoint_mask_float *
scores[:, :, tf.newaxis] *
keypoint_scores, axis=-1))
......
......@@ -1432,6 +1432,7 @@ def get_fake_kp_params(num_candidates_per_keypoint=100,
keypoint_std_dev=[0.00001] * len(_KEYPOINT_INDICES),
classification_loss=losses.WeightedSigmoidClassificationLoss(),
localization_loss=losses.L1LocalizationLoss(),
unmatched_keypoint_score=0.1,
keypoint_candidate_score_threshold=0.1,
num_candidates_per_keypoint=num_candidates_per_keypoint,
per_keypoint_offset=per_keypoint_offset,
......@@ -1818,6 +1819,8 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase):
model = build_center_net_meta_arch()
max_detection = model._center_params.max_box_predictions
num_keypoints = len(model._kp_params_dict[_TASK_NAME].keypoint_indices)
unmatched_keypoint_score = (
model._kp_params_dict[_TASK_NAME].unmatched_keypoint_score)
class_center = np.zeros((1, 32, 32, 10), dtype=np.float32)
height_width = np.zeros((1, 32, 32, 2), dtype=np.float32)
......@@ -1938,7 +1941,7 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase):
expected_kpts_for_obj_0 = np.array(
[[14., 14.], [14., 18.], [18., 14.], [17., 17.]]) / 32.
expected_kpt_scores_for_obj_0 = np.array(
[0.9, 0.9, 0.9, cnma.UNMATCHED_KEYPOINT_SCORE])
[0.9, 0.9, 0.9, unmatched_keypoint_score])
np.testing.assert_allclose(detections['detection_keypoints'][0][0],
expected_kpts_for_obj_0, rtol=1e-6)
np.testing.assert_allclose(detections['detection_keypoint_scores'][0][0],
......@@ -2267,7 +2270,7 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase):
scores = tf.constant([[0.5, 0.75]], dtype=tf.float32)
keypoint_scores = tf.constant(
[
[[0.1, 0.2, 0.3, 0.4, 0.5],
[[0.1, 0.0, 0.3, 0.4, 0.5],
[0.1, 0.2, 0.3, 0.4, 0.5]],
])
new_scores = model._rescore_instances(classes, scores, keypoint_scores)
......@@ -2275,7 +2278,7 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase):
new_scores = self.execute_cpu(graph_fn, [])
expected_scores = np.array(
[[0.5, 0.75 * (0.1 + 0.2 + 0.3)/3]]
[[0.5, 0.75 * (0.1 + 0.3)/2]]
)
self.assertAllClose(expected_scores, new_scores)
......
......@@ -637,6 +637,8 @@ def _maybe_update_config_with_key_value(configs, key, value):
_update_keypoint_candidate_score_threshold(configs["model"], value)
elif field_name == "rescore_instances":
_update_rescore_instances(configs["model"], value)
elif field_name == "unmatched_keypoint_score":
_update_unmatched_keypoint_score(configs["model"], value)
else:
return False
return True
......@@ -1199,3 +1201,16 @@ def _update_rescore_instances(model_config, should_rescore):
tf.logging.warning("Ignoring config override key for "
"rescore_instances since there are multiple keypoint "
"estimation tasks")
def _update_unmatched_keypoint_score(model_config, score):
meta_architecture = model_config.WhichOneof("model")
if meta_architecture == "center_net":
if len(model_config.center_net.keypoint_estimation_task) == 1:
kpt_estimation_task = model_config.center_net.keypoint_estimation_task[0]
kpt_estimation_task.unmatched_keypoint_score = score
else:
tf.logging.warning("Ignoring config override key for "
"unmatched_keypoint_score since there are multiple "
"keypoint estimation tasks")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment