"git@developer.sourcefind.cn:change/sglang.git" did not exist on "ecc9f3e47abd8fa1a23020a91b4a50088fd3c060"
Commit dc2d15db authored by Yu-hui Chen's avatar Yu-hui Chen Committed by TF Object Detection Team
Browse files

Added the specialized postprocessing function for single instance keypoint

estimation. It is designed specifically to run the model in the browser
environment using tf.js.

PiperOrigin-RevId: 352609409
parent 7c9d9ede
...@@ -32,6 +32,7 @@ from object_detection.core import model ...@@ -32,6 +32,7 @@ from object_detection.core import model
from object_detection.core import standard_fields as fields from object_detection.core import standard_fields as fields
from object_detection.core import target_assigner as cn_assigner from object_detection.core import target_assigner as cn_assigner
from object_detection.utils import shape_utils from object_detection.utils import shape_utils
from object_detection.utils import target_assigner_utils as ta_utils
# Number of channels needed to predict size and offsets. # Number of channels needed to predict size and offsets.
NUM_OFFSET_CHANNELS = 2 NUM_OFFSET_CHANNELS = 2
...@@ -526,6 +527,125 @@ def prediction_tensors_to_keypoint_candidates( ...@@ -526,6 +527,125 @@ def prediction_tensors_to_keypoint_candidates(
return keypoint_candidates, keypoint_scores, num_candidates return keypoint_candidates, keypoint_scores, num_candidates
def prediction_to_single_instance_keypoints(object_heatmap, keypoint_heatmap,
keypoint_offset,
keypoint_regression, stride,
object_center_std_dev,
keypoint_std_dev, kp_params):
"""Postprocess function to predict single instance keypoints.
This is a simplified postprocessing function based on the assumption that
there is only one instance in the image. If there are multiple instances in
the image, the model prefers to predict the one that is closest to the image
center. Here is a high-level description of what this function does:
1) Object heatmap re-weighted by image center Gaussian is used to determine
the instance center.
2) Regressed keypoint locations are retrieved from the instance center. The
Gaussian kernel is applied to the regressed keypoint locations to
re-weight the keypoint heatmap. This is to select the keypoints that are
associated with the center instance without using top_k op.
3) The keypoint locations are computed by the re-weighted keypoint heatmap
and the keypoint offset.
Args:
object_heatmap: A float tensor of shape [1, height, width, 1] representing
the heapmap of the class.
keypoint_heatmap: A float tensor of shape [1, height, width, num_keypoints]
representing the per-keypoint heatmaps.
keypoint_offset: A float tensor of shape [1, height, width, 2] (or [1,
height, width, 2 * num_keypoints] if 'per_keypoint_offset' is set True)
representing the per-keypoint offsets.
keypoint_regression: A float tensor of shape [1, height, width, 2 *
num_keypoints] representing the joint regression prediction.
stride: The stride in the output space.
object_center_std_dev: The standard deviation of the Gaussian mask which is
applied to the object_heatmap. The goal is to upweight the instance that
is closer to the image center. Expressed in units of input image pixels.
keypoint_std_dev: The standard deviation of the Gaussian masks which are
applied to the keypoint_heatmap based on the regressed joint location. It
is used to upweight the keypoint joints that belongs to the targeted
instance. If keypoint_std_dev contains 1 element, all keypoint joints will
share the same value. Otherwise, it must contain num_keypoints elements,
representing the standard deviation corresponding to each joint.
kp_params: A `KeypointEstimationParams` object with parameters for a single
keypoint class.
Returns:
A tuple of two tensors:
keypoint_candidates: A float tensor with shape [1, 1, num_keypoints, 2]
representing the yx-coordinates of the keypoints in the output feature
map space.
keypoint_scores: A float tensor with shape [1, 1, num_keypoints]
representing the keypoint prediction scores.
Raises:
ValueError: if the input keypoint_std_dev doesn't have valid number of
elements (1 or num_keypoints).
"""
num_keypoints = len(kp_params.keypoint_std_dev)
batch_size, height, width, _ = _get_shape(keypoint_heatmap, 4)
# Apply the Gaussian mask to the image center.
image_center_y = tf.convert_to_tensor([0.5 * height], dtype=tf.float32)
image_center_x = tf.convert_to_tensor([0.5 * width], dtype=tf.float32)
(y_grid, x_grid) = ta_utils.image_shape_to_grids(height, width)
# Mask shape: [1, height, width, 1]
object_mask = tf.expand_dims(
ta_utils.coordinates_to_heatmap(y_grid, x_grid, image_center_y,
image_center_x,
object_center_std_dev / stride,
tf.one_hot(tf.range(1), depth=1)), axis=0)
object_heatmap = tf.math.multiply(object_heatmap, object_mask)
# Pick the highest score and location of the weighted object heatmap.
_, y_indices, x_indices, _ = (
top_k_feature_map_locations(
object_heatmap, max_pool_kernel_size=1, k=1, per_channel=True))
_, num_indices = _get_shape(y_indices, 2)
combined_indices = tf.stack([
_multi_range(batch_size, value_repetitions=num_indices),
tf.reshape(y_indices, [-1]),
tf.reshape(x_indices, [-1])
], axis=1)
# Select the regression vectors from the object center.
selected_regression_flat = tf.gather_nd(keypoint_regression, combined_indices)
# shape: [num_keypoints, 2]
regression_offsets = tf.reshape(selected_regression_flat, [num_keypoints, -1])
(y_reg, x_reg) = tf.unstack(regression_offsets, axis=1)
y_regressed = tf.cast(y_indices, dtype=tf.float32) + y_reg
x_regressed = tf.cast(x_indices, dtype=tf.float32) + x_reg
# Prepare and apply the keypoint heatmap masks.
keypoint_std_dev = [x / stride for x in keypoint_std_dev]
if len(keypoint_std_dev) == 1:
std_dev = tf.convert_to_tensor(
keypoint_std_dev * num_keypoints, dtype=tf.float32)
elif len(keypoint_std_dev) == num_keypoints:
std_dev = tf.convert_to_tensor(
keypoint_std_dev, dtype=tf.float32)
else:
raise ValueError('keypoint_std_dev needs to have length either '
'equal to 1 or num_keypoints.')
channel_onehot = tf.one_hot(tf.range(num_keypoints), depth=num_keypoints)
keypoint_mask = tf.expand_dims(
ta_utils.coordinates_to_heatmap(y_grid, x_grid, y_regressed, x_regressed,
std_dev, channel_onehot), axis=0)
keypoint_predictions = tf.math.multiply(keypoint_heatmap, keypoint_mask)
# Get the keypoint locations/scores:
# keypoint_candidates: [1, 1, num_keypoints, 2]
# keypoint_scores: [1, 1, num_keypoints]
(keypoint_candidates, keypoint_scores,
_) = prediction_tensors_to_keypoint_candidates(
keypoint_predictions,
keypoint_offset,
keypoint_score_threshold=kp_params.keypoint_candidate_score_threshold,
max_pool_kernel_size=kp_params.peak_max_pool_kernel_size,
max_candidates=1)
return keypoint_candidates, keypoint_scores
def regressed_keypoints_at_object_centers(regressed_keypoint_predictions, def regressed_keypoints_at_object_centers(regressed_keypoint_predictions,
y_indices, x_indices): y_indices, x_indices):
"""Returns the regressed keypoints at specified object centers. """Returns the regressed keypoints at specified object centers.
...@@ -2990,6 +3110,89 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -2990,6 +3110,89 @@ class CenterNetMetaArch(model.DetectionModel):
return postprocess_dict return postprocess_dict
def postprocess_single_instance_keypoints(self, prediction_dict,
true_image_shapes,
object_center_std_dev,
keypoint_std_dev):
"""Postprocess for predicting single instance keypoints.
This postprocess function is a special case of predicting the keypoint of
a single instance in the image (original CenterNet postprocess supports
multi-instance prediction). Due to the simplification assumption, this
postprocessing function achieves much faster inference time.
Here is a short list of the modifications made in this function:
1) Assume the model predicts only single class keypoint.
2) Assume there is only one instance in the image. If multiple instances
appear in the image, the model tends to predict the one that is closer
to the image center (the other ones are considered as background and
are rejected by the model).
3) Avoid using top_k ops in the postprocessing logics since it is slower
than using argmax.
4) The predictions other than the keypoints are ignored, e.g. boxes.
5) The input batch size is assumed to be 1.
Args:
prediction_dict: a dictionary holding predicted tensors from "predict"
function.
true_image_shapes: int32 tensor of shape [batch, 3] where each row is of
the form [height, width, channels] indicating the shapes of true images
in the resized images, as resized images can be padded with zeros.
object_center_std_dev: The standard deviation of the Gaussian mask which
is applied to the object_heatmap. The goal is to upweight the instance
that is closer to the image center. Expressed in units of input image
pixels.
keypoint_std_dev: The standard deviation of the Gaussian masks which are
applied to the keypoint_heatmap based on the regressed joint location.
It is used to upweight the keypoint joints that belongs to the targeted
instance. If keypoint_std_dev contains one value, then we assume the
same value is applied to all keypoint joints. If keypoint_std_dev is a
list, it must contain num_keypoints elements, representing the standard
deviation corresponding to each joints.
Returns:
detections: a dictionary containing the following fields
detection_keypoints: A float tensor of shape
[1, 1, num_keypoints, 2] with normalized keypoints. Any invalid
keypoints have their coordinates and scores set to 0.0.
detection_keypoint_scores: A float tensor of shape
[1, 1, num_keypoints] with scores for each keypoint.
"""
# The number of keypoint task is expected to be 1.
assert len(self._kp_params_dict) == 1
task_name, kp_params = next(iter(self._kp_params_dict.items()))
keypoint_heatmap = tf.nn.sigmoid(prediction_dict[get_keypoint_name(
task_name, KEYPOINT_HEATMAP)][-1])
keypoint_offset = prediction_dict[get_keypoint_name(task_name,
KEYPOINT_OFFSET)][-1]
keypoint_regression = prediction_dict[get_keypoint_name(
task_name, KEYPOINT_REGRESSION)][-1]
object_heatmap = tf.nn.sigmoid(prediction_dict[OBJECT_CENTER][-1])
keypoints, keypoint_scores = (
prediction_to_single_instance_keypoints(
object_heatmap=object_heatmap,
keypoint_heatmap=keypoint_heatmap,
keypoint_offset=keypoint_offset,
keypoint_regression=keypoint_regression,
stride=self._stride,
object_center_std_dev=object_center_std_dev,
keypoint_std_dev=keypoint_std_dev,
kp_params=kp_params))
keypoints, keypoint_scores = (
convert_strided_predictions_to_normalized_keypoints(
keypoints,
keypoint_scores,
self._stride,
true_image_shapes,
clip_out_of_frame_keypoints=False))
postprocess_dict = {
fields.DetectionResultFields.detection_keypoints: keypoints,
fields.DetectionResultFields.detection_keypoint_scores: keypoint_scores
}
return postprocess_dict
def _postprocess_embeddings(self, prediction_dict, y_indices, x_indices): def _postprocess_embeddings(self, prediction_dict, y_indices, x_indices):
"""Performs postprocessing on embedding predictions. """Performs postprocessing on embedding predictions.
......
...@@ -734,6 +734,75 @@ class CenterNetMetaArchHelpersTest(test_case.TestCase, parameterized.TestCase): ...@@ -734,6 +734,75 @@ class CenterNetMetaArchHelpersTest(test_case.TestCase, parameterized.TestCase):
np.testing.assert_array_equal(expected_num_keypoint_candidates, np.testing.assert_array_equal(expected_num_keypoint_candidates,
num_keypoint_candidates) num_keypoint_candidates)
def test_prediction_to_single_instance_keypoints(self):
image_size = (9, 9)
object_heatmap_np = np.zeros((1, image_size[0], image_size[1], 1),
dtype=np.float32)
# This should be picked.
object_heatmap_np[0, 4, 4, 0] = 0.9
# This shouldn't be picked since it's farther away from the center.
object_heatmap_np[0, 2, 2, 0] = 1.0
keypoint_heatmap_np = np.zeros((1, image_size[0], image_size[1], 4),
dtype=np.float32)
# Top-left corner should be picked.
keypoint_heatmap_np[0, 1, 1, 0] = 0.9
keypoint_heatmap_np[0, 4, 4, 0] = 1.0
# Top-right corner should be picked.
keypoint_heatmap_np[0, 1, 7, 1] = 0.9
keypoint_heatmap_np[0, 4, 4, 1] = 1.0
# Bottom-left corner should be picked.
keypoint_heatmap_np[0, 7, 1, 2] = 0.9
keypoint_heatmap_np[0, 4, 4, 2] = 1.0
# Bottom-right corner should be picked.
keypoint_heatmap_np[0, 7, 7, 3] = 0.9
keypoint_heatmap_np[0, 4, 4, 3] = 1.0
keypoint_offset_np = np.zeros((1, image_size[0], image_size[1], 2),
dtype=np.float32)
keypoint_offset_np[0, 1, 1] = [0.5, 0.5]
keypoint_offset_np[0, 1, 7] = [0.5, -0.5]
keypoint_offset_np[0, 7, 1] = [-0.5, 0.5]
keypoint_offset_np[0, 7, 7] = [-0.5, -0.5]
keypoint_regression_np = np.zeros((1, image_size[0], image_size[1], 8),
dtype=np.float32)
keypoint_regression_np[0, 4, 4] = [-3, -3, -3, 3, 3, -3, 3, 3]
kp_params = get_fake_kp_params(num_candidates_per_keypoint=1)
def graph_fn():
object_heatmap = tf.constant(object_heatmap_np, dtype=tf.float32)
keypoint_heatmap = tf.constant(keypoint_heatmap_np, dtype=tf.float32)
keypoint_offset = tf.constant(keypoint_offset_np, dtype=tf.float32)
keypoint_regression = tf.constant(
keypoint_regression_np, dtype=tf.float32)
(keypoint_cands, keypoint_scores) = (
cnma.prediction_to_single_instance_keypoints(
object_heatmap,
keypoint_heatmap,
keypoint_offset,
keypoint_regression,
stride=4,
object_center_std_dev=image_size[0] / 2,
keypoint_std_dev=[image_size[0] / 10],
kp_params=kp_params))
return keypoint_cands, keypoint_scores
(keypoint_cands, keypoint_scores) = self.execute(graph_fn, [])
expected_keypoint_candidates = [[[
[1.5, 1.5], # top-left
[1.5, 6.5], # top-right
[6.5, 1.5], # bottom-left
[6.5, 6.5], # bottom-right
]]]
expected_keypoint_scores = [[[0.9, 0.9, 0.9, 0.9]]]
np.testing.assert_allclose(expected_keypoint_candidates, keypoint_cands)
np.testing.assert_allclose(expected_keypoint_scores, keypoint_scores)
def test_keypoint_candidate_prediction_per_keypoints(self): def test_keypoint_candidate_prediction_per_keypoints(self):
keypoint_heatmap_np = np.zeros((2, 3, 3, 2), dtype=np.float32) keypoint_heatmap_np = np.zeros((2, 3, 3, 2), dtype=np.float32)
keypoint_heatmap_np[0, 0, 0, 0] = 1.0 keypoint_heatmap_np[0, 0, 0, 0] = 1.0
...@@ -1798,6 +1867,59 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase): ...@@ -1798,6 +1867,59 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase):
self.assertAllEqual([1, max_detection, num_keypoints], self.assertAllEqual([1, max_detection, num_keypoints],
detections['detection_keypoint_scores'].shape) detections['detection_keypoint_scores'].shape)
def test_postprocess_single_instance(self):
"""Test the postprocess single instance function."""
model = build_center_net_meta_arch(num_classes=1)
num_keypoints = len(model._kp_params_dict[_TASK_NAME].keypoint_indices)
class_center = np.zeros((1, 32, 32, 1), dtype=np.float32)
keypoint_heatmaps = np.zeros((1, 32, 32, num_keypoints), dtype=np.float32)
keypoint_offsets = np.zeros((1, 32, 32, 2), dtype=np.float32)
keypoint_regression = np.random.randn(1, 32, 32, num_keypoints * 2)
class_probs = np.zeros(1)
class_probs[0] = _logit(0.75)
class_center[0, 16, 16] = class_probs
keypoint_regression[0, 16, 16] = [
-1., -1.,
-1., 1.,
1., -1.,
1., 1.]
keypoint_heatmaps[0, 14, 14, 0] = _logit(0.9)
keypoint_heatmaps[0, 14, 18, 1] = _logit(0.9)
keypoint_heatmaps[0, 18, 14, 2] = _logit(0.9)
keypoint_heatmaps[0, 18, 18, 3] = _logit(0.05) # Note the low score.
class_center = tf.constant(class_center)
keypoint_heatmaps = tf.constant(keypoint_heatmaps, dtype=tf.float32)
keypoint_offsets = tf.constant(keypoint_offsets, dtype=tf.float32)
keypoint_regression = tf.constant(keypoint_regression, dtype=tf.float32)
prediction_dict = {
cnma.OBJECT_CENTER: [class_center],
cnma.get_keypoint_name(_TASK_NAME, cnma.KEYPOINT_HEATMAP):
[keypoint_heatmaps],
cnma.get_keypoint_name(_TASK_NAME, cnma.KEYPOINT_OFFSET):
[keypoint_offsets],
cnma.get_keypoint_name(_TASK_NAME, cnma.KEYPOINT_REGRESSION):
[keypoint_regression],
}
def graph_fn():
detections = model.postprocess_single_instance_keypoints(
prediction_dict,
tf.constant([[128, 128, 3]]),
object_center_std_dev=32,
keypoint_std_dev=[32])
return detections
detections = self.execute_cpu(graph_fn, [])
self.assertAllEqual([1, 1, num_keypoints, 2],
detections['detection_keypoints'].shape)
self.assertAllEqual([1, 1, num_keypoints],
detections['detection_keypoint_scores'].shape)
def test_get_instance_indices(self): def test_get_instance_indices(self):
classes = tf.constant([[0, 1, 2, 0], [2, 1, 2, 2]], dtype=tf.int32) classes = tf.constant([[0, 1, 2, 0], [2, 1, 2, 2]], dtype=tf.int32)
num_detections = tf.constant([1, 3], dtype=tf.int32) num_detections = tf.constant([1, 3], dtype=tf.int32)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment