Commit 0c85c06c authored by Yu-hui Chen's avatar Yu-hui Chen Committed by TF Object Detection Team
Browse files

Extended CenterNet model to predict keypoint depth information.

PiperOrigin-RevId: 359344675
parent 3cfd0ba0
...@@ -868,7 +868,10 @@ def keypoint_proto_to_params(kp_config, keypoint_map_dict): ...@@ -868,7 +868,10 @@ def keypoint_proto_to_params(kp_config, keypoint_map_dict):
candidate_search_scale=kp_config.candidate_search_scale, candidate_search_scale=kp_config.candidate_search_scale,
candidate_ranking_mode=kp_config.candidate_ranking_mode, candidate_ranking_mode=kp_config.candidate_ranking_mode,
offset_peak_radius=kp_config.offset_peak_radius, offset_peak_radius=kp_config.offset_peak_radius,
per_keypoint_offset=kp_config.per_keypoint_offset) per_keypoint_offset=kp_config.per_keypoint_offset,
predict_depth=kp_config.predict_depth,
per_keypoint_depth=kp_config.per_keypoint_depth,
keypoint_depth_loss_weight=kp_config.keypoint_depth_loss_weight)
def object_detection_proto_to_params(od_config): def object_detection_proto_to_params(od_config):
......
...@@ -116,6 +116,9 @@ class ModelBuilderTF2Test(model_builder_test.ModelBuilderTest): ...@@ -116,6 +116,9 @@ class ModelBuilderTF2Test(model_builder_test.ModelBuilderTest):
candidate_ranking_mode: "score_distance_ratio" candidate_ranking_mode: "score_distance_ratio"
offset_peak_radius: 3 offset_peak_radius: 3
per_keypoint_offset: true per_keypoint_offset: true
predict_depth: true
per_keypoint_depth: true
keypoint_depth_loss_weight: 0.3
""" """
config = text_format.Merge(task_proto_txt, config = text_format.Merge(task_proto_txt,
center_net_pb2.CenterNet.KeypointEstimation()) center_net_pb2.CenterNet.KeypointEstimation())
...@@ -264,6 +267,9 @@ class ModelBuilderTF2Test(model_builder_test.ModelBuilderTest): ...@@ -264,6 +267,9 @@ class ModelBuilderTF2Test(model_builder_test.ModelBuilderTest):
self.assertEqual(kp_params.candidate_ranking_mode, 'score_distance_ratio') self.assertEqual(kp_params.candidate_ranking_mode, 'score_distance_ratio')
self.assertEqual(kp_params.offset_peak_radius, 3) self.assertEqual(kp_params.offset_peak_radius, 3)
self.assertEqual(kp_params.per_keypoint_offset, True) self.assertEqual(kp_params.per_keypoint_offset, True)
self.assertEqual(kp_params.predict_depth, True)
self.assertEqual(kp_params.per_keypoint_depth, True)
self.assertAlmostEqual(kp_params.keypoint_depth_loss_weight, 0.3)
# Check mask related parameters. # Check mask related parameters.
self.assertAlmostEqual(model._mask_params.task_loss_weight, 0.7) self.assertAlmostEqual(model._mask_params.task_loss_weight, 0.7)
......
...@@ -423,12 +423,12 @@ def prediction_tensors_to_temporal_offsets( ...@@ -423,12 +423,12 @@ def prediction_tensors_to_temporal_offsets(
return offsets return offsets
def prediction_tensors_to_keypoint_candidates( def prediction_tensors_to_keypoint_candidates(keypoint_heatmap_predictions,
keypoint_heatmap_predictions, keypoint_heatmap_offsets,
keypoint_heatmap_offsets, keypoint_score_threshold=0.1,
keypoint_score_threshold=0.1, max_pool_kernel_size=1,
max_pool_kernel_size=1, max_candidates=20,
max_candidates=20): keypoint_depths=None):
"""Convert keypoint heatmap predictions and offsets to keypoint candidates. """Convert keypoint heatmap predictions and offsets to keypoint candidates.
Args: Args:
...@@ -437,14 +437,17 @@ def prediction_tensors_to_keypoint_candidates( ...@@ -437,14 +437,17 @@ def prediction_tensors_to_keypoint_candidates(
keypoint_heatmap_offsets: A float tensor of shape [batch_size, height, keypoint_heatmap_offsets: A float tensor of shape [batch_size, height,
width, 2] (or [batch_size, height, width, 2 * num_keypoints] if width, 2] (or [batch_size, height, width, 2 * num_keypoints] if
'per_keypoint_offset' is set True) representing the per-keypoint offsets. 'per_keypoint_offset' is set True) representing the per-keypoint offsets.
keypoint_score_threshold: float, the threshold for considering a keypoint keypoint_score_threshold: float, the threshold for considering a keypoint a
a candidate. candidate.
max_pool_kernel_size: integer, the max pool kernel size to use to pull off max_pool_kernel_size: integer, the max pool kernel size to use to pull off
peak score locations in a neighborhood. For example, to make sure no two peak score locations in a neighborhood. For example, to make sure no two
neighboring values for the same keypoint are returned, set neighboring values for the same keypoint are returned, set
max_pool_kernel_size=3. If None or 1, will not apply any local filtering. max_pool_kernel_size=3. If None or 1, will not apply any local filtering.
max_candidates: integer, maximum number of keypoint candidates per max_candidates: integer, maximum number of keypoint candidates per keypoint
keypoint type. type.
keypoint_depths: (optional) A float tensor of shape [batch_size, height,
width, 1] (or [batch_size, height, width, num_keypoints] if
'per_keypoint_depth' is set True) representing the per-keypoint depths.
Returns: Returns:
keypoint_candidates: A tensor of shape keypoint_candidates: A tensor of shape
...@@ -458,6 +461,9 @@ def prediction_tensors_to_keypoint_candidates( ...@@ -458,6 +461,9 @@ def prediction_tensors_to_keypoint_candidates(
[batch_size, num_keypoints] with the number of candidates for each [batch_size, num_keypoints] with the number of candidates for each
keypoint type, as it's possible to filter some candidates due to the score keypoint type, as it's possible to filter some candidates due to the score
threshold. threshold.
depth_candidates: A tensor of shape [batch_size, max_candidates,
num_keypoints] representing the estimated depth of each keypoint
candidate. Return None if the input keypoint_depths is None.
""" """
batch_size, _, _, num_keypoints = _get_shape(keypoint_heatmap_predictions, 4) batch_size, _, _, num_keypoints = _get_shape(keypoint_heatmap_predictions, 4)
# Get x, y and channel indices corresponding to the top indices in the # Get x, y and channel indices corresponding to the top indices in the
...@@ -499,13 +505,13 @@ def prediction_tensors_to_keypoint_candidates( ...@@ -499,13 +505,13 @@ def prediction_tensors_to_keypoint_candidates(
# TF Lite does not support tf.gather with batch_dims > 0, so we need to use # TF Lite does not support tf.gather with batch_dims > 0, so we need to use
# tf_gather_nd instead and here we prepare the indices for that. In this # tf_gather_nd instead and here we prepare the indices for that. In this
# case, channel_indices indicates which keypoint to use the offset from. # case, channel_indices indicates which keypoint to use the offset from.
combined_indices = tf.stack([ channel_combined_indices = tf.stack([
_multi_range(batch_size, value_repetitions=num_indices), _multi_range(batch_size, value_repetitions=num_indices),
_multi_range(num_indices, range_repetitions=batch_size), _multi_range(num_indices, range_repetitions=batch_size),
tf.reshape(channel_indices, [-1]) tf.reshape(channel_indices, [-1])
], axis=1) ], axis=1)
offsets = tf.gather_nd(reshaped_offsets, combined_indices) offsets = tf.gather_nd(reshaped_offsets, channel_combined_indices)
offsets = tf.reshape(offsets, [batch_size, num_indices, -1]) offsets = tf.reshape(offsets, [batch_size, num_indices, -1])
else: else:
offsets = selected_offsets offsets = selected_offsets
...@@ -524,14 +530,38 @@ def prediction_tensors_to_keypoint_candidates( ...@@ -524,14 +530,38 @@ def prediction_tensors_to_keypoint_candidates(
num_candidates = tf.reduce_sum( num_candidates = tf.reduce_sum(
tf.to_int32(keypoint_scores >= keypoint_score_threshold), axis=1) tf.to_int32(keypoint_scores >= keypoint_score_threshold), axis=1)
return keypoint_candidates, keypoint_scores, num_candidates depth_candidates = None
if keypoint_depths is not None:
selected_depth_flat = tf.gather_nd(keypoint_depths, combined_indices)
selected_depth = tf.reshape(selected_depth_flat,
[batch_size, num_indices, -1])
_, _, num_depth_channels = _get_shape(selected_depth, 3)
if num_depth_channels > 1:
combined_indices = tf.stack([
_multi_range(batch_size, value_repetitions=num_indices),
_multi_range(num_indices, range_repetitions=batch_size),
tf.reshape(channel_indices, [-1])
], axis=1)
depth = tf.gather_nd(selected_depth, combined_indices)
depth = tf.reshape(depth, [batch_size, num_indices, -1])
else:
depth = selected_depth
depth_candidates = tf.reshape(depth,
[batch_size, num_keypoints, max_candidates])
depth_candidates = tf.transpose(depth_candidates, [0, 2, 1])
return keypoint_candidates, keypoint_scores, num_candidates, depth_candidates
def prediction_to_single_instance_keypoints(object_heatmap, keypoint_heatmap,
def prediction_to_single_instance_keypoints(object_heatmap,
keypoint_heatmap,
keypoint_offset, keypoint_offset,
keypoint_regression, stride, keypoint_regression,
stride,
object_center_std_dev, object_center_std_dev,
keypoint_std_dev, kp_params): keypoint_std_dev,
kp_params,
keypoint_depths=None):
"""Postprocess function to predict single instance keypoints. """Postprocess function to predict single instance keypoints.
This is a simplified postprocessing function based on the assumption that This is a simplified postprocessing function based on the assumption that
...@@ -569,6 +599,9 @@ def prediction_to_single_instance_keypoints(object_heatmap, keypoint_heatmap, ...@@ -569,6 +599,9 @@ def prediction_to_single_instance_keypoints(object_heatmap, keypoint_heatmap,
representing the standard deviation corresponding to each joint. representing the standard deviation corresponding to each joint.
kp_params: A `KeypointEstimationParams` object with parameters for a single kp_params: A `KeypointEstimationParams` object with parameters for a single
keypoint class. keypoint class.
keypoint_depths: (optional) A float tensor of shape [batch_size, height,
width, 1] (or [batch_size, height, width, num_keypoints] if
'per_keypoint_depth' is set True) representing the per-keypoint depths.
Returns: Returns:
A tuple of two tensors: A tuple of two tensors:
...@@ -577,6 +610,9 @@ def prediction_to_single_instance_keypoints(object_heatmap, keypoint_heatmap, ...@@ -577,6 +610,9 @@ def prediction_to_single_instance_keypoints(object_heatmap, keypoint_heatmap,
map space. map space.
keypoint_scores: A float tensor with shape [1, 1, num_keypoints] keypoint_scores: A float tensor with shape [1, 1, num_keypoints]
representing the keypoint prediction scores. representing the keypoint prediction scores.
keypoint_depths: A float tensor with shape [1, 1, num_keypoints]
representing the estimated keypoint depths. Return None if the input
keypoint_depths is None.
Raises: Raises:
ValueError: if the input keypoint_std_dev doesn't have valid number of ValueError: if the input keypoint_std_dev doesn't have valid number of
...@@ -636,14 +672,16 @@ def prediction_to_single_instance_keypoints(object_heatmap, keypoint_heatmap, ...@@ -636,14 +672,16 @@ def prediction_to_single_instance_keypoints(object_heatmap, keypoint_heatmap,
# Get the keypoint locations/scores: # Get the keypoint locations/scores:
# keypoint_candidates: [1, 1, num_keypoints, 2] # keypoint_candidates: [1, 1, num_keypoints, 2]
# keypoint_scores: [1, 1, num_keypoints] # keypoint_scores: [1, 1, num_keypoints]
(keypoint_candidates, keypoint_scores, # depth_candidates: [1, 1, num_keypoints]
_) = prediction_tensors_to_keypoint_candidates( (keypoint_candidates, keypoint_scores, _,
depth_candidates) = prediction_tensors_to_keypoint_candidates(
keypoint_predictions, keypoint_predictions,
keypoint_offset, keypoint_offset,
keypoint_score_threshold=kp_params.keypoint_candidate_score_threshold, keypoint_score_threshold=kp_params.keypoint_candidate_score_threshold,
max_pool_kernel_size=kp_params.peak_max_pool_kernel_size, max_pool_kernel_size=kp_params.peak_max_pool_kernel_size,
max_candidates=1) max_candidates=1,
return keypoint_candidates, keypoint_scores keypoint_depths=keypoint_depths)
return keypoint_candidates, keypoint_scores, depth_candidates
def regressed_keypoints_at_object_centers(regressed_keypoint_predictions, def regressed_keypoints_at_object_centers(regressed_keypoint_predictions,
...@@ -697,11 +735,16 @@ def regressed_keypoints_at_object_centers(regressed_keypoint_predictions, ...@@ -697,11 +735,16 @@ def regressed_keypoints_at_object_centers(regressed_keypoint_predictions,
[batch_size, num_instances, -1]) [batch_size, num_instances, -1])
def refine_keypoints(regressed_keypoints, keypoint_candidates, keypoint_scores, def refine_keypoints(regressed_keypoints,
num_keypoint_candidates, bboxes=None, keypoint_candidates,
unmatched_keypoint_score=0.1, box_scale=1.2, keypoint_scores,
num_keypoint_candidates,
bboxes=None,
unmatched_keypoint_score=0.1,
box_scale=1.2,
candidate_search_scale=0.3, candidate_search_scale=0.3,
candidate_ranking_mode='min_distance'): candidate_ranking_mode='min_distance',
keypoint_depth_candidates=None):
"""Refines regressed keypoints by snapping to the nearest candidate keypoints. """Refines regressed keypoints by snapping to the nearest candidate keypoints.
The initial regressed keypoints represent a full set of keypoints regressed The initial regressed keypoints represent a full set of keypoints regressed
...@@ -757,6 +800,9 @@ def refine_keypoints(regressed_keypoints, keypoint_candidates, keypoint_scores, ...@@ -757,6 +800,9 @@ def refine_keypoints(regressed_keypoints, keypoint_candidates, keypoint_scores,
candidate_ranking_mode: A string as one of ['min_distance', candidate_ranking_mode: A string as one of ['min_distance',
'score_distance_ratio'] indicating how to select the candidate. If invalid 'score_distance_ratio'] indicating how to select the candidate. If invalid
value is provided, an ValueError will be raised. value is provided, an ValueError will be raised.
keypoint_depth_candidates: (optional) A float tensor of shape
[batch_size, max_candidates, num_keypoints] indicating the depths for
keypoint candidates.
Returns: Returns:
A tuple with: A tuple with:
...@@ -836,9 +882,11 @@ def refine_keypoints(regressed_keypoints, keypoint_candidates, keypoint_scores, ...@@ -836,9 +882,11 @@ def refine_keypoints(regressed_keypoints, keypoint_candidates, keypoint_scores,
# Gather the coordinates and scores corresponding to the closest candidates. # Gather the coordinates and scores corresponding to the closest candidates.
# Shape of tensors are [batch_size, num_instances, num_keypoints, 2] and # Shape of tensors are [batch_size, num_instances, num_keypoints, 2] and
# [batch_size, num_instances, num_keypoints], respectively. # [batch_size, num_instances, num_keypoints], respectively.
nearby_candidate_coords, nearby_candidate_scores = ( (nearby_candidate_coords, nearby_candidate_scores,
_gather_candidates_at_indices(keypoint_candidates, keypoint_scores, nearby_candidate_depths) = (
nearby_candidate_inds)) _gather_candidates_at_indices(keypoint_candidates, keypoint_scores,
nearby_candidate_inds,
keypoint_depth_candidates))
if bboxes is None: if bboxes is None:
# Create bboxes from regressed keypoints. # Create bboxes from regressed keypoints.
...@@ -895,7 +943,12 @@ def refine_keypoints(regressed_keypoints, keypoint_candidates, keypoint_scores, ...@@ -895,7 +943,12 @@ def refine_keypoints(regressed_keypoints, keypoint_candidates, keypoint_scores,
unmatched_keypoint_score * tf.ones_like(nearby_candidate_scores), unmatched_keypoint_score * tf.ones_like(nearby_candidate_scores),
nearby_candidate_scores) nearby_candidate_scores)
return refined_keypoints, refined_scores refined_depths = None
if nearby_candidate_depths is not None:
refined_depths = tf.where(mask, tf.zeros_like(nearby_candidate_depths),
nearby_candidate_depths)
return refined_keypoints, refined_scores, refined_depths
def _pad_to_full_keypoint_dim(keypoint_coords, keypoint_scores, keypoint_inds, def _pad_to_full_keypoint_dim(keypoint_coords, keypoint_scores, keypoint_inds,
...@@ -976,8 +1029,10 @@ def _pad_to_full_instance_dim(keypoint_coords, keypoint_scores, instance_inds, ...@@ -976,8 +1029,10 @@ def _pad_to_full_instance_dim(keypoint_coords, keypoint_scores, instance_inds,
return keypoint_coords_padded, keypoint_scores_padded return keypoint_coords_padded, keypoint_scores_padded
def _gather_candidates_at_indices(keypoint_candidates, keypoint_scores, def _gather_candidates_at_indices(keypoint_candidates,
indices): keypoint_scores,
indices,
keypoint_depth_candidates=None):
"""Gathers keypoint candidate coordinates and scores at indices. """Gathers keypoint candidate coordinates and scores at indices.
Args: Args:
...@@ -987,13 +1042,18 @@ def _gather_candidates_at_indices(keypoint_candidates, keypoint_scores, ...@@ -987,13 +1042,18 @@ def _gather_candidates_at_indices(keypoint_candidates, keypoint_scores,
num_keypoints] with keypoint scores. num_keypoints] with keypoint scores.
indices: an integer tensor of shape [batch_size, num_indices, num_keypoints] indices: an integer tensor of shape [batch_size, num_indices, num_keypoints]
with indices. with indices.
keypoint_depth_candidates: (optional) a float tensor of shape [batch_size,
max_candidates, num_keypoints] with keypoint depths.
Returns: Returns:
A tuple with A tuple with
gathered_keypoint_candidates: a float tensor of shape [batch_size, gathered_keypoint_candidates: a float tensor of shape [batch_size,
num_indices, num_keypoints, 2] with gathered coordinates. num_indices, num_keypoints, 2] with gathered coordinates.
gathered_keypoint_scores: a float tensor of shape [batch_size, gathered_keypoint_scores: a float tensor of shape [batch_size,
num_indices, num_keypoints, 2]. num_indices, num_keypoints].
gathered_keypoint_depths: a float tensor of shape [batch_size,
num_indices, num_keypoints]. Return None if the input
keypoint_depth_candidates is None.
""" """
batch_size, num_indices, num_keypoints = _get_shape(indices, 3) batch_size, num_indices, num_keypoints = _get_shape(indices, 3)
...@@ -1035,7 +1095,19 @@ def _gather_candidates_at_indices(keypoint_candidates, keypoint_scores, ...@@ -1035,7 +1095,19 @@ def _gather_candidates_at_indices(keypoint_candidates, keypoint_scores,
gathered_keypoint_scores = tf.transpose(nearby_candidate_scores_transposed, gathered_keypoint_scores = tf.transpose(nearby_candidate_scores_transposed,
[0, 2, 1]) [0, 2, 1])
return gathered_keypoint_candidates, gathered_keypoint_scores gathered_keypoint_depths = None
if keypoint_depth_candidates is not None:
keypoint_depths_transposed = tf.transpose(keypoint_depth_candidates,
[0, 2, 1])
nearby_candidate_depths_transposed = tf.gather_nd(
keypoint_depths_transposed, combined_indices)
nearby_candidate_depths_transposed = tf.reshape(
nearby_candidate_depths_transposed,
[batch_size, num_keypoints, num_indices])
gathered_keypoint_depths = tf.transpose(nearby_candidate_depths_transposed,
[0, 2, 1])
return (gathered_keypoint_candidates, gathered_keypoint_scores,
gathered_keypoint_depths)
def flattened_indices_from_row_col_indices(row_indices, col_indices, num_cols): def flattened_indices_from_row_col_indices(row_indices, col_indices, num_cols):
...@@ -1517,7 +1589,8 @@ class KeypointEstimationParams( ...@@ -1517,7 +1589,8 @@ class KeypointEstimationParams(
'heatmap_bias_init', 'num_candidates_per_keypoint', 'task_loss_weight', 'heatmap_bias_init', 'num_candidates_per_keypoint', 'task_loss_weight',
'peak_max_pool_kernel_size', 'unmatched_keypoint_score', 'box_scale', 'peak_max_pool_kernel_size', 'unmatched_keypoint_score', 'box_scale',
'candidate_search_scale', 'candidate_ranking_mode', 'candidate_search_scale', 'candidate_ranking_mode',
'offset_peak_radius', 'per_keypoint_offset' 'offset_peak_radius', 'per_keypoint_offset', 'predict_depth',
'per_keypoint_depth', 'keypoint_depth_loss_weight'
])): ])):
"""Namedtuple to host object detection related parameters. """Namedtuple to host object detection related parameters.
...@@ -1550,7 +1623,10 @@ class KeypointEstimationParams( ...@@ -1550,7 +1623,10 @@ class KeypointEstimationParams(
candidate_search_scale=0.3, candidate_search_scale=0.3,
candidate_ranking_mode='min_distance', candidate_ranking_mode='min_distance',
offset_peak_radius=0, offset_peak_radius=0,
per_keypoint_offset=False): per_keypoint_offset=False,
predict_depth=False,
per_keypoint_depth=False,
keypoint_depth_loss_weight=1.0):
"""Constructor with default values for KeypointEstimationParams. """Constructor with default values for KeypointEstimationParams.
Args: Args:
...@@ -1614,6 +1690,12 @@ class KeypointEstimationParams( ...@@ -1614,6 +1690,12 @@ class KeypointEstimationParams(
original paper). If set True, the output offset target has the shape original paper). If set True, the output offset target has the shape
[batch_size, out_height, out_width, 2 * num_keypoints] (recommended when [batch_size, out_height, out_width, 2 * num_keypoints] (recommended when
the offset_peak_radius is not zero). the offset_peak_radius is not zero).
predict_depth: A bool indicates whether to predict the depth of each
keypoints.
per_keypoint_depth: A bool indicates whether the model predicts the depth
of each keypoints in independent channels. Similar to
per_keypoint_offset but for the keypoint depth.
keypoint_depth_loss_weight: The weight of the keypoint depth loss.
Returns: Returns:
An initialized KeypointEstimationParams namedtuple. An initialized KeypointEstimationParams namedtuple.
...@@ -1626,7 +1708,8 @@ class KeypointEstimationParams( ...@@ -1626,7 +1708,8 @@ class KeypointEstimationParams(
heatmap_bias_init, num_candidates_per_keypoint, task_loss_weight, heatmap_bias_init, num_candidates_per_keypoint, task_loss_weight,
peak_max_pool_kernel_size, unmatched_keypoint_score, box_scale, peak_max_pool_kernel_size, unmatched_keypoint_score, box_scale,
candidate_search_scale, candidate_ranking_mode, offset_peak_radius, candidate_search_scale, candidate_ranking_mode, offset_peak_radius,
per_keypoint_offset) per_keypoint_offset, predict_depth, per_keypoint_depth,
keypoint_depth_loss_weight)
class ObjectCenterParams( class ObjectCenterParams(
...@@ -1839,6 +1922,7 @@ BOX_OFFSET = 'box/offset' ...@@ -1839,6 +1922,7 @@ BOX_OFFSET = 'box/offset'
KEYPOINT_REGRESSION = 'keypoint/regression' KEYPOINT_REGRESSION = 'keypoint/regression'
KEYPOINT_HEATMAP = 'keypoint/heatmap' KEYPOINT_HEATMAP = 'keypoint/heatmap'
KEYPOINT_OFFSET = 'keypoint/offset' KEYPOINT_OFFSET = 'keypoint/offset'
KEYPOINT_DEPTH = 'keypoint/depth'
SEGMENTATION_TASK = 'segmentation_task' SEGMENTATION_TASK = 'segmentation_task'
SEGMENTATION_HEATMAP = 'segmentation/heatmap' SEGMENTATION_HEATMAP = 'segmentation/heatmap'
DENSEPOSE_TASK = 'densepose_task' DENSEPOSE_TASK = 'densepose_task'
...@@ -2055,6 +2139,15 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -2055,6 +2139,15 @@ class CenterNetMetaArch(model.DetectionModel):
use_depthwise=self._use_depthwise) use_depthwise=self._use_depthwise)
for _ in range(num_feature_outputs) for _ in range(num_feature_outputs)
] ]
if kp_params.predict_depth:
num_depth_channel = (
num_keypoints if kp_params.per_keypoint_depth else 1)
prediction_heads[get_keypoint_name(task_name, KEYPOINT_DEPTH)] = [
make_prediction_net(
num_depth_channel, use_depthwise=self._use_depthwise)
for _ in range(num_feature_outputs)
]
# pylint: disable=g-complex-comprehension # pylint: disable=g-complex-comprehension
if self._mask_params is not None: if self._mask_params is not None:
prediction_heads[SEGMENTATION_HEATMAP] = [ prediction_heads[SEGMENTATION_HEATMAP] = [
...@@ -2305,6 +2398,7 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -2305,6 +2398,7 @@ class CenterNetMetaArch(model.DetectionModel):
heatmap_key = get_keypoint_name(task_name, KEYPOINT_HEATMAP) heatmap_key = get_keypoint_name(task_name, KEYPOINT_HEATMAP)
offset_key = get_keypoint_name(task_name, KEYPOINT_OFFSET) offset_key = get_keypoint_name(task_name, KEYPOINT_OFFSET)
regression_key = get_keypoint_name(task_name, KEYPOINT_REGRESSION) regression_key = get_keypoint_name(task_name, KEYPOINT_REGRESSION)
depth_key = get_keypoint_name(task_name, KEYPOINT_DEPTH)
heatmap_loss = self._compute_kp_heatmap_loss( heatmap_loss = self._compute_kp_heatmap_loss(
input_height=input_height, input_height=input_height,
input_width=input_width, input_width=input_width,
...@@ -2332,6 +2426,14 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -2332,6 +2426,14 @@ class CenterNetMetaArch(model.DetectionModel):
kp_params.keypoint_offset_loss_weight * offset_loss) kp_params.keypoint_offset_loss_weight * offset_loss)
loss_dict[regression_key] = ( loss_dict[regression_key] = (
kp_params.keypoint_regression_loss_weight * reg_loss) kp_params.keypoint_regression_loss_weight * reg_loss)
if kp_params.predict_depth:
depth_loss = self._compute_kp_depth_loss(
input_height=input_height,
input_width=input_width,
task_name=task_name,
depth_predictions=prediction_dict[depth_key],
localization_loss_fn=kp_params.localization_loss)
loss_dict[depth_key] = kp_params.keypoint_depth_loss_weight * depth_loss
return loss_dict return loss_dict
def _compute_kp_heatmap_loss(self, input_height, input_width, task_name, def _compute_kp_heatmap_loss(self, input_height, input_width, task_name,
...@@ -2501,6 +2603,68 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -2501,6 +2603,68 @@ class CenterNetMetaArch(model.DetectionModel):
tf.maximum(tf.reduce_sum(batch_weights), 1.0)) tf.maximum(tf.reduce_sum(batch_weights), 1.0))
return loss return loss
def _compute_kp_depth_loss(self, input_height, input_width, task_name,
depth_predictions, localization_loss_fn):
"""Computes the loss of the keypoint depth estimation.
Args:
input_height: An integer scalar tensor representing input image height.
input_width: An integer scalar tensor representing input image width.
task_name: A string representing the name of the keypoint task.
depth_predictions: A list of float tensors of shape [batch_size,
out_height, out_width, 1 (or num_keypoints)] representing the prediction
heads of the model for keypoint depth.
localization_loss_fn: An object_detection.core.losses.Loss object to
compute the loss for the keypoint offset predictions in CenterNet.
Returns:
loss: A float scalar tensor representing the keypoint depth loss
normalized by number of total keypoints.
"""
kp_params = self._kp_params_dict[task_name]
gt_keypoints_list = self.groundtruth_lists(fields.BoxListFields.keypoints)
gt_classes_list = self.groundtruth_lists(fields.BoxListFields.classes)
gt_weights_list = self.groundtruth_lists(fields.BoxListFields.weights)
gt_keypoint_depths_list = self.groundtruth_lists(
fields.BoxListFields.keypoint_depths)
gt_keypoint_depth_weights_list = self.groundtruth_lists(
fields.BoxListFields.keypoint_depth_weights)
assigner = self._target_assigner_dict[task_name]
(batch_indices, batch_depths,
batch_weights) = assigner.assign_keypoints_depth_targets(
height=input_height,
width=input_width,
gt_keypoints_list=gt_keypoints_list,
gt_weights_list=gt_weights_list,
gt_classes_list=gt_classes_list,
gt_keypoint_depths_list=gt_keypoint_depths_list,
gt_keypoint_depth_weights_list=gt_keypoint_depth_weights_list)
if kp_params.per_keypoint_offset and not kp_params.per_keypoint_depth:
batch_indices = batch_indices[:, 0:3]
# Keypoint offset loss.
loss = 0.0
for prediction in depth_predictions:
selected_depths = cn_assigner.get_batch_predictions_from_indices(
prediction, batch_indices)
if kp_params.per_keypoint_offset and kp_params.per_keypoint_depth:
selected_depths = tf.expand_dims(selected_depths, axis=-1)
# The dimensions passed are not as per the doc string but the loss
# still computes the correct value.
unweighted_loss = localization_loss_fn(
selected_depths,
batch_depths,
weights=tf.expand_dims(tf.ones_like(batch_weights), -1))
# Apply the weights after the loss function to have full control over it.
loss += batch_weights * tf.squeeze(unweighted_loss, axis=1)
loss = tf.reduce_sum(loss) / (
float(len(depth_predictions)) *
tf.maximum(tf.reduce_sum(batch_weights), 1.0))
return loss
def _compute_segmentation_losses(self, prediction_dict, per_pixel_weights): def _compute_segmentation_losses(self, prediction_dict, per_pixel_weights):
"""Computes all the losses associated with segmentation. """Computes all the losses associated with segmentation.
...@@ -3051,9 +3215,10 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -3051,9 +3215,10 @@ class CenterNetMetaArch(model.DetectionModel):
# keypoint, we fall back to a simpler postprocessing function which uses # keypoint, we fall back to a simpler postprocessing function which uses
# the ops that are supported by tf.lite on GPU. # the ops that are supported by tf.lite on GPU.
if len(self._kp_params_dict) == 1 and self._num_classes == 1: if len(self._kp_params_dict) == 1 and self._num_classes == 1:
keypoints, keypoint_scores = self._postprocess_keypoints_single_class( (keypoints, keypoint_scores,
prediction_dict, classes, y_indices, x_indices, keypoint_depths) = self._postprocess_keypoints_single_class(
boxes_strided, num_detections) prediction_dict, classes, y_indices, x_indices, boxes_strided,
num_detections)
# The map_fn used to clip out of frame keypoints creates issues when # The map_fn used to clip out of frame keypoints creates issues when
# converting to tf.lite model so we disable it and let the users to # converting to tf.lite model so we disable it and let the users to
# handle those out of frame keypoints. # handle those out of frame keypoints.
...@@ -3061,7 +3226,18 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -3061,7 +3226,18 @@ class CenterNetMetaArch(model.DetectionModel):
convert_strided_predictions_to_normalized_keypoints( convert_strided_predictions_to_normalized_keypoints(
keypoints, keypoint_scores, self._stride, true_image_shapes, keypoints, keypoint_scores, self._stride, true_image_shapes,
clip_out_of_frame_keypoints=False)) clip_out_of_frame_keypoints=False))
if keypoint_depths is not None:
postprocess_dict.update({
fields.DetectionResultFields.detection_keypoint_depths:
keypoint_depths
})
else: else:
# Multi-class keypoint estimation task does not support depth
# estimation.
assert all([
not kp_dict.predict_depth
for kp_dict in self._kp_params_dict.values()
])
keypoints, keypoint_scores = self._postprocess_keypoints_multi_class( keypoints, keypoint_scores = self._postprocess_keypoints_multi_class(
prediction_dict, classes, y_indices, x_indices, prediction_dict, classes, y_indices, x_indices,
boxes_strided, num_detections) boxes_strided, num_detections)
...@@ -3200,7 +3376,11 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -3200,7 +3376,11 @@ class CenterNetMetaArch(model.DetectionModel):
task_name, KEYPOINT_REGRESSION)][-1] task_name, KEYPOINT_REGRESSION)][-1]
object_heatmap = tf.nn.sigmoid(prediction_dict[OBJECT_CENTER][-1]) object_heatmap = tf.nn.sigmoid(prediction_dict[OBJECT_CENTER][-1])
keypoints, keypoint_scores = ( keypoint_depths = None
if kp_params.predict_depth:
keypoint_depths = prediction_dict[get_keypoint_name(
task_name, KEYPOINT_DEPTH)][-1]
keypoints, keypoint_scores, keypoint_depths = (
prediction_to_single_instance_keypoints( prediction_to_single_instance_keypoints(
object_heatmap=object_heatmap, object_heatmap=object_heatmap,
keypoint_heatmap=keypoint_heatmap, keypoint_heatmap=keypoint_heatmap,
...@@ -3209,7 +3389,8 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -3209,7 +3389,8 @@ class CenterNetMetaArch(model.DetectionModel):
stride=self._stride, stride=self._stride,
object_center_std_dev=object_center_std_dev, object_center_std_dev=object_center_std_dev,
keypoint_std_dev=keypoint_std_dev, keypoint_std_dev=keypoint_std_dev,
kp_params=kp_params)) kp_params=kp_params,
keypoint_depths=keypoint_depths))
keypoints, keypoint_scores = ( keypoints, keypoint_scores = (
convert_strided_predictions_to_normalized_keypoints( convert_strided_predictions_to_normalized_keypoints(
...@@ -3222,6 +3403,12 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -3222,6 +3403,12 @@ class CenterNetMetaArch(model.DetectionModel):
fields.DetectionResultFields.detection_keypoints: keypoints, fields.DetectionResultFields.detection_keypoints: keypoints,
fields.DetectionResultFields.detection_keypoint_scores: keypoint_scores fields.DetectionResultFields.detection_keypoint_scores: keypoint_scores
} }
if kp_params.predict_depth:
postprocess_dict.update({
fields.DetectionResultFields.detection_keypoint_depths:
keypoint_depths
})
return postprocess_dict return postprocess_dict
def _postprocess_embeddings(self, prediction_dict, y_indices, x_indices): def _postprocess_embeddings(self, prediction_dict, y_indices, x_indices):
...@@ -3316,7 +3503,7 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -3316,7 +3503,7 @@ class CenterNetMetaArch(model.DetectionModel):
# [1, num_instances_i, num_keypoints_i], respectively. Note that # [1, num_instances_i, num_keypoints_i], respectively. Note that
# num_instances_i and num_keypoints_i refers to the number of # num_instances_i and num_keypoints_i refers to the number of
# instances and keypoints for class i, respectively. # instances and keypoints for class i, respectively.
kpt_coords_for_class, kpt_scores_for_class = ( (kpt_coords_for_class, kpt_scores_for_class, _) = (
self._postprocess_keypoints_for_class_and_image( self._postprocess_keypoints_for_class_and_image(
keypoint_heatmap, keypoint_offsets, keypoint_regression, keypoint_heatmap, keypoint_offsets, keypoint_regression,
classes, y_indices_for_kpt_class, x_indices_for_kpt_class, classes, y_indices_for_kpt_class, x_indices_for_kpt_class,
...@@ -3426,21 +3613,35 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -3426,21 +3613,35 @@ class CenterNetMetaArch(model.DetectionModel):
get_keypoint_name(task_name, KEYPOINT_OFFSET)][-1] get_keypoint_name(task_name, KEYPOINT_OFFSET)][-1]
keypoint_regression = prediction_dict[ keypoint_regression = prediction_dict[
get_keypoint_name(task_name, KEYPOINT_REGRESSION)][-1] get_keypoint_name(task_name, KEYPOINT_REGRESSION)][-1]
keypoint_depth_predictions = None
if kp_params.predict_depth:
keypoint_depth_predictions = prediction_dict[get_keypoint_name(
task_name, KEYPOINT_DEPTH)][-1]
batch_size, _, _ = _get_shape(boxes, 3) batch_size, _, _ = _get_shape(boxes, 3)
kpt_coords_for_example_list = [] kpt_coords_for_example_list = []
kpt_scores_for_example_list = [] kpt_scores_for_example_list = []
kpt_depths_for_example_list = []
for ex_ind in range(batch_size): for ex_ind in range(batch_size):
# Postprocess keypoints and scores for class and single image. Shapes # Postprocess keypoints and scores for class and single image. Shapes
# are [1, max_detections, num_keypoints, 2] and # are [1, max_detections, num_keypoints, 2] and
# [1, max_detections, num_keypoints], respectively. # [1, max_detections, num_keypoints], respectively.
kpt_coords_for_class, kpt_scores_for_class = ( (kpt_coords_for_class, kpt_scores_for_class, kpt_depths_for_class) = (
self._postprocess_keypoints_for_class_and_image( self._postprocess_keypoints_for_class_and_image(
keypoint_heatmap, keypoint_offsets, keypoint_regression, classes, keypoint_heatmap,
y_indices, x_indices, boxes, ex_ind, kp_params)) keypoint_offsets,
keypoint_regression,
classes,
y_indices,
x_indices,
boxes,
ex_ind,
kp_params,
keypoint_depth_predictions=keypoint_depth_predictions))
kpt_coords_for_example_list.append(kpt_coords_for_class) kpt_coords_for_example_list.append(kpt_coords_for_class)
kpt_scores_for_example_list.append(kpt_scores_for_class) kpt_scores_for_example_list.append(kpt_scores_for_class)
kpt_depths_for_example_list.append(kpt_depths_for_class)
# Concatenate all keypoints and scores from all examples in the batch. # Concatenate all keypoints and scores from all examples in the batch.
# Shapes are [batch_size, max_detections, num_keypoints, 2] and # Shapes are [batch_size, max_detections, num_keypoints, 2] and
...@@ -3448,7 +3649,11 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -3448,7 +3649,11 @@ class CenterNetMetaArch(model.DetectionModel):
keypoints = tf.concat(kpt_coords_for_example_list, axis=0) keypoints = tf.concat(kpt_coords_for_example_list, axis=0)
keypoint_scores = tf.concat(kpt_scores_for_example_list, axis=0) keypoint_scores = tf.concat(kpt_scores_for_example_list, axis=0)
return keypoints, keypoint_scores keypoint_depths = None
if kp_params.predict_depth:
keypoint_depths = tf.concat(kpt_depths_for_example_list, axis=0)
return keypoints, keypoint_scores, keypoint_depths
def _get_instance_indices(self, classes, num_detections, batch_index, def _get_instance_indices(self, classes, num_detections, batch_index,
class_id): class_id):
...@@ -3482,8 +3687,17 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -3482,8 +3687,17 @@ class CenterNetMetaArch(model.DetectionModel):
return tf.cast(instance_inds, tf.int32) return tf.cast(instance_inds, tf.int32)
def _postprocess_keypoints_for_class_and_image( def _postprocess_keypoints_for_class_and_image(
self, keypoint_heatmap, keypoint_offsets, keypoint_regression, classes, self,
y_indices, x_indices, boxes, batch_index, kp_params): keypoint_heatmap,
keypoint_offsets,
keypoint_regression,
classes,
y_indices,
x_indices,
boxes,
batch_index,
kp_params,
keypoint_depth_predictions=None):
"""Postprocess keypoints for a single image and class. """Postprocess keypoints for a single image and class.
Args: Args:
...@@ -3504,6 +3718,8 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -3504,6 +3718,8 @@ class CenterNetMetaArch(model.DetectionModel):
batch_index: An integer specifying the index for an example in the batch. batch_index: An integer specifying the index for an example in the batch.
kp_params: A `KeypointEstimationParams` object with parameters for a kp_params: A `KeypointEstimationParams` object with parameters for a
single keypoint class. single keypoint class.
keypoint_depth_predictions: (optional) A [batch_size, height, width, 1]
float32 tensor representing the keypoint depth prediction.
Returns: Returns:
A tuple of A tuple of
...@@ -3514,6 +3730,9 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -3514,6 +3730,9 @@ class CenterNetMetaArch(model.DetectionModel):
for the specific class. for the specific class.
refined_scores: A [1, num_instances, num_keypoints] float32 tensor with refined_scores: A [1, num_instances, num_keypoints] float32 tensor with
keypoint scores. keypoint scores.
refined_depths: A [1, num_instances, num_keypoints] float32 tensor with
keypoint depths. Return None if the input keypoint_depth_predictions is
None.
""" """
num_keypoints = len(kp_params.keypoint_indices) num_keypoints = len(kp_params.keypoint_indices)
...@@ -3521,6 +3740,10 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -3521,6 +3740,10 @@ class CenterNetMetaArch(model.DetectionModel):
keypoint_heatmap[batch_index:batch_index+1, ...]) keypoint_heatmap[batch_index:batch_index+1, ...])
keypoint_offsets = keypoint_offsets[batch_index:batch_index+1, ...] keypoint_offsets = keypoint_offsets[batch_index:batch_index+1, ...]
keypoint_regression = keypoint_regression[batch_index:batch_index+1, ...] keypoint_regression = keypoint_regression[batch_index:batch_index+1, ...]
keypoint_depths = None
if keypoint_depth_predictions is not None:
keypoint_depths = keypoint_depth_predictions[batch_index:batch_index + 1,
...]
y_indices = y_indices[batch_index:batch_index+1, ...] y_indices = y_indices[batch_index:batch_index+1, ...]
x_indices = x_indices[batch_index:batch_index+1, ...] x_indices = x_indices[batch_index:batch_index+1, ...]
boxes_slice = boxes[batch_index:batch_index+1, ...] boxes_slice = boxes[batch_index:batch_index+1, ...]
...@@ -3536,26 +3759,33 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -3536,26 +3759,33 @@ class CenterNetMetaArch(model.DetectionModel):
# The shape of keypoint_candidates and keypoint_scores is: # The shape of keypoint_candidates and keypoint_scores is:
# [1, num_candidates_per_keypoint, num_keypoints, 2] and # [1, num_candidates_per_keypoint, num_keypoints, 2] and
# [1, num_candidates_per_keypoint, num_keypoints], respectively. # [1, num_candidates_per_keypoint, num_keypoints], respectively.
keypoint_candidates, keypoint_scores, num_keypoint_candidates = ( (keypoint_candidates, keypoint_scores, num_keypoint_candidates,
prediction_tensors_to_keypoint_candidates( keypoint_depth_candidates) = (
keypoint_heatmap, keypoint_offsets, prediction_tensors_to_keypoint_candidates(
keypoint_score_threshold=( keypoint_heatmap,
kp_params.keypoint_candidate_score_threshold), keypoint_offsets,
max_pool_kernel_size=kp_params.peak_max_pool_kernel_size, keypoint_score_threshold=(
max_candidates=kp_params.num_candidates_per_keypoint)) kp_params.keypoint_candidate_score_threshold),
max_pool_kernel_size=kp_params.peak_max_pool_kernel_size,
max_candidates=kp_params.num_candidates_per_keypoint,
keypoint_depths=keypoint_depths))
# Get the refined keypoints and scores, of shape # Get the refined keypoints and scores, of shape
# [1, num_instances, num_keypoints, 2] and # [1, num_instances, num_keypoints, 2] and
# [1, num_instances, num_keypoints], respectively. # [1, num_instances, num_keypoints], respectively.
refined_keypoints, refined_scores = refine_keypoints( (refined_keypoints, refined_scores, refined_depths) = refine_keypoints(
regressed_keypoints_for_objects, keypoint_candidates, keypoint_scores, regressed_keypoints_for_objects,
num_keypoint_candidates, bboxes=boxes_slice, keypoint_candidates,
keypoint_scores,
num_keypoint_candidates,
bboxes=boxes_slice,
unmatched_keypoint_score=kp_params.unmatched_keypoint_score, unmatched_keypoint_score=kp_params.unmatched_keypoint_score,
box_scale=kp_params.box_scale, box_scale=kp_params.box_scale,
candidate_search_scale=kp_params.candidate_search_scale, candidate_search_scale=kp_params.candidate_search_scale,
candidate_ranking_mode=kp_params.candidate_ranking_mode) candidate_ranking_mode=kp_params.candidate_ranking_mode,
keypoint_depth_candidates=keypoint_depth_candidates)
return refined_keypoints, refined_scores return refined_keypoints, refined_scores, refined_depths
def regularization_losses(self): def regularization_losses(self):
return [] return []
......
...@@ -695,7 +695,7 @@ class CenterNetMetaArchHelpersTest(test_case.TestCase, parameterized.TestCase): ...@@ -695,7 +695,7 @@ class CenterNetMetaArchHelpersTest(test_case.TestCase, parameterized.TestCase):
keypoint_heatmap_offsets = tf.constant( keypoint_heatmap_offsets = tf.constant(
keypoint_heatmap_offsets_np, dtype=tf.float32) keypoint_heatmap_offsets_np, dtype=tf.float32)
keypoint_cands, keypoint_scores, num_keypoint_candidates = ( (keypoint_cands, keypoint_scores, num_keypoint_candidates, _) = (
cnma.prediction_tensors_to_keypoint_candidates( cnma.prediction_tensors_to_keypoint_candidates(
keypoint_heatmap, keypoint_heatmap,
keypoint_heatmap_offsets, keypoint_heatmap_offsets,
...@@ -780,7 +780,7 @@ class CenterNetMetaArchHelpersTest(test_case.TestCase, parameterized.TestCase): ...@@ -780,7 +780,7 @@ class CenterNetMetaArchHelpersTest(test_case.TestCase, parameterized.TestCase):
keypoint_regression = tf.constant( keypoint_regression = tf.constant(
keypoint_regression_np, dtype=tf.float32) keypoint_regression_np, dtype=tf.float32)
(keypoint_cands, keypoint_scores) = ( (keypoint_cands, keypoint_scores, _) = (
cnma.prediction_to_single_instance_keypoints( cnma.prediction_to_single_instance_keypoints(
object_heatmap, object_heatmap,
keypoint_heatmap, keypoint_heatmap,
...@@ -839,7 +839,7 @@ class CenterNetMetaArchHelpersTest(test_case.TestCase, parameterized.TestCase): ...@@ -839,7 +839,7 @@ class CenterNetMetaArchHelpersTest(test_case.TestCase, parameterized.TestCase):
keypoint_heatmap_offsets = tf.constant( keypoint_heatmap_offsets = tf.constant(
keypoint_heatmap_offsets_np, dtype=tf.float32) keypoint_heatmap_offsets_np, dtype=tf.float32)
keypoint_cands, keypoint_scores, num_keypoint_candidates = ( (keypoint_cands, keypoint_scores, num_keypoint_candidates, _) = (
cnma.prediction_tensors_to_keypoint_candidates( cnma.prediction_tensors_to_keypoint_candidates(
keypoint_heatmap, keypoint_heatmap,
keypoint_heatmap_offsets, keypoint_heatmap_offsets,
...@@ -880,6 +880,89 @@ class CenterNetMetaArchHelpersTest(test_case.TestCase, parameterized.TestCase): ...@@ -880,6 +880,89 @@ class CenterNetMetaArchHelpersTest(test_case.TestCase, parameterized.TestCase):
np.testing.assert_array_equal(expected_num_keypoint_candidates, np.testing.assert_array_equal(expected_num_keypoint_candidates,
num_keypoint_candidates) num_keypoint_candidates)
@parameterized.parameters({'per_keypoint_depth': True},
{'per_keypoint_depth': False})
def test_keypoint_candidate_prediction_depth(self, per_keypoint_depth):
keypoint_heatmap_np = np.zeros((2, 3, 3, 2), dtype=np.float32)
keypoint_heatmap_np[0, 0, 0, 0] = 1.0
keypoint_heatmap_np[0, 2, 1, 0] = 0.7
keypoint_heatmap_np[0, 1, 1, 0] = 0.6
keypoint_heatmap_np[0, 0, 2, 1] = 0.7
keypoint_heatmap_np[0, 1, 1, 1] = 0.3 # Filtered by low score.
keypoint_heatmap_np[0, 2, 2, 1] = 0.2
keypoint_heatmap_np[1, 1, 0, 0] = 0.6
keypoint_heatmap_np[1, 2, 1, 0] = 0.5
keypoint_heatmap_np[1, 0, 0, 0] = 0.4
keypoint_heatmap_np[1, 0, 0, 1] = 1.0
keypoint_heatmap_np[1, 0, 1, 1] = 0.9
keypoint_heatmap_np[1, 2, 0, 1] = 0.8
if per_keypoint_depth:
keypoint_depths_np = np.zeros((2, 3, 3, 2), dtype=np.float32)
keypoint_depths_np[0, 0, 0, 0] = -1.5
keypoint_depths_np[0, 2, 1, 0] = -1.0
keypoint_depths_np[0, 0, 2, 1] = 1.5
else:
keypoint_depths_np = np.zeros((2, 3, 3, 1), dtype=np.float32)
keypoint_depths_np[0, 0, 0, 0] = -1.5
keypoint_depths_np[0, 2, 1, 0] = -1.0
keypoint_depths_np[0, 0, 2, 0] = 1.5
keypoint_heatmap_offsets_np = np.zeros((2, 3, 3, 2), dtype=np.float32)
keypoint_heatmap_offsets_np[0, 0, 0] = [0.5, 0.25]
keypoint_heatmap_offsets_np[0, 2, 1] = [-0.25, 0.5]
keypoint_heatmap_offsets_np[0, 1, 1] = [0.0, 0.0]
keypoint_heatmap_offsets_np[0, 0, 2] = [1.0, 0.0]
keypoint_heatmap_offsets_np[0, 2, 2] = [1.0, 1.0]
keypoint_heatmap_offsets_np[1, 1, 0] = [0.25, 0.5]
keypoint_heatmap_offsets_np[1, 2, 1] = [0.5, 0.0]
keypoint_heatmap_offsets_np[1, 0, 0] = [0.0, -0.5]
keypoint_heatmap_offsets_np[1, 0, 1] = [0.5, -0.5]
keypoint_heatmap_offsets_np[1, 2, 0] = [-1.0, -0.5]
def graph_fn():
keypoint_heatmap = tf.constant(keypoint_heatmap_np, dtype=tf.float32)
keypoint_heatmap_offsets = tf.constant(
keypoint_heatmap_offsets_np, dtype=tf.float32)
keypoint_depths = tf.constant(keypoint_depths_np, dtype=tf.float32)
(keypoint_cands, keypoint_scores, num_keypoint_candidates,
keypoint_depths) = (
cnma.prediction_tensors_to_keypoint_candidates(
keypoint_heatmap,
keypoint_heatmap_offsets,
keypoint_score_threshold=0.5,
max_pool_kernel_size=1,
max_candidates=2,
keypoint_depths=keypoint_depths))
return (keypoint_cands, keypoint_scores, num_keypoint_candidates,
keypoint_depths)
(_, keypoint_scores, _, keypoint_depths) = self.execute(graph_fn, [])
expected_keypoint_scores = [
[ # Example 0.
[1.0, 0.7], # Keypoint 1.
[0.7, 0.3], # Keypoint 2.
],
[ # Example 1.
[0.6, 1.0], # Keypoint 1.
[0.5, 0.9], # Keypoint 2.
],
]
expected_keypoint_depths = [
[
[-1.5, 1.5],
[-1.0, 0.0],
],
[
[0., 0.],
[0., 0.],
],
]
np.testing.assert_allclose(expected_keypoint_scores, keypoint_scores)
np.testing.assert_allclose(expected_keypoint_depths, keypoint_depths)
def test_regressed_keypoints_at_object_centers(self): def test_regressed_keypoints_at_object_centers(self):
batch_size = 2 batch_size = 2
num_keypoints = 5 num_keypoints = 5
...@@ -985,11 +1068,15 @@ class CenterNetMetaArchHelpersTest(test_case.TestCase, parameterized.TestCase): ...@@ -985,11 +1068,15 @@ class CenterNetMetaArchHelpersTest(test_case.TestCase, parameterized.TestCase):
keypoint_scores = tf.constant(keypoint_scores_np, dtype=tf.float32) keypoint_scores = tf.constant(keypoint_scores_np, dtype=tf.float32)
num_keypoint_candidates = tf.constant(num_keypoints_candidates_np, num_keypoint_candidates = tf.constant(num_keypoints_candidates_np,
dtype=tf.int32) dtype=tf.int32)
refined_keypoints, refined_scores = cnma.refine_keypoints( (refined_keypoints, refined_scores, _) = cnma.refine_keypoints(
regressed_keypoints, keypoint_candidates, keypoint_scores, regressed_keypoints,
num_keypoint_candidates, bboxes=None, keypoint_candidates,
keypoint_scores,
num_keypoint_candidates,
bboxes=None,
unmatched_keypoint_score=unmatched_keypoint_score, unmatched_keypoint_score=unmatched_keypoint_score,
box_scale=1.2, candidate_search_scale=0.3, box_scale=1.2,
candidate_search_scale=0.3,
candidate_ranking_mode=candidate_ranking_mode) candidate_ranking_mode=candidate_ranking_mode)
return refined_keypoints, refined_scores return refined_keypoints, refined_scores
...@@ -1057,7 +1144,8 @@ class CenterNetMetaArchHelpersTest(test_case.TestCase, parameterized.TestCase): ...@@ -1057,7 +1144,8 @@ class CenterNetMetaArchHelpersTest(test_case.TestCase, parameterized.TestCase):
np.testing.assert_allclose(expected_refined_keypoints, refined_keypoints) np.testing.assert_allclose(expected_refined_keypoints, refined_keypoints)
np.testing.assert_allclose(expected_refined_scores, refined_scores) np.testing.assert_allclose(expected_refined_scores, refined_scores)
def test_refine_keypoints_with_bboxes(self): @parameterized.parameters({'predict_depth': True}, {'predict_depth': False})
def test_refine_keypoints_with_bboxes(self, predict_depth):
regressed_keypoints_np = np.array( regressed_keypoints_np = np.array(
[ [
# Example 0. # Example 0.
...@@ -1096,7 +1184,22 @@ class CenterNetMetaArchHelpersTest(test_case.TestCase, parameterized.TestCase): ...@@ -1096,7 +1184,22 @@ class CenterNetMetaArchHelpersTest(test_case.TestCase, parameterized.TestCase):
[0.7, 0.4, 0.0], # Candidate 0. [0.7, 0.4, 0.0], # Candidate 0.
[0.6, 0.1, 0.0], # Candidate 1. [0.6, 0.1, 0.0], # Candidate 1.
] ]
], dtype=np.float32) ],
dtype=np.float32)
keypoint_depths_np = np.array(
[
# Example 0.
[
[-0.8, -0.9, -1.0], # Candidate 0.
[-0.6, -0.1, -0.9], # Candidate 1.
],
# Example 1.
[
[-0.7, -0.4, -0.0], # Candidate 0.
[-0.6, -0.1, -0.0], # Candidate 1.
]
],
dtype=np.float32)
num_keypoints_candidates_np = np.array( num_keypoints_candidates_np = np.array(
[ [
# Example 0. # Example 0.
...@@ -1125,17 +1228,28 @@ class CenterNetMetaArchHelpersTest(test_case.TestCase, parameterized.TestCase): ...@@ -1125,17 +1228,28 @@ class CenterNetMetaArchHelpersTest(test_case.TestCase, parameterized.TestCase):
keypoint_candidates = tf.constant( keypoint_candidates = tf.constant(
keypoint_candidates_np, dtype=tf.float32) keypoint_candidates_np, dtype=tf.float32)
keypoint_scores = tf.constant(keypoint_scores_np, dtype=tf.float32) keypoint_scores = tf.constant(keypoint_scores_np, dtype=tf.float32)
if predict_depth:
keypoint_depths = tf.constant(keypoint_depths_np, dtype=tf.float32)
else:
keypoint_depths = None
num_keypoint_candidates = tf.constant(num_keypoints_candidates_np, num_keypoint_candidates = tf.constant(num_keypoints_candidates_np,
dtype=tf.int32) dtype=tf.int32)
bboxes = tf.constant(bboxes_np, dtype=tf.float32) bboxes = tf.constant(bboxes_np, dtype=tf.float32)
refined_keypoints, refined_scores = cnma.refine_keypoints( (refined_keypoints, refined_scores,
regressed_keypoints, keypoint_candidates, keypoint_scores, refined_depths) = cnma.refine_keypoints(
num_keypoint_candidates, bboxes=bboxes, regressed_keypoints,
unmatched_keypoint_score=unmatched_keypoint_score, keypoint_candidates,
box_scale=1.0, candidate_search_scale=0.3) keypoint_scores,
return refined_keypoints, refined_scores num_keypoint_candidates,
bboxes=bboxes,
refined_keypoints, refined_scores = self.execute(graph_fn, []) unmatched_keypoint_score=unmatched_keypoint_score,
box_scale=1.0,
candidate_search_scale=0.3,
keypoint_depth_candidates=keypoint_depths)
if predict_depth:
return refined_keypoints, refined_scores, refined_depths
else:
return refined_keypoints, refined_scores
expected_refined_keypoints = np.array( expected_refined_keypoints = np.array(
[ [
...@@ -1166,8 +1280,17 @@ class CenterNetMetaArchHelpersTest(test_case.TestCase, parameterized.TestCase): ...@@ -1166,8 +1280,17 @@ class CenterNetMetaArchHelpersTest(test_case.TestCase, parameterized.TestCase):
], ],
], dtype=np.float32) ], dtype=np.float32)
np.testing.assert_allclose(expected_refined_keypoints, refined_keypoints) if predict_depth:
np.testing.assert_allclose(expected_refined_scores, refined_scores) refined_keypoints, refined_scores, refined_depths = self.execute(
graph_fn, [])
expected_refined_depths = np.array([[[-0.8, 0.0, 0.0], [0.0, 0.0, -1.0]],
[[-0.7, -0.1, 0.0], [-0.7, -0.4,
0.0]]])
np.testing.assert_allclose(expected_refined_depths, refined_depths)
else:
refined_keypoints, refined_scores = self.execute(graph_fn, [])
np.testing.assert_allclose(expected_refined_keypoints, refined_keypoints)
np.testing.assert_allclose(expected_refined_scores, refined_scores)
def test_pad_to_full_keypoint_dim(self): def test_pad_to_full_keypoint_dim(self):
batch_size = 4 batch_size = 4
...@@ -1296,7 +1419,11 @@ def get_fake_od_params(): ...@@ -1296,7 +1419,11 @@ def get_fake_od_params():
scale_loss_weight=0.1) scale_loss_weight=0.1)
def get_fake_kp_params(num_candidates_per_keypoint=100): def get_fake_kp_params(num_candidates_per_keypoint=100,
per_keypoint_offset=False,
predict_depth=False,
per_keypoint_depth=False,
peak_radius=0):
"""Returns the fake keypoint estimation parameter namedtuple.""" """Returns the fake keypoint estimation parameter namedtuple."""
return cnma.KeypointEstimationParams( return cnma.KeypointEstimationParams(
task_name=_TASK_NAME, task_name=_TASK_NAME,
...@@ -1306,7 +1433,11 @@ def get_fake_kp_params(num_candidates_per_keypoint=100): ...@@ -1306,7 +1433,11 @@ def get_fake_kp_params(num_candidates_per_keypoint=100):
classification_loss=losses.WeightedSigmoidClassificationLoss(), classification_loss=losses.WeightedSigmoidClassificationLoss(),
localization_loss=losses.L1LocalizationLoss(), localization_loss=losses.L1LocalizationLoss(),
keypoint_candidate_score_threshold=0.1, keypoint_candidate_score_threshold=0.1,
num_candidates_per_keypoint=num_candidates_per_keypoint) num_candidates_per_keypoint=num_candidates_per_keypoint,
per_keypoint_offset=per_keypoint_offset,
predict_depth=predict_depth,
per_keypoint_depth=per_keypoint_depth,
offset_peak_radius=peak_radius)
def get_fake_mask_params(): def get_fake_mask_params():
...@@ -1353,7 +1484,11 @@ def build_center_net_meta_arch(build_resnet=False, ...@@ -1353,7 +1484,11 @@ def build_center_net_meta_arch(build_resnet=False,
num_classes=_NUM_CLASSES, num_classes=_NUM_CLASSES,
max_box_predictions=5, max_box_predictions=5,
apply_non_max_suppression=False, apply_non_max_suppression=False,
detection_only=False): detection_only=False,
per_keypoint_offset=False,
predict_depth=False,
per_keypoint_depth=False,
peak_radius=0):
"""Builds the CenterNet meta architecture.""" """Builds the CenterNet meta architecture."""
if build_resnet: if build_resnet:
feature_extractor = ( feature_extractor = (
...@@ -1407,7 +1542,10 @@ def build_center_net_meta_arch(build_resnet=False, ...@@ -1407,7 +1542,10 @@ def build_center_net_meta_arch(build_resnet=False,
object_center_params=get_fake_center_params(max_box_predictions), object_center_params=get_fake_center_params(max_box_predictions),
object_detection_params=get_fake_od_params(), object_detection_params=get_fake_od_params(),
keypoint_params_dict={ keypoint_params_dict={
_TASK_NAME: get_fake_kp_params(num_candidates_per_keypoint) _TASK_NAME:
get_fake_kp_params(num_candidates_per_keypoint,
per_keypoint_offset, predict_depth,
per_keypoint_depth, peak_radius)
}, },
non_max_suppression_fn=non_max_suppression_fn) non_max_suppression_fn=non_max_suppression_fn)
else: else:
...@@ -1992,6 +2130,84 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase): ...@@ -1992,6 +2130,84 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase):
self.assertAllEqual([1, 1, num_keypoints], self.assertAllEqual([1, 1, num_keypoints],
detections['detection_keypoint_scores'].shape) detections['detection_keypoint_scores'].shape)
@parameterized.parameters(
{'per_keypoint_depth': False},
{'per_keypoint_depth': True},
)
def test_postprocess_single_class_depth(self, per_keypoint_depth):
"""Test the postprocess function."""
model = build_center_net_meta_arch(
num_classes=1,
per_keypoint_offset=per_keypoint_depth,
predict_depth=True,
per_keypoint_depth=per_keypoint_depth)
num_keypoints = len(model._kp_params_dict[_TASK_NAME].keypoint_indices)
class_center = np.zeros((1, 32, 32, 1), dtype=np.float32)
height_width = np.zeros((1, 32, 32, 2), dtype=np.float32)
offset = np.zeros((1, 32, 32, 2), dtype=np.float32)
keypoint_heatmaps = np.zeros((1, 32, 32, num_keypoints), dtype=np.float32)
keypoint_offsets = np.zeros((1, 32, 32, 2), dtype=np.float32)
keypoint_regression = np.random.randn(1, 32, 32, num_keypoints * 2)
class_probs = np.zeros(1)
class_probs[0] = _logit(0.75)
class_center[0, 16, 16] = class_probs
height_width[0, 16, 16] = [5, 10]
offset[0, 16, 16] = [.25, .5]
keypoint_regression[0, 16, 16] = [-1., -1., -1., 1., 1., -1., 1., 1.]
keypoint_heatmaps[0, 14, 14, 0] = _logit(0.9)
keypoint_heatmaps[0, 14, 18, 1] = _logit(0.9)
keypoint_heatmaps[0, 18, 14, 2] = _logit(0.9)
keypoint_heatmaps[0, 18, 18, 3] = _logit(0.05) # Note the low score.
if per_keypoint_depth:
keypoint_depth = np.zeros((1, 32, 32, num_keypoints), dtype=np.float32)
keypoint_depth[0, 14, 14, 0] = -1.0
keypoint_depth[0, 14, 18, 1] = -1.1
keypoint_depth[0, 18, 14, 2] = -1.2
keypoint_depth[0, 18, 18, 3] = -1.3
else:
keypoint_depth = np.zeros((1, 32, 32, 1), dtype=np.float32)
keypoint_depth[0, 14, 14, 0] = -1.0
keypoint_depth[0, 14, 18, 0] = -1.1
keypoint_depth[0, 18, 14, 0] = -1.2
keypoint_depth[0, 18, 18, 0] = -1.3
class_center = tf.constant(class_center)
height_width = tf.constant(height_width)
offset = tf.constant(offset)
keypoint_heatmaps = tf.constant(keypoint_heatmaps, dtype=tf.float32)
keypoint_offsets = tf.constant(keypoint_offsets, dtype=tf.float32)
keypoint_regression = tf.constant(keypoint_regression, dtype=tf.float32)
keypoint_depth = tf.constant(keypoint_depth, dtype=tf.float32)
prediction_dict = {
cnma.OBJECT_CENTER: [class_center],
cnma.BOX_SCALE: [height_width],
cnma.BOX_OFFSET: [offset],
cnma.get_keypoint_name(_TASK_NAME,
cnma.KEYPOINT_HEATMAP): [keypoint_heatmaps],
cnma.get_keypoint_name(_TASK_NAME,
cnma.KEYPOINT_OFFSET): [keypoint_offsets],
cnma.get_keypoint_name(_TASK_NAME,
cnma.KEYPOINT_REGRESSION): [keypoint_regression],
cnma.get_keypoint_name(_TASK_NAME,
cnma.KEYPOINT_DEPTH): [keypoint_depth]
}
def graph_fn():
detections = model.postprocess(prediction_dict,
tf.constant([[128, 128, 3]]))
return detections
detections = self.execute_cpu(graph_fn, [])
self.assertAllClose(detections['detection_keypoint_depths'][0, 0],
np.array([-1.0, -1.1, -1.2, 0.0]))
self.assertAllClose(detections['detection_keypoint_scores'][0, 0],
np.array([0.9, 0.9, 0.9, 0.1]))
def test_get_instance_indices(self): def test_get_instance_indices(self):
classes = tf.constant([[0, 1, 2, 0], [2, 1, 2, 2]], dtype=tf.int32) classes = tf.constant([[0, 1, 2, 0], [2, 1, 2, 2]], dtype=tf.int32)
num_detections = tf.constant([1, 3], dtype=tf.int32) num_detections = tf.constant([1, 3], dtype=tf.int32)
...@@ -2003,7 +2219,10 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase): ...@@ -2003,7 +2219,10 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase):
self.assertAllEqual(valid_indices.numpy(), [0, 2]) self.assertAllEqual(valid_indices.numpy(), [0, 2])
def get_fake_prediction_dict(input_height, input_width, stride): def get_fake_prediction_dict(input_height,
input_width,
stride,
per_keypoint_depth=False):
"""Prepares the fake prediction dictionary.""" """Prepares the fake prediction dictionary."""
output_height = input_height // stride output_height = input_height // stride
output_width = input_width // stride output_width = input_width // stride
...@@ -2038,6 +2257,11 @@ def get_fake_prediction_dict(input_height, input_width, stride): ...@@ -2038,6 +2257,11 @@ def get_fake_prediction_dict(input_height, input_width, stride):
dtype=np.float32) dtype=np.float32)
keypoint_offset[0, 2, 4] = 0.2, 0.4 keypoint_offset[0, 2, 4] = 0.2, 0.4
keypoint_depth = np.zeros((2, output_height, output_width,
_NUM_KEYPOINTS if per_keypoint_depth else 1),
dtype=np.float32)
keypoint_depth[0, 2, 4] = 3.0
keypoint_regression = np.zeros( keypoint_regression = np.zeros(
(2, output_height, output_width, 2 * _NUM_KEYPOINTS), dtype=np.float32) (2, output_height, output_width, 2 * _NUM_KEYPOINTS), dtype=np.float32)
keypoint_regression[0, 2, 4] = 0.0, 0.0, 0.2, 0.4, 0.0, 0.0, 0.2, 0.4 keypoint_regression[0, 2, 4] = 0.0, 0.0, 0.2, 0.4, 0.0, 0.0, 0.2, 0.4
...@@ -2073,14 +2297,10 @@ def get_fake_prediction_dict(input_height, input_width, stride): ...@@ -2073,14 +2297,10 @@ def get_fake_prediction_dict(input_height, input_width, stride):
tf.constant(object_center), tf.constant(object_center),
tf.constant(object_center) tf.constant(object_center)
], ],
cnma.BOX_SCALE: [ cnma.BOX_SCALE: [tf.constant(object_scale),
tf.constant(object_scale), tf.constant(object_scale)],
tf.constant(object_scale) cnma.BOX_OFFSET: [tf.constant(object_offset),
], tf.constant(object_offset)],
cnma.BOX_OFFSET: [
tf.constant(object_offset),
tf.constant(object_offset)
],
cnma.get_keypoint_name(_TASK_NAME, cnma.KEYPOINT_HEATMAP): [ cnma.get_keypoint_name(_TASK_NAME, cnma.KEYPOINT_HEATMAP): [
tf.constant(keypoint_heatmap), tf.constant(keypoint_heatmap),
tf.constant(keypoint_heatmap) tf.constant(keypoint_heatmap)
...@@ -2093,6 +2313,10 @@ def get_fake_prediction_dict(input_height, input_width, stride): ...@@ -2093,6 +2313,10 @@ def get_fake_prediction_dict(input_height, input_width, stride):
tf.constant(keypoint_regression), tf.constant(keypoint_regression),
tf.constant(keypoint_regression) tf.constant(keypoint_regression)
], ],
cnma.get_keypoint_name(_TASK_NAME, cnma.KEYPOINT_DEPTH): [
tf.constant(keypoint_depth),
tf.constant(keypoint_depth)
],
cnma.SEGMENTATION_HEATMAP: [ cnma.SEGMENTATION_HEATMAP: [
tf.constant(mask_heatmap), tf.constant(mask_heatmap),
tf.constant(mask_heatmap) tf.constant(mask_heatmap)
...@@ -2117,7 +2341,10 @@ def get_fake_prediction_dict(input_height, input_width, stride): ...@@ -2117,7 +2341,10 @@ def get_fake_prediction_dict(input_height, input_width, stride):
return prediction_dict return prediction_dict
def get_fake_groundtruth_dict(input_height, input_width, stride): def get_fake_groundtruth_dict(input_height,
input_width,
stride,
has_depth=False):
"""Prepares the fake groundtruth dictionary.""" """Prepares the fake groundtruth dictionary."""
# A small box with center at (0.55, 0.55). # A small box with center at (0.55, 0.55).
boxes = [ boxes = [
...@@ -2146,6 +2373,26 @@ def get_fake_groundtruth_dict(input_height, input_width, stride): ...@@ -2146,6 +2373,26 @@ def get_fake_groundtruth_dict(input_height, input_width, stride):
axis=2), axis=2),
multiples=[1, 1, 2]), multiples=[1, 1, 2]),
] ]
if has_depth:
keypoint_depths = [
tf.constant([[float('nan'), 3.0,
float('nan'), 3.0, 0.55, 0.0]]),
tf.constant([[float('nan'), 0.55,
float('nan'), 0.55, 0.55, 0.0]])
]
keypoint_depth_weights = [
tf.constant([[1.0, 1.0, 1.0, 1.0, 0.0, 0.0]]),
tf.constant([[1.0, 1.0, 1.0, 1.0, 0.0, 0.0]])
]
else:
keypoint_depths = [
tf.constant([[0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]),
tf.constant([[0.0, 0.0, 0.0, 0.0, 0.0, 0.0]])
]
keypoint_depth_weights = [
tf.constant([[0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]),
tf.constant([[0.0, 0.0, 0.0, 0.0, 0.0, 0.0]])
]
labeled_classes = [ labeled_classes = [
tf.one_hot([1], depth=_NUM_CLASSES) + tf.one_hot([2], depth=_NUM_CLASSES), tf.one_hot([1], depth=_NUM_CLASSES) + tf.one_hot([2], depth=_NUM_CLASSES),
tf.one_hot([0], depth=_NUM_CLASSES) + tf.one_hot([1], depth=_NUM_CLASSES), tf.one_hot([0], depth=_NUM_CLASSES) + tf.one_hot([1], depth=_NUM_CLASSES),
...@@ -2187,11 +2434,12 @@ def get_fake_groundtruth_dict(input_height, input_width, stride): ...@@ -2187,11 +2434,12 @@ def get_fake_groundtruth_dict(input_height, input_width, stride):
fields.BoxListFields.weights: weights, fields.BoxListFields.weights: weights,
fields.BoxListFields.classes: classes, fields.BoxListFields.classes: classes,
fields.BoxListFields.keypoints: keypoints, fields.BoxListFields.keypoints: keypoints,
fields.BoxListFields.keypoint_depths: keypoint_depths,
fields.BoxListFields.keypoint_depth_weights: keypoint_depth_weights,
fields.BoxListFields.masks: masks, fields.BoxListFields.masks: masks,
fields.BoxListFields.densepose_num_points: densepose_num_points, fields.BoxListFields.densepose_num_points: densepose_num_points,
fields.BoxListFields.densepose_part_ids: densepose_part_ids, fields.BoxListFields.densepose_part_ids: densepose_part_ids,
fields.BoxListFields.densepose_surface_coords: fields.BoxListFields.densepose_surface_coords: densepose_surface_coords,
densepose_surface_coords,
fields.BoxListFields.track_ids: track_ids, fields.BoxListFields.track_ids: track_ids,
fields.BoxListFields.temporal_offsets: temporal_offsets, fields.BoxListFields.temporal_offsets: temporal_offsets,
fields.BoxListFields.track_match_flags: track_match_flags, fields.BoxListFields.track_match_flags: track_match_flags,
...@@ -2201,7 +2449,7 @@ def get_fake_groundtruth_dict(input_height, input_width, stride): ...@@ -2201,7 +2449,7 @@ def get_fake_groundtruth_dict(input_height, input_width, stride):
@unittest.skipIf(tf_version.is_tf1(), 'Skipping TF2.X only test.') @unittest.skipIf(tf_version.is_tf1(), 'Skipping TF2.X only test.')
class CenterNetMetaComputeLossTest(test_case.TestCase): class CenterNetMetaComputeLossTest(test_case.TestCase, parameterized.TestCase):
"""Test for CenterNet loss compuation related functions.""" """Test for CenterNet loss compuation related functions."""
def setUp(self): def setUp(self):
...@@ -2328,6 +2576,45 @@ class CenterNetMetaComputeLossTest(test_case.TestCase): ...@@ -2328,6 +2576,45 @@ class CenterNetMetaComputeLossTest(test_case.TestCase):
# The prediction and groundtruth are curated to produce very low loss. # The prediction and groundtruth are curated to produce very low loss.
self.assertGreater(0.01, loss) self.assertGreater(0.01, loss)
@parameterized.parameters(
{'per_keypoint_depth': False},
{'per_keypoint_depth': True},
)
def test_compute_kp_depth_loss(self, per_keypoint_depth):
prediction_dict = get_fake_prediction_dict(
self.input_height,
self.input_width,
self.stride,
per_keypoint_depth=per_keypoint_depth)
model = build_center_net_meta_arch(
num_classes=1,
per_keypoint_offset=per_keypoint_depth,
predict_depth=True,
per_keypoint_depth=per_keypoint_depth,
peak_radius=1 if per_keypoint_depth else 0)
model._groundtruth_lists = get_fake_groundtruth_dict(
self.input_height, self.input_width, self.stride, has_depth=True)
def graph_fn():
loss = model._compute_kp_depth_loss(
input_height=self.input_height,
input_width=self.input_width,
task_name=_TASK_NAME,
depth_predictions=prediction_dict[cnma.get_keypoint_name(
_TASK_NAME, cnma.KEYPOINT_DEPTH)],
localization_loss_fn=self.localization_loss_fn)
return loss
loss = self.execute(graph_fn, [])
if per_keypoint_depth:
# The loss is computed on a disk with radius 1 but only the center pixel
# has the accurate prediction. The final loss is (4 * |3-0|) / 5 = 2.4
self.assertAlmostEqual(2.4, loss, delta=1e-4)
else:
# The prediction and groundtruth are curated to produce very low loss.
self.assertGreater(0.01, loss)
def test_compute_track_embedding_loss(self): def test_compute_track_embedding_loss(self):
default_fc = self.model.track_reid_classification_net default_fc = self.model.track_reid_classification_net
# Initialize the kernel to extreme values so that the classification score # Initialize the kernel to extreme values so that the classification score
......
...@@ -165,6 +165,21 @@ message CenterNet { ...@@ -165,6 +165,21 @@ message CenterNet {
// out_height, out_width, 2 * num_keypoints] (recommended when the // out_height, out_width, 2 * num_keypoints] (recommended when the
// offset_peak_radius is not zero). // offset_peak_radius is not zero).
optional bool per_keypoint_offset = 18 [default = false]; optional bool per_keypoint_offset = 18 [default = false];
// Indicates whether to predict the depth of each keypoints. Note that this
// is only supported in the single class keypoint task.
optional bool predict_depth = 19 [default = false];
// Indicates whether to predict depths for each keypoint channel
// separately. If set False, the output depth target has the shape
// [batch_size, out_height, out_width, 1]. If set True, the output depth
// target has the shape [batch_size, out_height, out_width,
// num_keypoints]. Recommend to set this value and "per_keypoint_offset" to
// both be True at the same time.
optional bool per_keypoint_depth = 20 [default = false];
// The weight of the keypoint depth loss.
optional float keypoint_depth_loss_weight = 21 [default = 1.0];
} }
repeated KeypointEstimation keypoint_estimation_task = 7; repeated KeypointEstimation keypoint_estimation_task = 7;
...@@ -278,7 +293,6 @@ message CenterNet { ...@@ -278,7 +293,6 @@ message CenterNet {
// from CenterNet. Use this optional parameter to apply traditional non max // from CenterNet. Use this optional parameter to apply traditional non max
// suppression and score thresholding. // suppression and score thresholding.
optional PostProcessing post_processing = 24; optional PostProcessing post_processing = 24;
} }
message CenterNetFeatureExtractor { message CenterNetFeatureExtractor {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment