Commit a905826c authored by Yu-hui Chen's avatar Yu-hui Chen Committed by TF Object Detection Team
Browse files

Cleaned up the target_assigner and CenterNet meta arch such that the depth

logic doesn't depend on the per_keypoint_offset value and use
get_batch_predictions_from_indices to be consistent with other prediction
heads.

PiperOrigin-RevId: 369925004
parent d54edbcf
...@@ -1318,7 +1318,8 @@ class CenterNetKeypointTargetAssigner(object): ...@@ -1318,7 +1318,8 @@ class CenterNetKeypointTargetAssigner(object):
keypoint_std_dev=None, keypoint_std_dev=None,
per_keypoint_offset=False, per_keypoint_offset=False,
peak_radius=0, peak_radius=0,
compute_heatmap_sparse=False): compute_heatmap_sparse=False,
per_keypoint_depth=False):
"""Initializes a CenterNet keypoints target assigner. """Initializes a CenterNet keypoints target assigner.
Args: Args:
...@@ -1349,12 +1350,16 @@ class CenterNetKeypointTargetAssigner(object): ...@@ -1349,12 +1350,16 @@ class CenterNetKeypointTargetAssigner(object):
version of the Op that computes the heatmap. The sparse version scales version of the Op that computes the heatmap. The sparse version scales
better with number of keypoint types, but in some cases is known to better with number of keypoint types, but in some cases is known to
cause an OOM error. See (b/170989061). cause an OOM error. See (b/170989061).
per_keypoint_depth: A bool indicates whether the model predicts the depth
of each keypoints in independent channels. Similar to
per_keypoint_offset but for the keypoint depth.
""" """
self._stride = stride self._stride = stride
self._class_id = class_id self._class_id = class_id
self._keypoint_indices = keypoint_indices self._keypoint_indices = keypoint_indices
self._per_keypoint_offset = per_keypoint_offset self._per_keypoint_offset = per_keypoint_offset
self._per_keypoint_depth = per_keypoint_depth
self._peak_radius = peak_radius self._peak_radius = peak_radius
self._compute_heatmap_sparse = compute_heatmap_sparse self._compute_heatmap_sparse = compute_heatmap_sparse
if keypoint_std_dev is None: if keypoint_std_dev is None:
...@@ -1686,14 +1691,15 @@ class CenterNetKeypointTargetAssigner(object): ...@@ -1686,14 +1691,15 @@ class CenterNetKeypointTargetAssigner(object):
Returns: Returns:
batch_indices: an integer tensor of shape [num_total_instances, 3] (or batch_indices: an integer tensor of shape [num_total_instances, 3] (or
[num_total_instances, 4] if 'per_keypoint_offset' is set True) holding [num_total_instances, 4] if 'per_keypoint_depth' is set True) holding
the indices inside the predicted tensor which should be penalized. The the indices inside the predicted tensor which should be penalized. The
first column indicates the index along the batch dimension and the first column indicates the index along the batch dimension and the
second and third columns indicate the index along the y and x second and third columns indicate the index along the y and x
dimensions respectively. The fourth column corresponds to the channel dimensions respectively. The fourth column corresponds to the channel
dimension (if 'per_keypoint_offset' is set True). dimension (if 'per_keypoint_offset' is set True).
batch_depths: a float tensor of shape [num_total_instances, 1] indicating batch_depths: a float tensor of shape [num_total_instances, 1] (or
the target depth of each keypoint. [num_total_instances, num_keypoints] if per_keypoint_depth is set True)
indicating the target depth of each keypoint.
batch_weights: a float tensor of shape [num_total_instances] indicating batch_weights: a float tensor of shape [num_total_instances] indicating
the weight of each prediction. the weight of each prediction.
Note that num_total_instances = batch_size * num_instances * Note that num_total_instances = batch_size * num_instances *
...@@ -1800,7 +1806,7 @@ class CenterNetKeypointTargetAssigner(object): ...@@ -1800,7 +1806,7 @@ class CenterNetKeypointTargetAssigner(object):
# Prepare the batch indices to be prepended. # Prepare the batch indices to be prepended.
batch_index = tf.fill( batch_index = tf.fill(
[num_instances * num_keypoints * num_neighbors, 1], i) [num_instances * num_keypoints * num_neighbors, 1], i)
if self._per_keypoint_offset: if self._per_keypoint_depth:
tiled_keypoint_types = self._get_keypoint_types( tiled_keypoint_types = self._get_keypoint_types(
num_instances, num_keypoints, num_neighbors) num_instances, num_keypoints, num_neighbors)
batch_indices.append( batch_indices.append(
......
...@@ -1863,7 +1863,7 @@ class CenterNetKeypointTargetAssignerTest(test_case.TestCase): ...@@ -1863,7 +1863,7 @@ class CenterNetKeypointTargetAssignerTest(test_case.TestCase):
class_id=1, class_id=1,
keypoint_indices=[0, 2], keypoint_indices=[0, 2],
peak_radius=1, peak_radius=1,
per_keypoint_offset=True) per_keypoint_depth=True)
(indices, depths, weights) = cn_assigner.assign_keypoints_depth_targets( (indices, depths, weights) = cn_assigner.assign_keypoints_depth_targets(
height=120, height=120,
width=80, width=80,
......
...@@ -2430,7 +2430,8 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -2430,7 +2430,8 @@ class CenterNetMetaArch(model.DetectionModel):
keypoint_std_dev=kp_params.keypoint_std_dev, keypoint_std_dev=kp_params.keypoint_std_dev,
peak_radius=kp_params.offset_peak_radius, peak_radius=kp_params.offset_peak_radius,
per_keypoint_offset=kp_params.per_keypoint_offset, per_keypoint_offset=kp_params.per_keypoint_offset,
compute_heatmap_sparse=self._compute_heatmap_sparse)) compute_heatmap_sparse=self._compute_heatmap_sparse,
per_keypoint_depth=kp_params.per_keypoint_depth))
if self._mask_params is not None: if self._mask_params is not None:
target_assigners[SEGMENTATION_TASK] = ( target_assigners[SEGMENTATION_TASK] = (
cn_assigner.CenterNetMaskTargetAssigner(stride)) cn_assigner.CenterNetMaskTargetAssigner(stride))
...@@ -2853,17 +2854,13 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -2853,17 +2854,13 @@ class CenterNetMetaArch(model.DetectionModel):
gt_keypoint_depths_list=gt_keypoint_depths_list, gt_keypoint_depths_list=gt_keypoint_depths_list,
gt_keypoint_depth_weights_list=gt_keypoint_depth_weights_list) gt_keypoint_depth_weights_list=gt_keypoint_depth_weights_list)
if kp_params.per_keypoint_offset and not kp_params.per_keypoint_depth:
batch_indices = batch_indices[:, 0:3]
# Keypoint offset loss. # Keypoint offset loss.
loss = 0.0 loss = 0.0
for prediction in depth_predictions: for prediction in depth_predictions:
# TODO(yuhuic): Update this function to use if kp_params.per_keypoint_depth:
# cn_assigner.get_batch_predictions_from_indices(). prediction = tf.expand_dims(prediction, axis=-1)
selected_depths = tf.gather_nd(prediction, batch_indices) selected_depths = cn_assigner.get_batch_predictions_from_indices(
if kp_params.per_keypoint_offset and kp_params.per_keypoint_depth: prediction, batch_indices)
selected_depths = tf.expand_dims(selected_depths, axis=-1)
# The dimensions passed are not as per the doc string but the loss # The dimensions passed are not as per the doc string but the loss
# still computes the correct value. # still computes the correct value.
unweighted_loss = localization_loss_fn( unweighted_loss = localization_loss_fn(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment