Commit 94618de6 authored by Ronny Votel's avatar Ronny Votel Committed by TF Object Detection Team
Browse files

Surgical update to CenterNet keypoint postprocessing to use bounding boxes if...

Surgical update to CenterNet keypoint postprocessing to use bounding boxes if present. This is shown to reduce keypoint snapping across individuals.

PiperOrigin-RevId: 369432018
parent 3c46cfb5
...@@ -3484,6 +3484,7 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -3484,6 +3484,7 @@ class CenterNetMetaArch(model.DetectionModel):
fields.DetectionResultFields.num_detections: num_detections, fields.DetectionResultFields.num_detections: num_detections,
} }
boxes_strided = None
if self._od_params: if self._od_params:
boxes_strided = ( boxes_strided = (
prediction_tensors_to_boxes(y_indices, x_indices, prediction_tensors_to_boxes(y_indices, x_indices,
...@@ -3506,8 +3507,8 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -3506,8 +3507,8 @@ class CenterNetMetaArch(model.DetectionModel):
if len(self._kp_params_dict) == 1 and self._num_classes == 1: if len(self._kp_params_dict) == 1 and self._num_classes == 1:
(keypoints, keypoint_scores, (keypoints, keypoint_scores,
keypoint_depths) = self._postprocess_keypoints_single_class( keypoint_depths) = self._postprocess_keypoints_single_class(
prediction_dict, channel_indices, y_indices, x_indices, None, prediction_dict, channel_indices, y_indices, x_indices,
num_detections) boxes_strided, num_detections)
keypoints, keypoint_scores = ( keypoints, keypoint_scores = (
convert_strided_predictions_to_normalized_keypoints( convert_strided_predictions_to_normalized_keypoints(
keypoints, keypoint_scores, self._stride, true_image_shapes, keypoints, keypoint_scores, self._stride, true_image_shapes,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment