Commit ca0eb4a6 authored by Yu-hui Chen's avatar Yu-hui Chen Committed by TF Object Detection Team
Browse files

Refactored the keypoint postprocessing logics to reduce duplicated codes and

added a new operating mode for single-instance prediction.

1) Refactored the _postprocess_keypoints_for_class_and_image function such that
it can be reused by single/multi class keypoint tasks.
2) Removed the "mod" operator to make the model compatible with WASM.

PiperOrigin-RevId: 345468250
parent e3f8ea22
...@@ -942,9 +942,12 @@ def row_col_channel_indices_from_flattened_indices(indices, num_cols, ...@@ -942,9 +942,12 @@ def row_col_channel_indices_from_flattened_indices(indices, num_cols,
indices. indices.
""" """
# Avoid using mod operator to make the ops more easy to be compatible with
# different environments, e.g. WASM.
row_indices = (indices // num_channels) // num_cols row_indices = (indices // num_channels) // num_cols
col_indices = (indices // num_channels) % num_cols col_indices = (indices // num_channels) - row_indices * num_cols
channel_indices = indices % num_channels channel_indices_temp = indices // num_channels
channel_indices = indices - channel_indices_temp * num_channels
return row_indices, col_indices, channel_indices return row_indices, col_indices, channel_indices
...@@ -2925,10 +2928,7 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -2925,10 +2928,7 @@ class CenterNetMetaArch(model.DetectionModel):
# keypoint, we fall back to a simpler postprocessing function which uses # keypoint, we fall back to a simpler postprocessing function which uses
# the ops that are supported by tf.lite on GPU. # the ops that are supported by tf.lite on GPU.
if len(self._kp_params_dict) == 1 and self._num_classes == 1: if len(self._kp_params_dict) == 1 and self._num_classes == 1:
# keypoints, keypoint_scores = self._postprocess_keypoints_simple( keypoints, keypoint_scores = self._postprocess_keypoints_single_class(
# prediction_dict, classes, y_indices, x_indices,
# boxes_strided, num_detections)
keypoints, keypoint_scores = self._postprocess_keypoints_simple(
prediction_dict, classes, y_indices, x_indices, prediction_dict, classes, y_indices, x_indices,
boxes_strided, num_detections) boxes_strided, num_detections)
# The map_fn used to clip out of frame keypoints creates issues when # The map_fn used to clip out of frame keypoints creates issues when
...@@ -2939,7 +2939,7 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -2939,7 +2939,7 @@ class CenterNetMetaArch(model.DetectionModel):
keypoints, keypoint_scores, self._stride, true_image_shapes, keypoints, keypoint_scores, self._stride, true_image_shapes,
clip_out_of_frame_keypoints=False)) clip_out_of_frame_keypoints=False))
else: else:
keypoints, keypoint_scores = self._postprocess_keypoints( keypoints, keypoint_scores = self._postprocess_keypoints_multi_class(
prediction_dict, classes, y_indices, x_indices, prediction_dict, classes, y_indices, x_indices,
boxes_strided, num_detections) boxes_strided, num_detections)
keypoints, keypoint_scores = ( keypoints, keypoint_scores = (
...@@ -3014,10 +3014,18 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -3014,10 +3014,18 @@ class CenterNetMetaArch(model.DetectionModel):
return embeddings return embeddings
def _postprocess_keypoints(self, prediction_dict, classes, y_indices, def _postprocess_keypoints_multi_class(self, prediction_dict, classes,
x_indices, boxes, num_detections): y_indices, x_indices, boxes,
num_detections):
"""Performs postprocessing on keypoint predictions. """Performs postprocessing on keypoint predictions.
This is the most general keypoint postprocessing function which supports
multiple keypoint tasks (e.g. human and dog keypoints) and multiple object
detection classes. Note that it is the most expensive postprocessing logics
and is currently not tf.lite/tf.js compatible. See
_postprocess_keypoints_single_class if you plan to export the model in more
portable format.
Args: Args:
prediction_dict: a dictionary holding predicted tensors, returned from the prediction_dict: a dictionary holding predicted tensors, returned from the
predict() method. This dictionary should contain keypoint prediction predict() method. This dictionary should contain keypoint prediction
...@@ -3060,11 +3068,15 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -3060,11 +3068,15 @@ class CenterNetMetaArch(model.DetectionModel):
classes, num_detections, ex_ind, kp_params.class_id) classes, num_detections, ex_ind, kp_params.class_id)
num_ind = _get_shape(instance_inds, 1) num_ind = _get_shape(instance_inds, 1)
def true_fn( def true_fn(keypoint_heatmap, keypoint_offsets, keypoint_regression,
keypoint_heatmap, keypoint_offsets, keypoint_regression, classes, y_indices, x_indices, boxes, instance_inds, ex_ind,
classes, y_indices, x_indices, boxes, instance_inds, kp_params):
ex_ind, kp_params):
"""Logics to execute when instance_inds is not an empty set.""" """Logics to execute when instance_inds is not an empty set."""
# Gather the feature map locations corresponding to the object class.
y_indices_for_kpt_class = tf.gather(y_indices, instance_inds, axis=1)
x_indices_for_kpt_class = tf.gather(x_indices, instance_inds, axis=1)
boxes_for_kpt_class = tf.gather(boxes, instance_inds, axis=1)
# Postprocess keypoints and scores for class and single image. Shapes # Postprocess keypoints and scores for class and single image. Shapes
# are [1, num_instances_i, num_keypoints_i, 2] and # are [1, num_instances_i, num_keypoints_i, 2] and
# [1, num_instances_i, num_keypoints_i], respectively. Note that # [1, num_instances_i, num_keypoints_i], respectively. Note that
...@@ -3073,15 +3085,17 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -3073,15 +3085,17 @@ class CenterNetMetaArch(model.DetectionModel):
kpt_coords_for_class, kpt_scores_for_class = ( kpt_coords_for_class, kpt_scores_for_class = (
self._postprocess_keypoints_for_class_and_image( self._postprocess_keypoints_for_class_and_image(
keypoint_heatmap, keypoint_offsets, keypoint_regression, keypoint_heatmap, keypoint_offsets, keypoint_regression,
classes, y_indices, x_indices, boxes, instance_inds, classes, y_indices_for_kpt_class, x_indices_for_kpt_class,
ex_ind, kp_params)) boxes_for_kpt_class, ex_ind, kp_params))
# Expand keypoint dimension (with padding) so that coordinates and # Expand keypoint dimension (with padding) so that coordinates and
# scores have shape [1, num_instances_i, num_total_keypoints, 2] and # scores have shape [1, num_instances_i, num_total_keypoints, 2] and
# [1, num_instances_i, num_total_keypoints], respectively. # [1, num_instances_i, num_total_keypoints], respectively.
kpts_coords_for_class_padded, kpt_scores_for_class_padded = ( kpts_coords_for_class_padded, kpt_scores_for_class_padded = (
_pad_to_full_keypoint_dim( _pad_to_full_keypoint_dim(kpt_coords_for_class,
kpt_coords_for_class, kpt_scores_for_class, kpt_scores_for_class,
kp_params.keypoint_indices, total_num_keypoints)) kp_params.keypoint_indices,
total_num_keypoints))
return kpts_coords_for_class_padded, kpt_scores_for_class_padded return kpts_coords_for_class_padded, kpt_scores_for_class_padded
def false_fn(): def false_fn():
...@@ -3135,9 +3149,10 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -3135,9 +3149,10 @@ class CenterNetMetaArch(model.DetectionModel):
return keypoints, keypoint_scores return keypoints, keypoint_scores
def _postprocess_keypoints_simple(self, prediction_dict, classes, y_indices, def _postprocess_keypoints_single_class(self, prediction_dict, classes,
x_indices, boxes, num_detections): y_indices, x_indices, boxes,
"""Performs postprocessing on keypoint predictions (one class only). num_detections):
"""Performs postprocessing on keypoint predictions (single class only).
This function handles the special case of keypoint task that the model This function handles the special case of keypoint task that the model
predicts only one class of the bounding box/keypoint (e.g. person). By the predicts only one class of the bounding box/keypoint (e.g. person). By the
...@@ -3186,9 +3201,9 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -3186,9 +3201,9 @@ class CenterNetMetaArch(model.DetectionModel):
# are [1, max_detections, num_keypoints, 2] and # are [1, max_detections, num_keypoints, 2] and
# [1, max_detections, num_keypoints], respectively. # [1, max_detections, num_keypoints], respectively.
kpt_coords_for_class, kpt_scores_for_class = ( kpt_coords_for_class, kpt_scores_for_class = (
self._postprocess_keypoints_for_class_and_image_simple( self._postprocess_keypoints_for_class_and_image(
keypoint_heatmap, keypoint_offsets, keypoint_regression, keypoint_heatmap, keypoint_offsets, keypoint_regression, classes,
classes, y_indices, x_indices, boxes, ex_ind, kp_params)) y_indices, x_indices, boxes, ex_ind, kp_params))
kpt_coords_for_example_list.append(kpt_coords_for_class) kpt_coords_for_example_list.append(kpt_coords_for_class)
kpt_scores_for_example_list.append(kpt_scores_for_class) kpt_scores_for_example_list.append(kpt_scores_for_class)
...@@ -3233,114 +3248,10 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -3233,114 +3248,10 @@ class CenterNetMetaArch(model.DetectionModel):
return tf.cast(instance_inds, tf.int32) return tf.cast(instance_inds, tf.int32)
def _postprocess_keypoints_for_class_and_image( def _postprocess_keypoints_for_class_and_image(
self, keypoint_heatmap, keypoint_offsets, keypoint_regression, classes,
y_indices, x_indices, boxes, indices_with_kpt_class, batch_index,
kp_params):
"""Postprocess keypoints for a single image and class.
This function performs the following postprocessing operations on a single
image and single keypoint class:
- Converts keypoints scores to range [0, 1] with sigmoid.
- Determines the detections that correspond to the specified keypoint class.
- Gathers the regressed keypoints at the detection (i.e. box) centers.
- Gathers keypoint candidates from the keypoint heatmaps.
- Snaps regressed keypoints to nearby keypoint candidates.
Args:
keypoint_heatmap: A [batch_size, height, width, num_keypoints] float32
tensor with keypoint heatmaps.
keypoint_offsets: A [batch_size, height, width, 2] float32 tensor with
local offsets to keypoint centers.
keypoint_regression: A [batch_size, height, width, 2 * num_keypoints]
float32 tensor with regressed offsets to all keypoints.
classes: A [batch_size, max_detections] int tensor with class indices for
all detected objects.
y_indices: A [batch_size, max_detections] int tensor with y indices for
all object centers.
x_indices: A [batch_size, max_detections] int tensor with x indices for
all object centers.
boxes: A [batch_size, max_detections, 4] float32 tensor with detected
boxes in the output (strided) frame.
indices_with_kpt_class: A [num_instances] int tensor where each element
indicates the instance location within the `classes` tensor. This is
useful to associate the refined keypoints with the original detections
(i.e. boxes)
batch_index: An integer specifying the index for an example in the batch.
kp_params: A `KeypointEstimationParams` object with parameters for a
single keypoint class.
Returns:
A tuple of
refined_keypoints: A [1, num_instances, num_keypoints, 2] float32 tensor
with refined keypoints for a single class in a single image, expressed
in the output (strided) coordinate frame. Note that `num_instances` is a
dynamic dimension, and corresponds to the number of valid detections
for the specific class.
refined_scores: A [1, num_instances, num_keypoints] float32 tensor with
keypoint scores.
"""
keypoint_indices = kp_params.keypoint_indices
num_keypoints = len(keypoint_indices)
keypoint_heatmap = tf.nn.sigmoid(
keypoint_heatmap[batch_index:batch_index+1, ...])
keypoint_offsets = keypoint_offsets[batch_index:batch_index+1, ...]
keypoint_regression = keypoint_regression[batch_index:batch_index+1, ...]
y_indices = y_indices[batch_index:batch_index+1, ...]
x_indices = x_indices[batch_index:batch_index+1, ...]
boxes_slice = boxes[batch_index:batch_index+1, ...]
# Gather the feature map locations corresponding to the object class.
y_indices_for_kpt_class = tf.gather(y_indices, indices_with_kpt_class,
axis=1)
x_indices_for_kpt_class = tf.gather(x_indices, indices_with_kpt_class,
axis=1)
boxes_for_kpt_class = tf.gather(boxes_slice, indices_with_kpt_class, axis=1)
# Gather the regressed keypoints. Final tensor has shape
# [1, num_instances, num_keypoints, 2].
regressed_keypoints_for_objects = regressed_keypoints_at_object_centers(
keypoint_regression, y_indices_for_kpt_class, x_indices_for_kpt_class)
regressed_keypoints_for_objects = tf.reshape(
regressed_keypoints_for_objects, [1, -1, num_keypoints, 2])
# Get the candidate keypoints and scores.
# The shape of keypoint_candidates and keypoint_scores is:
# [1, num_candidates_per_keypoint, num_keypoints, 2] and
# [1, num_candidates_per_keypoint, num_keypoints], respectively.
keypoint_candidates, keypoint_scores, num_keypoint_candidates = (
prediction_tensors_to_keypoint_candidates(
keypoint_heatmap, keypoint_offsets,
keypoint_score_threshold=(
kp_params.keypoint_candidate_score_threshold),
max_pool_kernel_size=kp_params.peak_max_pool_kernel_size,
max_candidates=kp_params.num_candidates_per_keypoint))
# Get the refined keypoints and scores, of shape
# [1, num_instances, num_keypoints, 2] and
# [1, num_instances, num_keypoints], respectively.
refined_keypoints, refined_scores = refine_keypoints(
regressed_keypoints=regressed_keypoints_for_objects,
keypoint_candidates=keypoint_candidates,
keypoint_scores=keypoint_scores,
num_keypoint_candidates=num_keypoint_candidates,
bboxes=boxes_for_kpt_class,
unmatched_keypoint_score=kp_params.unmatched_keypoint_score,
box_scale=kp_params.box_scale,
candidate_search_scale=kp_params.candidate_search_scale,
candidate_ranking_mode=kp_params.candidate_ranking_mode)
return refined_keypoints, refined_scores
def _postprocess_keypoints_for_class_and_image_simple(
self, keypoint_heatmap, keypoint_offsets, keypoint_regression, classes, self, keypoint_heatmap, keypoint_offsets, keypoint_regression, classes,
y_indices, x_indices, boxes, batch_index, kp_params): y_indices, x_indices, boxes, batch_index, kp_params):
"""Postprocess keypoints for a single image and class. """Postprocess keypoints for a single image and class.
This function is similar to "_postprocess_keypoints_for_class_and_image"
except that it assumes there is only one class of bounding box/keypoint to
be handled. The function is tf.lite compatible.
Args: Args:
keypoint_heatmap: A [batch_size, height, width, num_keypoints] float32 keypoint_heatmap: A [batch_size, height, width, num_keypoints] float32
tensor with keypoint heatmaps. tensor with keypoint heatmaps.
......
...@@ -1207,13 +1207,13 @@ _REID_EMBED_SIZE = 2 ...@@ -1207,13 +1207,13 @@ _REID_EMBED_SIZE = 2
_NUM_FC_LAYERS = 1 _NUM_FC_LAYERS = 1
def get_fake_center_params(): def get_fake_center_params(max_box_predictions=5):
"""Returns the fake object center parameter namedtuple.""" """Returns the fake object center parameter namedtuple."""
return cnma.ObjectCenterParams( return cnma.ObjectCenterParams(
classification_loss=losses.WeightedSigmoidClassificationLoss(), classification_loss=losses.WeightedSigmoidClassificationLoss(),
object_center_loss_weight=1.0, object_center_loss_weight=1.0,
min_box_overlap_iou=1.0, min_box_overlap_iou=1.0,
max_box_predictions=5, max_box_predictions=max_box_predictions,
use_labeled_classes=False) use_labeled_classes=False)
...@@ -1225,7 +1225,7 @@ def get_fake_od_params(): ...@@ -1225,7 +1225,7 @@ def get_fake_od_params():
scale_loss_weight=0.1) scale_loss_weight=0.1)
def get_fake_kp_params(): def get_fake_kp_params(num_candidates_per_keypoint=100):
"""Returns the fake keypoint estimation parameter namedtuple.""" """Returns the fake keypoint estimation parameter namedtuple."""
return cnma.KeypointEstimationParams( return cnma.KeypointEstimationParams(
task_name=_TASK_NAME, task_name=_TASK_NAME,
...@@ -1234,7 +1234,8 @@ def get_fake_kp_params(): ...@@ -1234,7 +1234,8 @@ def get_fake_kp_params():
keypoint_std_dev=[0.00001] * len(_KEYPOINT_INDICES), keypoint_std_dev=[0.00001] * len(_KEYPOINT_INDICES),
classification_loss=losses.WeightedSigmoidClassificationLoss(), classification_loss=losses.WeightedSigmoidClassificationLoss(),
localization_loss=losses.L1LocalizationLoss(), localization_loss=losses.L1LocalizationLoss(),
keypoint_candidate_score_threshold=0.1) keypoint_candidate_score_threshold=0.1,
num_candidates_per_keypoint=num_candidates_per_keypoint)
def get_fake_mask_params(): def get_fake_mask_params():
...@@ -1277,7 +1278,9 @@ def get_fake_temporal_offset_params(): ...@@ -1277,7 +1278,9 @@ def get_fake_temporal_offset_params():
task_loss_weight=1.0) task_loss_weight=1.0)
def build_center_net_meta_arch(build_resnet=False, num_classes=_NUM_CLASSES): def build_center_net_meta_arch(build_resnet=False,
num_classes=_NUM_CLASSES,
max_box_predictions=5):
"""Builds the CenterNet meta architecture.""" """Builds the CenterNet meta architecture."""
if build_resnet: if build_resnet:
feature_extractor = ( feature_extractor = (
...@@ -1297,15 +1300,18 @@ def build_center_net_meta_arch(build_resnet=False, num_classes=_NUM_CLASSES): ...@@ -1297,15 +1300,18 @@ def build_center_net_meta_arch(build_resnet=False, num_classes=_NUM_CLASSES):
pad_to_max_dimesnion=True) pad_to_max_dimesnion=True)
if num_classes == 1: if num_classes == 1:
num_candidates_per_keypoint = 100 if max_box_predictions > 1 else 1
return cnma.CenterNetMetaArch( return cnma.CenterNetMetaArch(
is_training=True, is_training=True,
add_summaries=False, add_summaries=False,
num_classes=num_classes, num_classes=num_classes,
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
image_resizer_fn=image_resizer_fn, image_resizer_fn=image_resizer_fn,
object_center_params=get_fake_center_params(), object_center_params=get_fake_center_params(max_box_predictions),
object_detection_params=get_fake_od_params(), object_detection_params=get_fake_od_params(),
keypoint_params_dict={_TASK_NAME: get_fake_kp_params()}) keypoint_params_dict={
_TASK_NAME: get_fake_kp_params(num_candidates_per_keypoint)
})
else: else:
return cnma.CenterNetMetaArch( return cnma.CenterNetMetaArch(
is_training=True, is_training=True,
...@@ -1726,7 +1732,7 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase): ...@@ -1726,7 +1732,7 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase):
detections['detection_surface_coords'][0, 0, :, :], detections['detection_surface_coords'][0, 0, :, :],
np.zeros_like(detections['detection_surface_coords'][0, 0, :, :])) np.zeros_like(detections['detection_surface_coords'][0, 0, :, :]))
def test_postprocess_simple(self): def test_postprocess_single_class(self):
"""Test the postprocess function.""" """Test the postprocess function."""
model = build_center_net_meta_arch(num_classes=1) model = build_center_net_meta_arch(num_classes=1)
max_detection = model._center_params.max_box_predictions max_detection = model._center_params.max_box_predictions
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment