Commit 42b49ff1 authored by Vighnesh Birodkar's avatar Vighnesh Birodkar Committed by TF Object Detection Team
Browse files

Make resize_detection_masks function public.

PiperOrigin-RevId: 328348470
parent b6b2eb2d
...@@ -559,6 +559,30 @@ def _resize_detection_masks(args): ...@@ -559,6 +559,30 @@ def _resize_detection_masks(args):
return tf.cast(detection_masks_reframed, tf.uint8) return tf.cast(detection_masks_reframed, tf.uint8)
def resize_detection_masks(detection_boxes, detection_masks,
original_image_spatial_shapes):
"""Resizes per-box detection masks to be relative to the entire image.
Note that this function only works when the spatial size of all images in
the batch is the same. If not, this function should be used with batch_size=1.
Args:
detection_boxes: A [batch_size, num_instances, 4] float tensor containing
bounding boxes.
detection_masks: A [batch_suze, num_instances, height, width] float tensor
containing binary instance masks per box.
original_image_spatial_shapes: a [batch_size, 3] shaped int tensor
holding the spatial dimensions of each image in the batch.
Returns:
masks: Masks resized to the spatial extents given by
(original_image_spatial_shapes[0, 0], original_image_spatial_shapes[0, 1])
"""
return shape_utils.static_or_dynamic_map_fn(
_resize_detection_masks,
elems=[detection_boxes, detection_masks, original_image_spatial_shapes],
dtype=tf.uint8)
def _resize_groundtruth_masks(args): def _resize_groundtruth_masks(args):
"""Resizes groundgtruth masks to the original image size.""" """Resizes groundgtruth masks to the original image size."""
mask, true_image_shape, original_image_shape = args mask, true_image_shape, original_image_shape = args
...@@ -869,12 +893,9 @@ def result_dict_for_batched_example(images, ...@@ -869,12 +893,9 @@ def result_dict_for_batched_example(images,
if detection_fields.detection_masks in detections: if detection_fields.detection_masks in detections:
detection_masks = detections[detection_fields.detection_masks] detection_masks = detections[detection_fields.detection_masks]
output_dict[detection_fields.detection_masks] = ( output_dict[detection_fields.detection_masks] = resize_detection_masks(
shape_utils.static_or_dynamic_map_fn( detection_boxes, detection_masks, original_image_spatial_shapes)
_resize_detection_masks,
elems=[detection_boxes, detection_masks,
original_image_spatial_shapes],
dtype=tf.uint8))
if detection_fields.detection_surface_coords in detections: if detection_fields.detection_surface_coords in detections:
detection_surface_coords = detections[ detection_surface_coords = detections[
detection_fields.detection_surface_coords] detection_fields.detection_surface_coords]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment