Unverified Commit 43081990 authored by srihari-humbarwadi's avatar srihari-humbarwadi
Browse files

use grid sampling to paste masks

parent 4f536f45
......@@ -18,6 +18,7 @@ from typing import List
import tensorflow as tf
from official.vision.beta.projects.panoptic_maskrcnn.modeling.layers import paste_masks
class PanopticSegmentationGenerator(tf.keras.layers.Layer):
"""Panoptic segmentation generator layer."""
......@@ -79,34 +80,10 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer):
}
super(PanopticSegmentationGenerator, self).__init__(**kwargs)
def _paste_mask(self, box, mask):
pasted_mask = tf.ones(
self._output_size + [1], dtype=mask.dtype) * self._void_class_label
ymin = tf.clip_by_value(box[0], 0, self._output_size[0])
xmin = tf.clip_by_value(box[1], 0, self._output_size[1])
ymax = tf.clip_by_value(box[2] + 1, 0, self._output_size[0])
xmax = tf.clip_by_value(box[3] + 1, 0, self._output_size[1])
box_height = ymax - ymin
box_width = xmax - xmin
if not (box_height == 0 or box_width == 0):
# resize mask to match the shape of the instance bounding box
resized_mask = tf.image.resize(
mask,
size=(box_height, box_width),
method='bilinear')
resized_mask = tf.cast(resized_mask, dtype=mask.dtype)
# paste resized mask on a blank mask that matches image shape
pasted_mask = tf.raw_ops.TensorStridedSliceUpdate(
input=pasted_mask,
begin=[ymin, xmin],
end=[ymax, xmax],
strides=[1, 1],
value=resized_mask)
return pasted_mask
def build(self, input_shape):
grid_sampler = paste_masks.BilinearGridSampler(align_corners=False)
self._paste_masks_fn = paste_masks.PasteMasks(
output_size=self._output_size, grid_sampler=grid_sampler)
def _generate_panoptic_masks(self, boxes, scores, classes, detections_masks,
segmentation_mask):
......@@ -138,6 +115,9 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer):
- category_mask: A `tf.Tensor` for category masks.
- instance_mask: A `tf.Tensor for instance masks.
"""
# Paste instance masks
pasted_masks = self._paste_masks_fn((detections_masks, boxes))
# Offset stuff class predictions
segmentation_mask = tf.where(
tf.logical_or(
......@@ -155,6 +135,7 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer):
instance_mask = tf.ones(
mask_shape, dtype=tf.float32) * self._void_instance_id
# filter instances with low confidence
sorted_scores = tf.sort(scores, direction='DESCENDING')
......@@ -174,9 +155,7 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer):
# the overlaps are resolved based on confidence score
instance_idx = sorted_indices[i]
pasted_mask = self._paste_mask(
box=boxes[instance_idx],
mask=detections_masks[instance_idx])
pasted_mask = pasted_masks[instance_idx]
class_id = tf.cast(classes[instance_idx], dtype=tf.float32)
......@@ -248,7 +227,7 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer):
batched_scores = detections['detection_scores']
batched_classes = detections['detection_classes']
batched_boxes = tf.cast(detections['detection_boxes'], dtype=tf.int32)
batched_boxes = detections['detection_boxes']
batched_detections_masks = tf.expand_dims(
detections['detection_masks'], axis=-1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment