Commit 163ca152 authored by Yu-hui Chen's avatar Yu-hui Chen Committed by TF Object Detection Team
Browse files

New centernet multi-pose postprocessing logics that avoid using topk op which

runs much slower in the browser.

PiperOrigin-RevId: 394148346
parent 33f0aa0f
...@@ -916,7 +916,9 @@ def keypoint_proto_to_params(kp_config, keypoint_map_dict): ...@@ -916,7 +916,9 @@ def keypoint_proto_to_params(kp_config, keypoint_map_dict):
regress_head_kernel_sizes=regress_head_kernel_sizes, regress_head_kernel_sizes=regress_head_kernel_sizes,
score_distance_multiplier=kp_config.score_distance_multiplier, score_distance_multiplier=kp_config.score_distance_multiplier,
std_dev_multiplier=kp_config.std_dev_multiplier, std_dev_multiplier=kp_config.std_dev_multiplier,
rescoring_threshold=kp_config.rescoring_threshold) rescoring_threshold=kp_config.rescoring_threshold,
gaussian_denom_ratio=kp_config.gaussian_denom_ratio,
argmax_postprocessing=kp_config.argmax_postprocessing)
def object_detection_proto_to_params(od_config): def object_detection_proto_to_params(od_config):
...@@ -981,7 +983,8 @@ def object_center_proto_to_params(oc_config): ...@@ -981,7 +983,8 @@ def object_center_proto_to_params(oc_config):
use_labeled_classes=oc_config.use_labeled_classes, use_labeled_classes=oc_config.use_labeled_classes,
keypoint_weights_for_center=keypoint_weights_for_center, keypoint_weights_for_center=keypoint_weights_for_center,
center_head_num_filters=center_head_num_filters, center_head_num_filters=center_head_num_filters,
center_head_kernel_sizes=center_head_kernel_sizes) center_head_kernel_sizes=center_head_kernel_sizes,
peak_max_pool_kernel_size=oc_config.peak_max_pool_kernel_size)
def mask_proto_to_params(mask_config): def mask_proto_to_params(mask_config):
......
...@@ -126,6 +126,8 @@ class ModelBuilderTF2Test( ...@@ -126,6 +126,8 @@ class ModelBuilderTF2Test(
score_distance_multiplier: 11.0 score_distance_multiplier: 11.0
std_dev_multiplier: 2.8 std_dev_multiplier: 2.8
rescoring_threshold: 0.5 rescoring_threshold: 0.5
gaussian_denom_ratio: 0.3
argmax_postprocessing: True
""" """
if customize_head_params: if customize_head_params:
task_proto_txt += """ task_proto_txt += """
...@@ -158,6 +160,7 @@ class ModelBuilderTF2Test( ...@@ -158,6 +160,7 @@ class ModelBuilderTF2Test(
beta: 4.0 beta: 4.0
} }
} }
peak_max_pool_kernel_size: 5
""" """
if customize_head_params: if customize_head_params:
proto_txt += """ proto_txt += """
...@@ -319,6 +322,7 @@ class ModelBuilderTF2Test( ...@@ -319,6 +322,7 @@ class ModelBuilderTF2Test(
else: else:
self.assertEqual(model._center_params.center_head_num_filters, [256]) self.assertEqual(model._center_params.center_head_num_filters, [256])
self.assertEqual(model._center_params.center_head_kernel_sizes, [3]) self.assertEqual(model._center_params.center_head_kernel_sizes, [3])
self.assertEqual(model._center_params.peak_max_pool_kernel_size, 5)
# Check object detection related parameters. # Check object detection related parameters.
self.assertAlmostEqual(model._od_params.offset_loss_weight, 0.1) self.assertAlmostEqual(model._od_params.offset_loss_weight, 0.1)
...@@ -376,6 +380,8 @@ class ModelBuilderTF2Test( ...@@ -376,6 +380,8 @@ class ModelBuilderTF2Test(
self.assertEqual(kp_params.heatmap_head_kernel_sizes, [3]) self.assertEqual(kp_params.heatmap_head_kernel_sizes, [3])
self.assertEqual(kp_params.offset_head_num_filters, [256]) self.assertEqual(kp_params.offset_head_num_filters, [256])
self.assertEqual(kp_params.offset_head_kernel_sizes, [3]) self.assertEqual(kp_params.offset_head_kernel_sizes, [3])
self.assertAlmostEqual(kp_params.gaussian_denom_ratio, 0.3)
self.assertEqual(kp_params.argmax_postprocessing, True)
# Check mask related parameters. # Check mask related parameters.
self.assertAlmostEqual(model._mask_params.task_loss_weight, 0.7) self.assertAlmostEqual(model._mask_params.task_loss_weight, 0.7)
......
...@@ -782,6 +782,269 @@ def prediction_to_single_instance_keypoints( ...@@ -782,6 +782,269 @@ def prediction_to_single_instance_keypoints(
return keypoint_candidates, keypoint_scores, None return keypoint_candidates, keypoint_scores, None
def _gaussian_weighted_map_const_multi(
y_grid, x_grid, heatmap, points_y, points_x, boxes,
gaussian_denom_ratio):
"""Rescores heatmap using the distance information.
The function is called when the candidate_ranking_mode in the
KeypointEstimationParams is set to be 'gaussian_weighted_const'. The
keypoint candidates are ranked using the formula:
heatmap_score * exp((-distances^2) / (gaussian_denom))
where 'gaussian_denom' is determined by:
min(output_feature_height, output_feature_width) * gaussian_denom_ratio
the 'distances' are the distances between the grid coordinates and the target
points.
Note that the postfix 'const' refers to the fact that the denominator is a
constant given the input image size, not scaled by the size of each of the
instances.
Args:
y_grid: A float tensor with shape [height, width] representing the
y-coordinate of each pixel grid.
x_grid: A float tensor with shape [height, width] representing the
x-coordinate of each pixel grid.
heatmap: A float tensor with shape [batch_size, height, width,
num_keypoints] representing the heatmap to be rescored.
points_y: A float tensor with shape [batch_size, num_instances,
num_keypoints] representing the y coordinates of the target points for
each channel.
points_x: A float tensor with shape [batch_size, num_instances,
num_keypoints] representing the x coordinates of the target points for
each channel.
boxes: A tensor of shape [batch_size, num_instances, 4] with predicted
bounding boxes for each instance, expressed in the output coordinate
frame.
gaussian_denom_ratio: A constant used in the above formula that determines
the denominator of the Gaussian kernel.
Returns:
A float tensor with shape [batch_size, height, width, channel] representing
the rescored heatmap.
"""
batch_size, num_instances, _ = _get_shape(boxes, 3)
_, height, width, num_keypoints = _get_shape(heatmap, 4)
# [batch_size, height, width, num_instances, num_keypoints].
# Note that we intentionally avoid using tf.newaxis as TfLite converter
# doesn't like it.
y_diff = (
tf.reshape(y_grid, [1, height, width, 1, 1]) -
tf.reshape(points_y, [batch_size, 1, 1, num_instances, num_keypoints]))
x_diff = (
tf.reshape(x_grid, [1, height, width, 1, 1]) -
tf.reshape(points_x, [batch_size, 1, 1, num_instances, num_keypoints]))
distance_square = y_diff**2 + x_diff**2
y_min, x_min, y_max, x_max = tf.split(boxes, 4, axis=2)
# Make the mask with all 1.0 in the box regions.
# Shape: [batch_size, height, width, num_instances]
in_boxes = tf.math.logical_and(
tf.math.logical_and(
tf.reshape(y_grid, [1, height, width, 1]) >= tf.reshape(
y_min, [batch_size, 1, 1, num_instances]),
tf.reshape(y_grid, [1, height, width, 1]) < tf.reshape(
y_max, [batch_size, 1, 1, num_instances])),
tf.math.logical_and(
tf.reshape(x_grid, [1, height, width, 1]) >= tf.reshape(
x_min, [batch_size, 1, 1, num_instances]),
tf.reshape(x_grid, [1, height, width, 1]) < tf.reshape(
x_max, [batch_size, 1, 1, num_instances])))
in_boxes = tf.cast(in_boxes, dtype=tf.float32)
gaussian_denom = tf.cast(
tf.minimum(height, width), dtype=tf.float32) * gaussian_denom_ratio
# shape: [batch_size, height, width, num_instances, num_keypoints]
gaussian_map = tf.exp((-1 * distance_square) / gaussian_denom)
return tf.expand_dims(
heatmap, axis=3) * gaussian_map * tf.reshape(
in_boxes, [batch_size, height, width, num_instances, 1])
def prediction_tensors_to_multi_instance_kpts(
keypoint_heatmap_predictions,
keypoint_heatmap_offsets,
keypoint_score_heatmap=None):
"""Converts keypoint heatmap predictions and offsets to keypoint candidates.
This function is similar to the 'prediction_tensors_to_single_instance_kpts'
function except that the input keypoint_heatmap_predictions is prepared to
have an additional 'num_instances' dimension for multi-instance prediction.
Args:
keypoint_heatmap_predictions: A float tensor of shape [batch_size, height,
width, num_instances, num_keypoints] representing the per-keypoint and
per-instance heatmaps which is used for finding the best keypoint
candidate locations.
keypoint_heatmap_offsets: A float tensor of shape [batch_size, height,
width, 2 * num_keypoints] representing the per-keypoint offsets.
keypoint_score_heatmap: (optional) A float tensor of shape [batch_size,
height, width, num_keypoints] representing the heatmap
which is used for reporting the confidence scores. If not provided, then
the values in the keypoint_heatmap_predictions will be used.
Returns:
keypoint_candidates: A tensor of shape
[batch_size, max_candidates, num_keypoints, 2] holding the
location of keypoint candidates in [y, x] format (expressed in absolute
coordinates in the output coordinate frame).
keypoint_scores: A float tensor of shape
[batch_size, max_candidates, num_keypoints] with the scores for each
keypoint candidate. The scores come directly from the heatmap predictions.
"""
batch_size, height, width, num_instances, num_keypoints = _get_shape(
keypoint_heatmap_predictions, 5)
# [batch_size, height * width, num_instances * num_keypoints].
feature_map_flattened = tf.reshape(
keypoint_heatmap_predictions,
[batch_size, -1, num_instances * num_keypoints])
# [batch_size, num_instances * num_keypoints].
peak_flat_indices = tf.math.argmax(
feature_map_flattened, axis=1, output_type=tf.dtypes.int32)
# Get x and y indices corresponding to the top indices in the flat array.
y_indices, x_indices = (
row_col_indices_from_flattened_indices(peak_flat_indices, width))
# [batch_size * num_instances * num_keypoints].
y_indices = tf.reshape(y_indices, [-1])
x_indices = tf.reshape(x_indices, [-1])
# Prepare the indices to gather the offsets from the keypoint_heatmap_offsets.
batch_idx = _multi_range(
limit=batch_size, value_repetitions=num_keypoints * num_instances)
kpts_idx = _multi_range(
limit=num_keypoints, value_repetitions=1,
range_repetitions=batch_size * num_instances)
combined_indices = tf.stack([
batch_idx,
y_indices,
x_indices,
kpts_idx
], axis=1)
keypoint_heatmap_offsets = tf.reshape(
keypoint_heatmap_offsets, [batch_size, height, width, num_keypoints, 2])
# Retrieve the keypoint offsets: shape:
# [batch_size * num_instance * num_keypoints, 2].
selected_offsets_flat = tf.gather_nd(keypoint_heatmap_offsets,
combined_indices)
y_offsets, x_offsets = tf.unstack(selected_offsets_flat, axis=1)
keypoint_candidates = tf.stack([
tf.cast(y_indices, dtype=tf.float32) + tf.expand_dims(y_offsets, axis=0),
tf.cast(x_indices, dtype=tf.float32) + tf.expand_dims(x_offsets, axis=0)
], axis=2)
keypoint_candidates = tf.reshape(
keypoint_candidates, [batch_size, num_instances, num_keypoints, 2])
if keypoint_score_heatmap is None:
keypoint_scores = tf.gather_nd(
tf.reduce_max(keypoint_heatmap_predictions, axis=3), combined_indices)
else:
keypoint_scores = tf.gather_nd(keypoint_score_heatmap, combined_indices)
return keypoint_candidates, tf.reshape(
keypoint_scores, [batch_size, num_instances, num_keypoints])
def prediction_to_keypoints_argmax(
prediction_dict,
object_y_indices,
object_x_indices,
boxes,
task_name,
kp_params):
"""Postprocess function to predict multi instance keypoints with argmax op.
This is a different implementation of the original keypoint postprocessing
function such that it avoids using topk op (replaced by argmax) as it runs
much slower in the browser.
Args:
prediction_dict: a dictionary holding predicted tensors, returned from the
predict() method. This dictionary should contain keypoint prediction
feature maps for each keypoint task.
object_y_indices: A float tensor of shape [batch_size, max_instances]
representing the location indices of the object centers.
object_x_indices: A float tensor of shape [batch_size, max_instances]
representing the location indices of the object centers.
boxes: A tensor of shape [batch_size, num_instances, 4] with predicted
bounding boxes for each instance, expressed in the output coordinate
frame.
task_name: string, the name of the task this namedtuple corresponds to.
Note that it should be an unique identifier of the task.
kp_params: A `KeypointEstimationParams` object with parameters for a single
keypoint class.
Returns:
A tuple of two tensors:
keypoint_candidates: A float tensor with shape [batch_size,
num_instances, num_keypoints, 2] representing the yx-coordinates of
the keypoints in the output feature map space.
keypoint_scores: A float tensor with shape [batch_size, num_instances,
num_keypoints] representing the keypoint prediction scores.
Raises:
ValueError: if the candidate_ranking_mode is not supported.
"""
keypoint_heatmap = tf.nn.sigmoid(prediction_dict[
get_keypoint_name(task_name, KEYPOINT_HEATMAP)][-1])
keypoint_offset = prediction_dict[
get_keypoint_name(task_name, KEYPOINT_OFFSET)][-1]
keypoint_regression = prediction_dict[
get_keypoint_name(task_name, KEYPOINT_REGRESSION)][-1]
batch_size, height, width, num_keypoints = _get_shape(keypoint_heatmap, 4)
# Create the y,x grids: [height, width]
(y_grid, x_grid) = ta_utils.image_shape_to_grids(height, width)
# Prepare the indices to retrieve the information from object centers.
num_instances = _get_shape(object_y_indices, 2)[1]
combined_obj_indices = tf.stack([
_multi_range(batch_size, value_repetitions=num_instances),
tf.reshape(object_y_indices, [-1]),
tf.reshape(object_x_indices, [-1])
], axis=1)
# Select the regression vectors from the object center.
selected_regression_flat = tf.gather_nd(
keypoint_regression, combined_obj_indices)
selected_regression = tf.reshape(
selected_regression_flat, [batch_size, num_instances, num_keypoints, 2])
(y_reg, x_reg) = tf.unstack(selected_regression, axis=3)
# shape: [batch_size, num_instances, num_keypoints].
y_regressed = tf.cast(
tf.reshape(object_y_indices, [batch_size, num_instances, 1]),
dtype=tf.float32) + y_reg
x_regressed = tf.cast(
tf.reshape(object_x_indices, [batch_size, num_instances, 1]),
dtype=tf.float32) + x_reg
if kp_params.candidate_ranking_mode == 'gaussian_weighted_const':
rescored_heatmap = _gaussian_weighted_map_const_multi(
y_grid, x_grid, keypoint_heatmap, y_regressed, x_regressed, boxes,
kp_params.gaussian_denom_ratio)
# shape: [batch_size, height, width, num_keypoints].
keypoint_score_heatmap = tf.math.reduce_max(rescored_heatmap, axis=3)
else:
raise ValueError(
'Unsupported ranking mode in the multipose no topk method: %s' %
kp_params.candidate_ranking_mode)
(keypoint_candidates,
keypoint_scores) = prediction_tensors_to_multi_instance_kpts(
keypoint_heatmap_predictions=rescored_heatmap,
keypoint_heatmap_offsets=keypoint_offset,
keypoint_score_heatmap=keypoint_score_heatmap)
return keypoint_candidates, keypoint_scores
def regressed_keypoints_at_object_centers(regressed_keypoint_predictions, def regressed_keypoints_at_object_centers(regressed_keypoint_predictions,
y_indices, x_indices): y_indices, x_indices):
"""Returns the regressed keypoints at specified object centers. """Returns the regressed keypoints at specified object centers.
...@@ -1900,7 +2163,8 @@ class KeypointEstimationParams( ...@@ -1900,7 +2163,8 @@ class KeypointEstimationParams(
'heatmap_head_kernel_sizes', 'offset_head_num_filters', 'heatmap_head_kernel_sizes', 'offset_head_num_filters',
'offset_head_kernel_sizes', 'regress_head_num_filters', 'offset_head_kernel_sizes', 'regress_head_num_filters',
'regress_head_kernel_sizes', 'score_distance_multiplier', 'regress_head_kernel_sizes', 'score_distance_multiplier',
'std_dev_multiplier', 'rescoring_threshold' 'std_dev_multiplier', 'rescoring_threshold', 'gaussian_denom_ratio',
'argmax_postprocessing'
])): ])):
"""Namedtuple to host object detection related parameters. """Namedtuple to host object detection related parameters.
...@@ -1948,7 +2212,9 @@ class KeypointEstimationParams( ...@@ -1948,7 +2212,9 @@ class KeypointEstimationParams(
regress_head_kernel_sizes=(3), regress_head_kernel_sizes=(3),
score_distance_multiplier=0.1, score_distance_multiplier=0.1,
std_dev_multiplier=1.0, std_dev_multiplier=1.0,
rescoring_threshold=0.0): rescoring_threshold=0.0,
argmax_postprocessing=False,
gaussian_denom_ratio=0.1):
"""Constructor with default values for KeypointEstimationParams. """Constructor with default values for KeypointEstimationParams.
Args: Args:
...@@ -2049,6 +2315,12 @@ class KeypointEstimationParams( ...@@ -2049,6 +2315,12 @@ class KeypointEstimationParams(
True. The detection score of an instance is set to be the average over True. The detection score of an instance is set to be the average over
the scores of the keypoints which their scores higher than the the scores of the keypoints which their scores higher than the
threshold. threshold.
argmax_postprocessing: Whether to use the keypoint postprocessing logic
that replaces the topk op with argmax. Usually used when exporting the
model for predicting keypoints of multiple instances in the browser.
gaussian_denom_ratio: The ratio used to multiply the image size to
determine the denominator of the Gaussian formula. Only applicable when
the candidate_ranking_mode is set to be 'gaussian_weighted_const'.
Returns: Returns:
An initialized KeypointEstimationParams namedtuple. An initialized KeypointEstimationParams namedtuple.
...@@ -2067,7 +2339,8 @@ class KeypointEstimationParams( ...@@ -2067,7 +2339,8 @@ class KeypointEstimationParams(
heatmap_head_num_filters, heatmap_head_kernel_sizes, heatmap_head_num_filters, heatmap_head_kernel_sizes,
offset_head_num_filters, offset_head_kernel_sizes, offset_head_num_filters, offset_head_kernel_sizes,
regress_head_num_filters, regress_head_kernel_sizes, regress_head_num_filters, regress_head_kernel_sizes,
score_distance_multiplier, std_dev_multiplier, rescoring_threshold) score_distance_multiplier, std_dev_multiplier, rescoring_threshold,
argmax_postprocessing, gaussian_denom_ratio)
class ObjectCenterParams( class ObjectCenterParams(
...@@ -2075,7 +2348,7 @@ class ObjectCenterParams( ...@@ -2075,7 +2348,7 @@ class ObjectCenterParams(
'classification_loss', 'object_center_loss_weight', 'heatmap_bias_init', 'classification_loss', 'object_center_loss_weight', 'heatmap_bias_init',
'min_box_overlap_iou', 'max_box_predictions', 'use_labeled_classes', 'min_box_overlap_iou', 'max_box_predictions', 'use_labeled_classes',
'keypoint_weights_for_center', 'center_head_num_filters', 'keypoint_weights_for_center', 'center_head_num_filters',
'center_head_kernel_sizes' 'center_head_kernel_sizes', 'peak_max_pool_kernel_size'
])): ])):
"""Namedtuple to store object center prediction related parameters.""" """Namedtuple to store object center prediction related parameters."""
...@@ -2090,7 +2363,8 @@ class ObjectCenterParams( ...@@ -2090,7 +2363,8 @@ class ObjectCenterParams(
use_labeled_classes=False, use_labeled_classes=False,
keypoint_weights_for_center=None, keypoint_weights_for_center=None,
center_head_num_filters=(256), center_head_num_filters=(256),
center_head_kernel_sizes=(3)): center_head_kernel_sizes=(3),
peak_max_pool_kernel_size=3):
"""Constructor with default values for ObjectCenterParams. """Constructor with default values for ObjectCenterParams.
Args: Args:
...@@ -2115,6 +2389,8 @@ class ObjectCenterParams( ...@@ -2115,6 +2389,8 @@ class ObjectCenterParams(
by the object center prediction head. by the object center prediction head.
center_head_kernel_sizes: kernel size of the convolutional layers used center_head_kernel_sizes: kernel size of the convolutional layers used
by the object center prediction head. by the object center prediction head.
peak_max_pool_kernel_size: Max pool kernel size to use to pull off peak
score locations in a neighborhood for the object detection heatmap.
Returns: Returns:
An initialized ObjectCenterParams namedtuple. An initialized ObjectCenterParams namedtuple.
""" """
...@@ -2123,7 +2399,8 @@ class ObjectCenterParams( ...@@ -2123,7 +2399,8 @@ class ObjectCenterParams(
object_center_loss_weight, heatmap_bias_init, object_center_loss_weight, heatmap_bias_init,
min_box_overlap_iou, max_box_predictions, min_box_overlap_iou, max_box_predictions,
use_labeled_classes, keypoint_weights_for_center, use_labeled_classes, keypoint_weights_for_center,
center_head_num_filters, center_head_kernel_sizes) center_head_num_filters, center_head_kernel_sizes,
peak_max_pool_kernel_size)
class MaskParams( class MaskParams(
...@@ -3773,7 +4050,8 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -3773,7 +4050,8 @@ class CenterNetMetaArch(model.DetectionModel):
# center predictions. # center predictions.
detection_scores, y_indices, x_indices, channel_indices = ( detection_scores, y_indices, x_indices, channel_indices = (
top_k_feature_map_locations( top_k_feature_map_locations(
object_center_prob, max_pool_kernel_size=3, object_center_prob,
max_pool_kernel_size=self._center_params.peak_max_pool_kernel_size,
k=self._center_params.max_box_predictions)) k=self._center_params.max_box_predictions))
multiclass_scores = tf.gather_nd( multiclass_scores = tf.gather_nd(
object_center_prob, tf.stack([y_indices, x_indices], -1), batch_dims=1) object_center_prob, tf.stack([y_indices, x_indices], -1), batch_dims=1)
...@@ -3808,6 +4086,18 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -3808,6 +4086,18 @@ class CenterNetMetaArch(model.DetectionModel):
# the ops that are supported by tf.lite on GPU. # the ops that are supported by tf.lite on GPU.
clip_keypoints = self._should_clip_keypoints() clip_keypoints = self._should_clip_keypoints()
if len(self._kp_params_dict) == 1 and self._num_classes == 1: if len(self._kp_params_dict) == 1 and self._num_classes == 1:
task_name, kp_params = next(iter(self._kp_params_dict.items()))
keypoint_depths = None
if kp_params.argmax_postprocessing:
keypoints, keypoint_scores = (
prediction_to_keypoints_argmax(
prediction_dict,
object_y_indices=y_indices,
object_x_indices=x_indices,
boxes=boxes_strided,
task_name=task_name,
kp_params=kp_params))
else:
(keypoints, keypoint_scores, (keypoints, keypoint_scores,
keypoint_depths) = self._postprocess_keypoints_single_class( keypoint_depths) = self._postprocess_keypoints_single_class(
prediction_dict, channel_indices, y_indices, x_indices, prediction_dict, channel_indices, y_indices, x_indices,
......
...@@ -807,6 +807,77 @@ class CenterNetMetaArchHelpersTest(test_case.TestCase, parameterized.TestCase): ...@@ -807,6 +807,77 @@ class CenterNetMetaArchHelpersTest(test_case.TestCase, parameterized.TestCase):
np.testing.assert_allclose(expected_keypoint_candidates, keypoint_cands) np.testing.assert_allclose(expected_keypoint_candidates, keypoint_cands)
np.testing.assert_allclose(expected_keypoint_scores, keypoint_scores) np.testing.assert_allclose(expected_keypoint_scores, keypoint_scores)
@parameterized.parameters({'provide_keypoint_score': True},
{'provide_keypoint_score': False})
def test_prediction_to_multi_instance_keypoints(self, provide_keypoint_score):
image_size = (9, 9)
keypoint_heatmap_np = np.zeros((1, image_size[0], image_size[1], 3, 4),
dtype=np.float32)
# Instance 0.
keypoint_heatmap_np[0, 1, 1, 0, 0] = 0.9
keypoint_heatmap_np[0, 1, 7, 0, 1] = 0.9
keypoint_heatmap_np[0, 7, 1, 0, 2] = 0.9
keypoint_heatmap_np[0, 7, 7, 0, 3] = 0.9
# Instance 1.
keypoint_heatmap_np[0, 2, 2, 1, 0] = 0.8
keypoint_heatmap_np[0, 2, 8, 1, 1] = 0.8
keypoint_heatmap_np[0, 8, 2, 1, 2] = 0.8
keypoint_heatmap_np[0, 8, 8, 1, 3] = 0.8
keypoint_offset_np = np.zeros((1, image_size[0], image_size[1], 8),
dtype=np.float32)
keypoint_offset_np[0, 1, 1] = [0.5, 0.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
keypoint_offset_np[0, 1, 7] = [0.0, 0.0, 0.5, -0.5, 0.0, 0.0, 0.0, 0.0]
keypoint_offset_np[0, 7, 1] = [0.0, 0.0, 0.0, 0.0, -0.5, 0.5, 0.0, 0.0]
keypoint_offset_np[0, 7, 7] = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.5, -0.5]
keypoint_offset_np[0, 2, 2] = [0.3, 0.3, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
keypoint_offset_np[0, 2, 8] = [0.0, 0.0, 0.3, -0.3, 0.0, 0.0, 0.0, 0.0]
keypoint_offset_np[0, 8, 2] = [0.0, 0.0, 0.0, 0.0, -0.3, 0.3, 0.0, 0.0]
keypoint_offset_np[0, 8, 8] = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.3, -0.3]
def graph_fn():
keypoint_heatmap = tf.constant(keypoint_heatmap_np, dtype=tf.float32)
keypoint_offset = tf.constant(keypoint_offset_np, dtype=tf.float32)
if provide_keypoint_score:
(keypoint_cands, keypoint_scores) = (
cnma.prediction_tensors_to_multi_instance_kpts(
keypoint_heatmap,
keypoint_offset,
tf.reduce_max(keypoint_heatmap, axis=3)))
else:
(keypoint_cands, keypoint_scores) = (
cnma.prediction_tensors_to_multi_instance_kpts(
keypoint_heatmap,
keypoint_offset))
return keypoint_cands, keypoint_scores
(keypoint_cands, keypoint_scores) = self.execute(graph_fn, [])
expected_keypoint_candidates_0 = [
[1.5, 1.5], # top-left
[1.5, 6.5], # top-right
[6.5, 1.5], # bottom-left
[6.5, 6.5], # bottom-right
]
expected_keypoint_scores_0 = [0.9, 0.9, 0.9, 0.9]
expected_keypoint_candidates_1 = [
[2.3, 2.3], # top-left
[2.3, 7.7], # top-right
[7.7, 2.3], # bottom-left
[7.7, 7.7], # bottom-right
]
expected_keypoint_scores_1 = [0.8, 0.8, 0.8, 0.8]
np.testing.assert_allclose(
expected_keypoint_candidates_0, keypoint_cands[0, 0, :, :])
np.testing.assert_allclose(
expected_keypoint_candidates_1, keypoint_cands[0, 1, :, :])
np.testing.assert_allclose(
expected_keypoint_scores_0, keypoint_scores[0, 0, :])
np.testing.assert_allclose(
expected_keypoint_scores_1, keypoint_scores[0, 1, :])
def test_keypoint_candidate_prediction_per_keypoints(self): def test_keypoint_candidate_prediction_per_keypoints(self):
keypoint_heatmap_np = np.zeros((2, 3, 3, 2), dtype=np.float32) keypoint_heatmap_np = np.zeros((2, 3, 3, 2), dtype=np.float32)
keypoint_heatmap_np[0, 0, 0, 0] = 1.0 keypoint_heatmap_np[0, 0, 0, 0] = 1.0
...@@ -1644,7 +1715,8 @@ def get_fake_kp_params(num_candidates_per_keypoint=100, ...@@ -1644,7 +1715,8 @@ def get_fake_kp_params(num_candidates_per_keypoint=100,
predict_depth=False, predict_depth=False,
per_keypoint_depth=False, per_keypoint_depth=False,
peak_radius=0, peak_radius=0,
candidate_ranking_mode='min_distance'): candidate_ranking_mode='min_distance',
argmax_postprocessing=False):
"""Returns the fake keypoint estimation parameter namedtuple.""" """Returns the fake keypoint estimation parameter namedtuple."""
return cnma.KeypointEstimationParams( return cnma.KeypointEstimationParams(
task_name=_TASK_NAME, task_name=_TASK_NAME,
...@@ -1660,7 +1732,8 @@ def get_fake_kp_params(num_candidates_per_keypoint=100, ...@@ -1660,7 +1732,8 @@ def get_fake_kp_params(num_candidates_per_keypoint=100,
predict_depth=predict_depth, predict_depth=predict_depth,
per_keypoint_depth=per_keypoint_depth, per_keypoint_depth=per_keypoint_depth,
offset_peak_radius=peak_radius, offset_peak_radius=peak_radius,
candidate_ranking_mode=candidate_ranking_mode) candidate_ranking_mode=candidate_ranking_mode,
argmax_postprocessing=argmax_postprocessing)
def get_fake_mask_params(): def get_fake_mask_params():
...@@ -1715,7 +1788,8 @@ def build_center_net_meta_arch(build_resnet=False, ...@@ -1715,7 +1788,8 @@ def build_center_net_meta_arch(build_resnet=False,
per_keypoint_depth=False, per_keypoint_depth=False,
peak_radius=0, peak_radius=0,
keypoint_only=False, keypoint_only=False,
candidate_ranking_mode='min_distance'): candidate_ranking_mode='min_distance',
argmax_postprocessing=False):
"""Builds the CenterNet meta architecture.""" """Builds the CenterNet meta architecture."""
if build_resnet: if build_resnet:
feature_extractor = ( feature_extractor = (
...@@ -1762,7 +1836,8 @@ def build_center_net_meta_arch(build_resnet=False, ...@@ -1762,7 +1836,8 @@ def build_center_net_meta_arch(build_resnet=False,
get_fake_kp_params(num_candidates_per_keypoint, get_fake_kp_params(num_candidates_per_keypoint,
per_keypoint_offset, predict_depth, per_keypoint_offset, predict_depth,
per_keypoint_depth, peak_radius, per_keypoint_depth, peak_radius,
candidate_ranking_mode) candidate_ranking_mode,
argmax_postprocessing)
}, },
non_max_suppression_fn=non_max_suppression_fn) non_max_suppression_fn=non_max_suppression_fn)
elif detection_only: elif detection_only:
...@@ -1790,7 +1865,8 @@ def build_center_net_meta_arch(build_resnet=False, ...@@ -1790,7 +1865,8 @@ def build_center_net_meta_arch(build_resnet=False,
get_fake_kp_params(num_candidates_per_keypoint, get_fake_kp_params(num_candidates_per_keypoint,
per_keypoint_offset, predict_depth, per_keypoint_offset, predict_depth,
per_keypoint_depth, peak_radius, per_keypoint_depth, peak_radius,
candidate_ranking_mode) candidate_ranking_mode,
argmax_postprocessing)
}, },
non_max_suppression_fn=non_max_suppression_fn) non_max_suppression_fn=non_max_suppression_fn)
else: else:
...@@ -2324,17 +2400,32 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase): ...@@ -2324,17 +2400,32 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase):
self.assertAllClose(expected_multiclass_scores, self.assertAllClose(expected_multiclass_scores,
detections['detection_multiclass_scores'][0][0]) detections['detection_multiclass_scores'][0][0])
def test_postprocess_single_class(self): @parameterized.parameters(
{
'candidate_ranking_mode': 'min_distance',
'argmax_postprocessing': False
},
{
'candidate_ranking_mode': 'gaussian_weighted_const',
'argmax_postprocessing': True
})
def test_postprocess_single_class(self, candidate_ranking_mode,
argmax_postprocessing):
"""Test the postprocess function.""" """Test the postprocess function."""
model = build_center_net_meta_arch(num_classes=1) model = build_center_net_meta_arch(
num_classes=1, max_box_predictions=5, per_keypoint_offset=True,
candidate_ranking_mode=candidate_ranking_mode,
argmax_postprocessing=argmax_postprocessing)
max_detection = model._center_params.max_box_predictions max_detection = model._center_params.max_box_predictions
num_keypoints = len(model._kp_params_dict[_TASK_NAME].keypoint_indices) num_keypoints = len(model._kp_params_dict[_TASK_NAME].keypoint_indices)
class_center = np.zeros((1, 32, 32, 1), dtype=np.float32) class_center = np.zeros((1, 32, 32, 1), dtype=np.float32)
height_width = np.zeros((1, 32, 32, 2), dtype=np.float32) height_width = np.zeros((1, 32, 32, 2), dtype=np.float32)
offset = np.zeros((1, 32, 32, 2), dtype=np.float32) offset = np.zeros((1, 32, 32, 2), dtype=np.float32)
keypoint_heatmaps = np.zeros((1, 32, 32, num_keypoints), dtype=np.float32) keypoint_heatmaps = np.ones(
keypoint_offsets = np.zeros((1, 32, 32, 2), dtype=np.float32) (1, 32, 32, num_keypoints), dtype=np.float32) * _logit(0.01)
keypoint_offsets = np.zeros(
(1, 32, 32, num_keypoints * 2), dtype=np.float32)
keypoint_regression = np.random.randn(1, 32, 32, num_keypoints * 2) keypoint_regression = np.random.randn(1, 32, 32, num_keypoints * 2)
class_probs = np.zeros(1) class_probs = np.zeros(1)
...@@ -2387,6 +2478,9 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase): ...@@ -2387,6 +2478,9 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase):
self.assertEqual(detections['num_detections'], [5]) self.assertEqual(detections['num_detections'], [5])
self.assertAllEqual([1, max_detection, num_keypoints, 2], self.assertAllEqual([1, max_detection, num_keypoints, 2],
detections['detection_keypoints'].shape) detections['detection_keypoints'].shape)
self.assertAllClose(
[[0.4375, 0.4375], [0.4375, 0.5625], [0.5625, 0.4375]],
detections['detection_keypoints'][0, 0, 0:3, :])
self.assertAllEqual([1, max_detection, num_keypoints], self.assertAllEqual([1, max_detection, num_keypoints],
detections['detection_keypoint_scores'].shape) detections['detection_keypoint_scores'].shape)
......
...@@ -111,6 +111,10 @@ message CenterNet { ...@@ -111,6 +111,10 @@ message CenterNet {
// Parameters to determine the architecture of the object center prediction // Parameters to determine the architecture of the object center prediction
// head. // head.
optional PredictionHeadParams center_head_params = 8; optional PredictionHeadParams center_head_params = 8;
// Max pool kernel size to use to pull off peak score locations in a
// neighborhood for the object detection heatmap.
optional int32 peak_max_pool_kernel_size = 9 [default = 3];
} }
optional ObjectCenterParams object_center_params = 5; optional ObjectCenterParams object_center_params = 5;
...@@ -266,6 +270,16 @@ message CenterNet { ...@@ -266,6 +270,16 @@ message CenterNet {
// with scores higher than the threshold. // with scores higher than the threshold.
optional float rescoring_threshold = 30 [default = 0.0]; optional float rescoring_threshold = 30 [default = 0.0];
// The ratio used to multiply the output feature map size to determine the
// denominator in the Gaussian formula. Only applicable when the
// candidate_ranking_mode is set to be 'gaussian_weighted_const'.
optional float gaussian_denom_ratio = 31 [default = 0.1];
// Whether to use the keypoint postprocessing logic that replaces topk op
// with argmax. Usually used when exporting the model for predicting
// keypoints of multiple instances in the browser.
optional bool argmax_postprocessing = 32 [default = false];
// Parameters to determine the architecture of the keypoint heatmap // Parameters to determine the architecture of the keypoint heatmap
// prediction head. // prediction head.
optional PredictionHeadParams heatmap_head_params = 25; optional PredictionHeadParams heatmap_head_params = 25;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment