Commit 58644b96 authored by Vighnesh Birodkar's avatar Vighnesh Birodkar Committed by TF Object Detection Team
Browse files

Make implementation reframe_box_masks_to_image_masks to be more clear.

PiperOrigin-RevId: 398596089
parent 0f0b060c
...@@ -947,7 +947,7 @@ class ObjectDetectionEvaluatorTest(tf.test.TestCase, parameterized.TestCase): ...@@ -947,7 +947,7 @@ class ObjectDetectionEvaluatorTest(tf.test.TestCase, parameterized.TestCase):
axis=0) axis=0)
detection_classes = tf.tile(tf.constant([[0]]), multiples=[batch_size, 1]) detection_classes = tf.tile(tf.constant([[0]]), multiples=[batch_size, 1])
detection_masks = tf.tile( detection_masks = tf.tile(
tf.ones(shape=[1, 2, 20, 20], dtype=tf.float32), tf.ones(shape=[1, 1, 20, 20], dtype=tf.float32),
multiples=[batch_size, 1, 1, 1]) multiples=[batch_size, 1, 1, 1])
groundtruth_boxes = tf.constant([[0., 0., 1., 1.]]) groundtruth_boxes = tf.constant([[0., 0., 1., 1.]])
groundtruth_classes = tf.constant([1]) groundtruth_classes = tf.constant([1])
...@@ -972,6 +972,7 @@ class ObjectDetectionEvaluatorTest(tf.test.TestCase, parameterized.TestCase): ...@@ -972,6 +972,7 @@ class ObjectDetectionEvaluatorTest(tf.test.TestCase, parameterized.TestCase):
detection_fields.detection_masks: detection_masks, detection_fields.detection_masks: detection_masks,
detection_fields.num_detections: num_detections detection_fields.num_detections: num_detections
} }
groundtruth = { groundtruth = {
input_data_fields.groundtruth_boxes: input_data_fields.groundtruth_boxes:
groundtruth_boxes, groundtruth_boxes,
......
...@@ -798,6 +798,31 @@ def position_sensitive_crop_regions(image, ...@@ -798,6 +798,31 @@ def position_sensitive_crop_regions(image,
return position_sensitive_features return position_sensitive_features
def reframe_image_corners_relative_to_boxes(boxes):
"""Reframe the image corners ([0, 0, 1, 1]) to be relative to boxes.
The local coordinate frame of each box is assumed to be relative to
its own for corners.
Args:
boxes: A float tensor of [num_boxes, 4] of (ymin, xmin, ymax, xmax)
coordinates in relative coordinate space of each bounding box.
Returns:
reframed_boxes: Reframes boxes with same shape as input.
"""
ymin, xmin, ymax, xmax = tf.unstack(boxes, axis=1)
height = tf.maximum(ymax - ymin, 1e-4)
width = tf.maximum(xmax - xmin, 1e-4)
ymin_out = (0 - ymin) / height
xmin_out = (0 - xmin) / width
ymax_out = (1 - ymin) / height
xmax_out = (1 - xmin) / width
return tf.stack([ymin_out, xmin_out, ymax_out, xmax_out], axis=1)
def reframe_box_masks_to_image_masks(box_masks, boxes, image_height, def reframe_box_masks_to_image_masks(box_masks, boxes, image_height,
image_width, resize_method='bilinear'): image_width, resize_method='bilinear'):
"""Transforms the box masks back to full image masks. """Transforms the box masks back to full image masks.
...@@ -826,27 +851,16 @@ def reframe_box_masks_to_image_masks(box_masks, boxes, image_height, ...@@ -826,27 +851,16 @@ def reframe_box_masks_to_image_masks(box_masks, boxes, image_height,
# TODO(rathodv): Make this a public function. # TODO(rathodv): Make this a public function.
def reframe_box_masks_to_image_masks_default(): def reframe_box_masks_to_image_masks_default():
"""The default function when there are more than 0 box masks.""" """The default function when there are more than 0 box masks."""
def transform_boxes_relative_to_boxes(boxes, reference_boxes):
boxes = tf.reshape(boxes, [-1, 2, 2])
min_corner = tf.expand_dims(reference_boxes[:, 0:2], 1)
max_corner = tf.expand_dims(reference_boxes[:, 2:4], 1)
denom = max_corner - min_corner
# Prevent a divide by zero.
denom = tf.math.maximum(denom, 1e-4)
transformed_boxes = (boxes - min_corner) / denom
return tf.reshape(transformed_boxes, [-1, 4])
num_boxes = tf.shape(box_masks)[0]
box_masks_expanded = tf.expand_dims(box_masks, axis=3) box_masks_expanded = tf.expand_dims(box_masks, axis=3)
num_boxes = tf.shape(box_masks_expanded)[0]
unit_boxes = tf.concat(
[tf.zeros([num_boxes, 2]), tf.ones([num_boxes, 2])], axis=1)
reverse_boxes = transform_boxes_relative_to_boxes(unit_boxes, boxes)
# TODO(vighneshb) Use matmul_crop_and_resize so that the output shape # TODO(vighneshb) Use matmul_crop_and_resize so that the output shape
# is static. This will help us run and test on TPUs. # is static. This will help us run and test on TPUs.
resized_crops = tf.image.crop_and_resize( resized_crops = tf.image.crop_and_resize(
image=box_masks_expanded, image=box_masks_expanded,
boxes=reverse_boxes, boxes=reframe_image_corners_relative_to_boxes(boxes),
box_ind=tf.range(num_boxes), box_ind=tf.range(num_boxes),
crop_size=[image_height, image_width], crop_size=[image_height, image_width],
method=resize_method, method=resize_method,
......
...@@ -1195,6 +1195,14 @@ class OpsTestBatchPositionSensitiveCropRegions(test_case.TestCase): ...@@ -1195,6 +1195,14 @@ class OpsTestBatchPositionSensitiveCropRegions(test_case.TestCase):
class ReframeBoxMasksToImageMasksTest(test_case.TestCase, class ReframeBoxMasksToImageMasksTest(test_case.TestCase,
parameterized.TestCase): parameterized.TestCase):
def test_reframe_image_corners_relative_to_boxes(self):
def graph_fn():
return ops.reframe_image_corners_relative_to_boxes(
tf.constant([[0.1, 0.2, 0.3, 0.4]]))
np_boxes = self.execute_cpu(graph_fn, [])
self.assertAllClose(np_boxes, [[-0.5, -1, 4.5, 4.]])
@parameterized.parameters( @parameterized.parameters(
{'mask_dtype': tf.float32, 'mask_dtype_np': np.float32, {'mask_dtype': tf.float32, 'mask_dtype_np': np.float32,
'resize_method': 'bilinear'}, 'resize_method': 'bilinear'},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment