Commit eab78118 authored by Vighnesh Birodkar's avatar Vighnesh Birodkar Committed by TF Object Detection Team
Browse files

Handle label_confidences in random_square_crop_by_scale.

PiperOrigin-RevId: 324652818
parent 1ea5e1f6
...@@ -3971,9 +3971,10 @@ def _get_crop_border(border, size): ...@@ -3971,9 +3971,10 @@ def _get_crop_border(border, size):
def random_square_crop_by_scale(image, boxes, labels, label_weights, def random_square_crop_by_scale(image, boxes, labels, label_weights,
masks=None, keypoints=None, max_border=128, label_confidences=None, masks=None,
scale_min=0.6, scale_max=1.3, num_scales=8, keypoints=None, max_border=128, scale_min=0.6,
seed=None, preprocess_vars_cache=None): scale_max=1.3, num_scales=8, seed=None,
preprocess_vars_cache=None):
"""Randomly crop a square in proportion to scale and image size. """Randomly crop a square in proportion to scale and image size.
Extract a square sized crop from an image whose side length is sampled by Extract a square sized crop from an image whose side length is sampled by
...@@ -3993,6 +3994,8 @@ def random_square_crop_by_scale(image, boxes, labels, label_weights, ...@@ -3993,6 +3994,8 @@ def random_square_crop_by_scale(image, boxes, labels, label_weights,
labels: rank 1 int32 tensor containing the object classes. labels: rank 1 int32 tensor containing the object classes.
label_weights: float32 tensor of shape [num_instances] representing the label_weights: float32 tensor of shape [num_instances] representing the
weight for each box. weight for each box.
label_confidences: (optional) float32 tensor of shape [num_instances]
representing the confidence for each box.
masks: (optional) rank 3 float32 tensor with shape masks: (optional) rank 3 float32 tensor with shape
[num_instances, height, width] containing instance masks. The masks [num_instances, height, width] containing instance masks. The masks
are of the same height, width as the input `image`. are of the same height, width as the input `image`.
...@@ -4021,6 +4024,8 @@ def random_square_crop_by_scale(image, boxes, labels, label_weights, ...@@ -4021,6 +4024,8 @@ def random_square_crop_by_scale(image, boxes, labels, label_weights,
Boxes are in normalized form. Boxes are in normalized form.
labels: new labels. labels: new labels.
label_weights: rank 1 float32 tensor with shape [num_instances]. label_weights: rank 1 float32 tensor with shape [num_instances].
label_confidences: (optional) float32 tensor of shape [num_instances]
representing the confidence for each box.
masks: rank 3 float32 tensor with shape [num_instances, height, width] masks: rank 3 float32 tensor with shape [num_instances, height, width]
containing instance masks. containing instance masks.
...@@ -4110,6 +4115,9 @@ def random_square_crop_by_scale(image, boxes, labels, label_weights, ...@@ -4110,6 +4115,9 @@ def random_square_crop_by_scale(image, boxes, labels, label_weights,
tf.gather(labels, indices), tf.gather(labels, indices),
tf.gather(label_weights, indices)] tf.gather(label_weights, indices)]
if label_confidences is not None:
return_values.append(tf.gather(label_confidences, indices))
if masks is not None: if masks is not None:
new_masks = tf.expand_dims(masks, -1) new_masks = tf.expand_dims(masks, -1)
new_masks = new_masks[:, ymin:ymax, xmin:xmax] new_masks = new_masks[:, ymin:ymax, xmin:xmax]
...@@ -4483,8 +4491,8 @@ def get_default_func_arg_map(include_label_weights=True, ...@@ -4483,8 +4491,8 @@ def get_default_func_arg_map(include_label_weights=True,
(fields.InputDataFields.image, (fields.InputDataFields.image,
fields.InputDataFields.groundtruth_boxes, fields.InputDataFields.groundtruth_boxes,
fields.InputDataFields.groundtruth_classes, fields.InputDataFields.groundtruth_classes,
groundtruth_label_weights, groundtruth_instance_masks, groundtruth_label_weights, groundtruth_label_confidences,
groundtruth_keypoints), groundtruth_instance_masks, groundtruth_keypoints),
random_scale_crop_and_pad_to_square: random_scale_crop_and_pad_to_square:
(fields.InputDataFields.image, (fields.InputDataFields.image,
fields.InputDataFields.groundtruth_boxes, fields.InputDataFields.groundtruth_boxes,
...@@ -4541,7 +4549,6 @@ def preprocess(tensor_dict, ...@@ -4541,7 +4549,6 @@ def preprocess(tensor_dict,
""" """
if func_arg_map is None: if func_arg_map is None:
func_arg_map = get_default_func_arg_map() func_arg_map = get_default_func_arg_map()
# changes the images to image (rank 4 to rank 3) since the functions # changes the images to image (rank 4 to rank 3) since the functions
# receive rank 3 tensor for image # receive rank 3 tensor for image
if fields.InputDataFields.image in tensor_dict: if fields.InputDataFields.image in tensor_dict:
......
...@@ -3814,21 +3814,23 @@ class PreprocessorTest(test_case.TestCase, parameterized.TestCase): ...@@ -3814,21 +3814,23 @@ class PreprocessorTest(test_case.TestCase, parameterized.TestCase):
boxes = tf.constant([[0.25, .25, .75, .75]]) boxes = tf.constant([[0.25, .25, .75, .75]])
labels = tf.constant([[1]]) labels = tf.constant([[1]])
label_confidences = tf.constant([0.75])
label_weights = tf.constant([[1.]]) label_weights = tf.constant([[1.]])
(new_image, new_boxes, _, _, new_masks, (new_image, new_boxes, _, _, new_confidences, new_masks,
new_keypoints) = preprocessor.random_square_crop_by_scale( new_keypoints) = preprocessor.random_square_crop_by_scale(
image, image,
boxes, boxes,
labels, labels,
label_weights, label_weights,
label_confidences,
masks=masks, masks=masks,
keypoints=keypoints, keypoints=keypoints,
max_border=256, max_border=256,
scale_min=scale, scale_min=scale,
scale_max=scale) scale_max=scale)
return new_image, new_boxes, new_masks, new_keypoints return new_image, new_boxes, new_confidences, new_masks, new_keypoints
image, boxes, masks, keypoints = self.execute_cpu(graph_fn, []) image, boxes, confidences, masks, keypoints = self.execute_cpu(graph_fn, [])
ymin, xmin, ymax, xmax = boxes[0] ymin, xmin, ymax, xmax = boxes[0]
self.assertAlmostEqual(ymax - ymin, 0.5 / scale) self.assertAlmostEqual(ymax - ymin, 0.5 / scale)
self.assertAlmostEqual(xmax - xmin, 0.5 / scale) self.assertAlmostEqual(xmax - xmin, 0.5 / scale)
...@@ -3842,6 +3844,7 @@ class PreprocessorTest(test_case.TestCase, parameterized.TestCase): ...@@ -3842,6 +3844,7 @@ class PreprocessorTest(test_case.TestCase, parameterized.TestCase):
self.assertAlmostEqual(scale * 256.0, size) self.assertAlmostEqual(scale * 256.0, size)
self.assertAllClose(image[:, :, 0], masks[0, :, :]) self.assertAllClose(image[:, :, 0], masks[0, :, :])
self.assertAllClose(confidences, [0.75])
@parameterized.named_parameters(('scale_0_1', 0.1), ('scale_1_0', 1.0), @parameterized.named_parameters(('scale_0_1', 0.1), ('scale_1_0', 1.0),
('scale_2_0', 2.0)) ('scale_2_0', 2.0))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment