Commit 8a16208b authored by A. Unique TensorFlower's avatar A. Unique TensorFlower Committed by TF Object Detection Team
Browse files

Call set_keypoint_visibilities on the non-expanded versions of the detection...

Call set_keypoint_visibilities on the non-expanded versions of the detection and groundtruth keypoints. set_keypoint_visibilities expects a rank-3 tensor, and was being provided a rank-4 tensor. This had the unintended effect of creating a keypoint visibilities tensor of the wrong shape, resulting in only 2 keypoints being visualized.

PiperOrigin-RevId: 368458273
parent 76476cd9
......@@ -56,6 +56,7 @@ def clip_to_window(keypoints, window, scope=None):
Returns:
new_keypoints: a tensor of shape [num_instances, num_keypoints, 2]
"""
keypoints.get_shape().assert_has_rank(3)
with tf.name_scope(scope, 'ClipToWindow'):
y, x = tf.split(value=keypoints, num_or_size_splits=2, axis=2)
win_y_min, win_x_min, win_y_max, win_x_max = tf.unstack(window)
......@@ -81,6 +82,7 @@ def prune_outside_window(keypoints, window, scope=None):
Returns:
new_keypoints: a tensor of shape [num_instances, num_keypoints, 2]
"""
keypoints.get_shape().assert_has_rank(3)
with tf.name_scope(scope, 'PruneOutsideWindow'):
y, x = tf.split(value=keypoints, num_or_size_splits=2, axis=2)
win_y_min, win_x_min, win_y_max, win_x_max = tf.unstack(window)
......@@ -242,6 +244,7 @@ def flip_horizontal(keypoints, flip_point, flip_permutation=None, scope=None):
Returns:
new_keypoints: a tensor of shape [num_instances, num_keypoints, 2]
"""
keypoints.get_shape().assert_has_rank(3)
with tf.name_scope(scope, 'FlipHorizontal'):
keypoints = tf.transpose(keypoints, [1, 0, 2])
if flip_permutation:
......@@ -276,6 +279,7 @@ def flip_vertical(keypoints, flip_point, flip_permutation=None, scope=None):
Returns:
new_keypoints: a tensor of shape [num_instances, num_keypoints, 2]
"""
keypoints.get_shape().assert_has_rank(3)
with tf.name_scope(scope, 'FlipVertical'):
keypoints = tf.transpose(keypoints, [1, 0, 2])
if flip_permutation:
......@@ -301,6 +305,7 @@ def rot90(keypoints, rotation_permutation=None, scope=None):
Returns:
new_keypoints: a tensor of shape [num_instances, num_keypoints, 2]
"""
keypoints.get_shape().assert_has_rank(3)
with tf.name_scope(scope, 'Rot90'):
keypoints = tf.transpose(keypoints, [1, 0, 2])
if rotation_permutation:
......@@ -336,6 +341,7 @@ def keypoint_weights_from_visibilities(keypoint_visibilities,
keypoints deemed visible will have the provided per-keypoint weight, and
all others will be set to zero.
"""
keypoint_visibilities.get_shape().assert_has_rank(2)
if per_keypoint_weights is None:
num_keypoints = keypoint_visibilities.shape.as_list()[1]
per_keypoint_weight_mult = tf.ones((1, num_keypoints,), dtype=tf.float32)
......@@ -365,6 +371,7 @@ def set_keypoint_visibilities(keypoints, initial_keypoint_visibilities=None):
keypoint_visibilities: a bool tensor of shape [num_instances, num_keypoints]
indicating whether a keypoint is visible or not.
"""
keypoints.get_shape().assert_has_rank(3)
if initial_keypoint_visibilities is not None:
keypoint_visibilities = tf.cast(initial_keypoint_visibilities, tf.bool)
else:
......
......@@ -684,8 +684,10 @@ def draw_side_by_side_evaluation_image(eval_dict,
keypoint_scores = tf.expand_dims(
eval_dict[detection_fields.detection_keypoint_scores][indx], axis=0)
else:
keypoint_scores = tf.cast(keypoint_ops.set_keypoint_visibilities(
keypoints), dtype=tf.float32)
keypoint_scores = tf.expand_dims(tf.cast(
keypoint_ops.set_keypoint_visibilities(
eval_dict[detection_fields.detection_keypoints][indx]),
dtype=tf.float32), axis=0)
groundtruth_instance_masks = None
if input_data_fields.groundtruth_instance_masks in eval_dict:
......@@ -703,9 +705,10 @@ def draw_side_by_side_evaluation_image(eval_dict,
groundtruth_keypoint_scores = tf.expand_dims(
tf.cast(eval_dict[gt_kpt_vis_fld][indx], dtype=tf.float32), axis=0)
else:
groundtruth_keypoint_scores = tf.cast(
groundtruth_keypoint_scores = tf.expand_dims(tf.cast(
keypoint_ops.set_keypoint_visibilities(
groundtruth_keypoints), dtype=tf.float32)
eval_dict[input_data_fields.groundtruth_keypoints][indx]),
dtype=tf.float32), axis=0)
images_with_detections = draw_bounding_boxes_on_image_tensors(
tf.expand_dims(
eval_dict[input_data_fields.original_image][indx], axis=0),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment