Commit a905826c authored by Yu-hui Chen's avatar Yu-hui Chen Committed by TF Object Detection Team
Browse files

Cleaned up the target_assigner and CenterNet meta arch such that the depth

logic doesn't depend on the per_keypoint_offset value and use
get_batch_predictions_from_indices to be consistent with other prediction
heads.

PiperOrigin-RevId: 369925004
parent d54edbcf
......@@ -1318,7 +1318,8 @@ class CenterNetKeypointTargetAssigner(object):
keypoint_std_dev=None,
per_keypoint_offset=False,
peak_radius=0,
compute_heatmap_sparse=False):
compute_heatmap_sparse=False,
per_keypoint_depth=False):
"""Initializes a CenterNet keypoints target assigner.
Args:
......@@ -1349,12 +1350,16 @@ class CenterNetKeypointTargetAssigner(object):
version of the Op that computes the heatmap. The sparse version scales
better with number of keypoint types, but in some cases is known to
cause an OOM error. See (b/170989061).
per_keypoint_depth: A bool indicates whether the model predicts the depth
of each keypoints in independent channels. Similar to
per_keypoint_offset but for the keypoint depth.
"""
self._stride = stride
self._class_id = class_id
self._keypoint_indices = keypoint_indices
self._per_keypoint_offset = per_keypoint_offset
self._per_keypoint_depth = per_keypoint_depth
self._peak_radius = peak_radius
self._compute_heatmap_sparse = compute_heatmap_sparse
if keypoint_std_dev is None:
......@@ -1686,14 +1691,15 @@ class CenterNetKeypointTargetAssigner(object):
Returns:
batch_indices: an integer tensor of shape [num_total_instances, 3] (or
[num_total_instances, 4] if 'per_keypoint_offset' is set True) holding
[num_total_instances, 4] if 'per_keypoint_depth' is set True) holding
the indices inside the predicted tensor which should be penalized. The
first column indicates the index along the batch dimension and the
second and third columns indicate the index along the y and x
dimensions respectively. The fourth column corresponds to the channel
dimension (if 'per_keypoint_offset' is set True).
batch_depths: a float tensor of shape [num_total_instances, 1] indicating
the target depth of each keypoint.
batch_depths: a float tensor of shape [num_total_instances, 1] (or
[num_total_instances, num_keypoints] if per_keypoint_depth is set True)
indicating the target depth of each keypoint.
batch_weights: a float tensor of shape [num_total_instances] indicating
the weight of each prediction.
Note that num_total_instances = batch_size * num_instances *
......@@ -1800,7 +1806,7 @@ class CenterNetKeypointTargetAssigner(object):
# Prepare the batch indices to be prepended.
batch_index = tf.fill(
[num_instances * num_keypoints * num_neighbors, 1], i)
if self._per_keypoint_offset:
if self._per_keypoint_depth:
tiled_keypoint_types = self._get_keypoint_types(
num_instances, num_keypoints, num_neighbors)
batch_indices.append(
......
......@@ -1863,7 +1863,7 @@ class CenterNetKeypointTargetAssignerTest(test_case.TestCase):
class_id=1,
keypoint_indices=[0, 2],
peak_radius=1,
per_keypoint_offset=True)
per_keypoint_depth=True)
(indices, depths, weights) = cn_assigner.assign_keypoints_depth_targets(
height=120,
width=80,
......
......@@ -2430,7 +2430,8 @@ class CenterNetMetaArch(model.DetectionModel):
keypoint_std_dev=kp_params.keypoint_std_dev,
peak_radius=kp_params.offset_peak_radius,
per_keypoint_offset=kp_params.per_keypoint_offset,
compute_heatmap_sparse=self._compute_heatmap_sparse))
compute_heatmap_sparse=self._compute_heatmap_sparse,
per_keypoint_depth=kp_params.per_keypoint_depth))
if self._mask_params is not None:
target_assigners[SEGMENTATION_TASK] = (
cn_assigner.CenterNetMaskTargetAssigner(stride))
......@@ -2853,17 +2854,13 @@ class CenterNetMetaArch(model.DetectionModel):
gt_keypoint_depths_list=gt_keypoint_depths_list,
gt_keypoint_depth_weights_list=gt_keypoint_depth_weights_list)
if kp_params.per_keypoint_offset and not kp_params.per_keypoint_depth:
batch_indices = batch_indices[:, 0:3]
# Keypoint offset loss.
loss = 0.0
for prediction in depth_predictions:
# TODO(yuhuic): Update this function to use
# cn_assigner.get_batch_predictions_from_indices().
selected_depths = tf.gather_nd(prediction, batch_indices)
if kp_params.per_keypoint_offset and kp_params.per_keypoint_depth:
selected_depths = tf.expand_dims(selected_depths, axis=-1)
if kp_params.per_keypoint_depth:
prediction = tf.expand_dims(prediction, axis=-1)
selected_depths = cn_assigner.get_batch_predictions_from_indices(
prediction, batch_indices)
# The dimensions passed are not as per the doc string but the loss
# still computes the correct value.
unweighted_loss = localization_loss_fn(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment