Commit aa94accd authored by Yu-hui Chen's avatar Yu-hui Chen Committed by TF Object Detection Team
Browse files

Updated the keypoint target assigner such that it blacks out the instance bbox

region for the keypoint heatmap if the instance's keypoint visibility is 0.

PiperOrigin-RevId: 378183175
parent acdb71d6
...@@ -1409,8 +1409,10 @@ class CenterNetKeypointTargetAssigner(object): ...@@ -1409,8 +1409,10 @@ class CenterNetKeypointTargetAssigner(object):
[batch_size, num_keypoints] representing number of instances for each [batch_size, num_keypoints] representing number of instances for each
keypoint type. keypoint type.
valid_mask: A float tensor with shape [batch_size, output_height, valid_mask: A float tensor with shape [batch_size, output_height,
output_width] where all values within the regions of the blackout boxes output_width, num_keypoints] where all values within the regions of the
are 0.0 and 1.0 else where. blackout boxes are 0.0 and 1.0 else where. Note that the blackout boxes
are per keypoint type and are blacked out if the keypoint
visibility/weight (of the corresponding keypoint type) is zero.
""" """
out_width = tf.cast(tf.maximum(width // self._stride, 1), tf.float32) out_width = tf.cast(tf.maximum(width // self._stride, 1), tf.float32)
out_height = tf.cast(tf.maximum(height // self._stride, 1), tf.float32) out_height = tf.cast(tf.maximum(height // self._stride, 1), tf.float32)
...@@ -1480,13 +1482,17 @@ class CenterNetKeypointTargetAssigner(object): ...@@ -1480,13 +1482,17 @@ class CenterNetKeypointTargetAssigner(object):
keypoint_std_dev = keypoint_std_dev * tf.stack( keypoint_std_dev = keypoint_std_dev * tf.stack(
[sigma] * num_keypoints, axis=1) [sigma] * num_keypoints, axis=1)
# Generate the valid region mask to ignore regions with target class but # Generate the per-keypoint type valid region mask to ignore regions
# no corresponding keypoints. # with keypoint weights equal to zeros (e.g. visibility is 0).
# Shape: [num_instances]. # shape of valid_mask: [out_height, out_width, num_keypoints]
blackout = tf.logical_and(classes[:, self._class_id] > 0, kp_weight_list = tf.unstack(kp_weights, axis=1)
tf.reduce_max(kp_weights, axis=1) < 1e-3) valid_mask_channel_list = []
valid_mask = ta_utils.blackout_pixel_weights_by_box_regions( for kp_weight in kp_weight_list:
out_height, out_width, boxes.get(), blackout) blackout = kp_weight < 1e-3
valid_mask_channel_list.append(
ta_utils.blackout_pixel_weights_by_box_regions(
out_height, out_width, boxes.get(), blackout))
valid_mask = tf.stack(valid_mask_channel_list, axis=2)
valid_mask_list.append(valid_mask) valid_mask_list.append(valid_mask)
# Apply the Gaussian kernel to the keypoint coordinates. Returned heatmap # Apply the Gaussian kernel to the keypoint coordinates. Returned heatmap
......
...@@ -1699,7 +1699,7 @@ class CenterNetKeypointTargetAssignerTest(test_case.TestCase): ...@@ -1699,7 +1699,7 @@ class CenterNetKeypointTargetAssignerTest(test_case.TestCase):
np.array([[0.0, 0.0, 0.3, 0.3], np.array([[0.0, 0.0, 0.3, 0.3],
[0.0, 0.0, 0.5, 0.5], [0.0, 0.0, 0.5, 0.5],
[0.0, 0.0, 0.5, 0.5], [0.0, 0.0, 0.5, 0.5],
[0.0, 0.0, 1.0, 1.0]]), [0.5, 0.5, 1.0, 1.0]]),
dtype=tf.float32) dtype=tf.float32)
] ]
...@@ -1728,15 +1728,20 @@ class CenterNetKeypointTargetAssignerTest(test_case.TestCase): ...@@ -1728,15 +1728,20 @@ class CenterNetKeypointTargetAssignerTest(test_case.TestCase):
# Verify the number of instances is correct. # Verify the number of instances is correct.
np.testing.assert_array_almost_equal([[0, 1]], np.testing.assert_array_almost_equal([[0, 1]],
num_instances_batch) num_instances_batch)
self.assertAllEqual([1, 30, 20, 2], valid_mask.shape)
# When calling the function, we specify the class id to be 1 (1th and 3rd) # When calling the function, we specify the class id to be 1 (1th and 3rd)
# instance and the keypoint indices to be [0, 2], meaning that the 1st # instance and the keypoint indices to be [0, 2], meaning that the 1st
# instance is the target class with no valid keypoints in it. As a result, # instance is the target class with no valid keypoints in it. As a result,
# the region of the 1st instance boxing box should be blacked out # the region of both keypoint types of the 1st instance boxing box should be
# (0.0, 0.0, 0.5, 0.5), transfering to (0, 0, 15, 10) in absolute output # blacked out (0.0, 0.0, 0.5, 0.5), transfering to (0, 0, 15, 10) in
# space. # absolute output space.
self.assertAlmostEqual(np.sum(valid_mask[:, 0:15, 0:10]), 0.0) self.assertAlmostEqual(np.sum(valid_mask[:, 0:15, 0:10, 0:2]), 0.0)
# All other values are 1.0 so the sum is: 30 * 20 - 15 * 10 = 450. # For the 2nd instance, only the 1st keypoint has visibility of 0 so only
self.assertAlmostEqual(np.sum(valid_mask), 450.0) # the corresponding valid mask contains zeros.
self.assertAlmostEqual(np.sum(valid_mask[:, 15:30, 10:20, 0]), 0.0)
# All other values are 1.0 so the sum is:
# 30 * 20 * 2 - 15 * 10 * 2 - 15 * 10 * 1 = 750.
self.assertAlmostEqual(np.sum(valid_mask), 750.0)
def test_assign_keypoints_offset_targets(self): def test_assign_keypoints_offset_targets(self):
def graph_fn(): def graph_fn():
......
...@@ -2755,8 +2755,7 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -2755,8 +2755,7 @@ class CenterNetMetaArch(model.DetectionModel):
gt_weights_list=gt_weights_list, gt_weights_list=gt_weights_list,
gt_classes_list=gt_classes_list, gt_classes_list=gt_classes_list,
gt_boxes_list=gt_boxes_list) gt_boxes_list=gt_boxes_list)
flattened_valid_mask = _flatten_spatial_dimensions( flattened_valid_mask = _flatten_spatial_dimensions(valid_mask_batch)
tf.expand_dims(valid_mask_batch, axis=-1))
flattened_heapmap_targets = _flatten_spatial_dimensions(keypoint_heatmap) flattened_heapmap_targets = _flatten_spatial_dimensions(keypoint_heatmap)
# Sum over the number of instances per keypoint types to get the total # Sum over the number of instances per keypoint types to get the total
# number of keypoints. Note that this is used to normalized the loss and we # number of keypoints. Note that this is used to normalized the loss and we
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment