Commit eab78118 authored by Vighnesh Birodkar's avatar Vighnesh Birodkar Committed by TF Object Detection Team
Browse files

Handle label_confidences in random_square_crop_by_scale.

PiperOrigin-RevId: 324652818
parent 1ea5e1f6
......@@ -3971,9 +3971,10 @@ def _get_crop_border(border, size):
def random_square_crop_by_scale(image, boxes, labels, label_weights,
masks=None, keypoints=None, max_border=128,
scale_min=0.6, scale_max=1.3, num_scales=8,
seed=None, preprocess_vars_cache=None):
label_confidences=None, masks=None,
keypoints=None, max_border=128, scale_min=0.6,
scale_max=1.3, num_scales=8, seed=None,
preprocess_vars_cache=None):
"""Randomly crop a square in proportion to scale and image size.
Extract a square sized crop from an image whose side length is sampled by
......@@ -3993,6 +3994,8 @@ def random_square_crop_by_scale(image, boxes, labels, label_weights,
labels: rank 1 int32 tensor containing the object classes.
label_weights: float32 tensor of shape [num_instances] representing the
weight for each box.
label_confidences: (optional) float32 tensor of shape [num_instances]
representing the confidence for each box.
masks: (optional) rank 3 float32 tensor with shape
[num_instances, height, width] containing instance masks. The masks
are of the same height, width as the input `image`.
......@@ -4021,6 +4024,8 @@ def random_square_crop_by_scale(image, boxes, labels, label_weights,
Boxes are in normalized form.
labels: new labels.
label_weights: rank 1 float32 tensor with shape [num_instances].
label_confidences: (optional) float32 tensor of shape [num_instances]
representing the confidence for each box.
masks: rank 3 float32 tensor with shape [num_instances, height, width]
containing instance masks.
......@@ -4110,6 +4115,9 @@ def random_square_crop_by_scale(image, boxes, labels, label_weights,
tf.gather(labels, indices),
tf.gather(label_weights, indices)]
if label_confidences is not None:
return_values.append(tf.gather(label_confidences, indices))
if masks is not None:
new_masks = tf.expand_dims(masks, -1)
new_masks = new_masks[:, ymin:ymax, xmin:xmax]
......@@ -4483,8 +4491,8 @@ def get_default_func_arg_map(include_label_weights=True,
(fields.InputDataFields.image,
fields.InputDataFields.groundtruth_boxes,
fields.InputDataFields.groundtruth_classes,
groundtruth_label_weights, groundtruth_instance_masks,
groundtruth_keypoints),
groundtruth_label_weights, groundtruth_label_confidences,
groundtruth_instance_masks, groundtruth_keypoints),
random_scale_crop_and_pad_to_square:
(fields.InputDataFields.image,
fields.InputDataFields.groundtruth_boxes,
......@@ -4541,7 +4549,6 @@ def preprocess(tensor_dict,
"""
if func_arg_map is None:
func_arg_map = get_default_func_arg_map()
# changes the images to image (rank 4 to rank 3) since the functions
# receive rank 3 tensor for image
if fields.InputDataFields.image in tensor_dict:
......
......@@ -3814,21 +3814,23 @@ class PreprocessorTest(test_case.TestCase, parameterized.TestCase):
boxes = tf.constant([[0.25, .25, .75, .75]])
labels = tf.constant([[1]])
label_confidences = tf.constant([0.75])
label_weights = tf.constant([[1.]])
(new_image, new_boxes, _, _, new_masks,
(new_image, new_boxes, _, _, new_confidences, new_masks,
new_keypoints) = preprocessor.random_square_crop_by_scale(
image,
boxes,
labels,
label_weights,
label_confidences,
masks=masks,
keypoints=keypoints,
max_border=256,
scale_min=scale,
scale_max=scale)
return new_image, new_boxes, new_masks, new_keypoints
image, boxes, masks, keypoints = self.execute_cpu(graph_fn, [])
return new_image, new_boxes, new_confidences, new_masks, new_keypoints
image, boxes, confidences, masks, keypoints = self.execute_cpu(graph_fn, [])
ymin, xmin, ymax, xmax = boxes[0]
self.assertAlmostEqual(ymax - ymin, 0.5 / scale)
self.assertAlmostEqual(xmax - xmin, 0.5 / scale)
......@@ -3842,6 +3844,7 @@ class PreprocessorTest(test_case.TestCase, parameterized.TestCase):
self.assertAlmostEqual(scale * 256.0, size)
self.assertAllClose(image[:, :, 0], masks[0, :, :])
self.assertAllClose(confidences, [0.75])
@parameterized.named_parameters(('scale_0_1', 0.1), ('scale_1_0', 1.0),
('scale_2_0', 2.0))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment