Unverified Commit 3982b139 authored by Srihari Humbarwadi's avatar Srihari Humbarwadi Committed by GitHub
Browse files

Merge pull request #10 from srihari-humbarwadi/rescale_predictions-2

Generate panoptic masks at original image resolution
parents 0bee3f54 1f2ec20d
...@@ -32,28 +32,13 @@ import tensorflow as tf ...@@ -32,28 +32,13 @@ import tensorflow as tf
from official.vision.beta.evaluation import panoptic_quality from official.vision.beta.evaluation import panoptic_quality
def rescale_masks(groundtruth, prediction, image_info): def _crop_padding(mask, image_info):
rescale_size = tf.cast(
tf.math.ceil(image_info[1, :] / image_info[2, :]), tf.int32)
image_shape = tf.cast(image_info[0, :], tf.int32) image_shape = tf.cast(image_info[0, :], tf.int32)
offsets = tf.cast(image_info[3, :], tf.int32) mask = tf.image.crop_to_bounding_box(
tf.expand_dims(mask, axis=-1), 0, 0,
prediction = tf.image.resize(
tf.expand_dims(prediction, axis=-1),
rescale_size,
method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
prediction = tf.image.crop_to_bounding_box(
prediction,
offsets[0], offsets[1],
image_shape[0],
image_shape[1])
groundtruth = tf.image.crop_to_bounding_box(
tf.expand_dims(groundtruth, axis=-1), 0, 0,
image_shape[0], image_shape[1]) image_shape[0], image_shape[1])
return ( return tf.expand_dims(mask[:, :, 0], axis=0)
tf.expand_dims(groundtruth[:, :, 0], axis=0),
tf.expand_dims(prediction[:, :, 0], axis=0))
class PanopticQualityEvaluator: class PanopticQualityEvaluator:
"""Panoptic Quality metric class.""" """Panoptic Quality metric class."""
...@@ -169,21 +154,17 @@ class PanopticQualityEvaluator: ...@@ -169,21 +154,17 @@ class PanopticQualityEvaluator:
if self._rescale_predictions: if self._rescale_predictions:
for idx in range(len(groundtruths['category_mask'])): for idx in range(len(groundtruths['category_mask'])):
image_info = groundtruths['image_info'][idx] image_info = groundtruths['image_info'][idx]
groundtruth_category_mask, prediction_category_mask = rescale_masks(
groundtruths['category_mask'][idx],
predictions['category_mask'][idx],
image_info)
groundtruth_instance_mask, prediction_instance_mask = rescale_masks(
groundtruths['instance_mask'][idx],
predictions['instance_mask'][idx],
image_info)
_groundtruths = { _groundtruths = {
'category_mask': groundtruth_category_mask, 'category_mask':
'instance_mask': groundtruth_instance_mask _crop_padding(groundtruths['category_mask'][idx], image_info),
'instance_mask':
_crop_padding(groundtruths['instance_mask'][idx], image_info),
} }
_predictions = { _predictions = {
'category_mask': prediction_category_mask, 'category_mask':
'instance_mask': prediction_instance_mask _crop_padding(predictions['category_mask'][idx], image_info),
'instance_mask':
_crop_padding(predictions['instance_mask'][idx], image_info),
} }
_groundtruths, _predictions = self._convert_to_numpy( _groundtruths, _predictions = self._convert_to_numpy(
_groundtruths, _predictions) _groundtruths, _predictions)
......
...@@ -84,6 +84,7 @@ class PanopticSegmentationGenerator(hyperparams.Config): ...@@ -84,6 +84,7 @@ class PanopticSegmentationGenerator(hyperparams.Config):
things_class_label: int = 1 things_class_label: int = 1
void_class_label: int = 0 void_class_label: int = 0
void_instance_id: int = 0 void_instance_id: int = 0
rescale_predictions: bool = False
@dataclasses.dataclass @dataclasses.dataclass
...@@ -182,7 +183,7 @@ def panoptic_fpn_coco() -> cfg.ExperimentConfig: ...@@ -182,7 +183,7 @@ def panoptic_fpn_coco() -> cfg.ExperimentConfig:
model=PanopticMaskRCNN( model=PanopticMaskRCNN(
num_classes=91, input_size=[1024, 1024, 3], num_classes=91, input_size=[1024, 1024, 3],
panoptic_segmentation_generator=PanopticSegmentationGenerator( panoptic_segmentation_generator=PanopticSegmentationGenerator(
output_size=[1024, 1024]), output_size=[640, 640], rescale_predictions=True),
stuff_classes_offset=90, stuff_classes_offset=90,
segmentation_model=SEGMENTATION_MODEL( segmentation_model=SEGMENTATION_MODEL(
num_classes=num_semantic_segmentation_classes, num_classes=num_semantic_segmentation_classes,
......
...@@ -110,7 +110,8 @@ def build_panoptic_maskrcnn( ...@@ -110,7 +110,8 @@ def build_panoptic_maskrcnn(
things_class_label=postprocessing_config.things_class_label, things_class_label=postprocessing_config.things_class_label,
stuff_area_threshold=postprocessing_config.stuff_area_threshold, stuff_area_threshold=postprocessing_config.stuff_area_threshold,
void_class_label=postprocessing_config.void_class_label, void_class_label=postprocessing_config.void_class_label,
void_instance_id=postprocessing_config.void_instance_id) void_instance_id=postprocessing_config.void_instance_id,
rescale_predictions=postprocessing_config.rescale_predictions)
else: else:
panoptic_segmentation_generator_obj = None panoptic_segmentation_generator_obj = None
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
"""Contains definition for postprocessing layer to genrate panoptic segmentations.""" """Contains definition for postprocessing layer to genrate panoptic segmentations."""
from typing import List from typing import List, Mapping
import tensorflow as tf import tensorflow as tf
...@@ -35,6 +35,7 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer): ...@@ -35,6 +35,7 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer):
things_class_label: int = 1, things_class_label: int = 1,
void_class_label: int = 0, void_class_label: int = 0,
void_instance_id: int = -1, void_instance_id: int = -1,
rescale_predictions: bool = False,
**kwargs): **kwargs):
"""Generates panoptic segmentation masks. """Generates panoptic segmentation masks.
...@@ -55,6 +56,8 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer): ...@@ -55,6 +56,8 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer):
void_instance_id: An `int` that is used to denote regions that are not void_instance_id: An `int` that is used to denote regions that are not
assigned to any thing class. That is, void_instance_id are assigned to assigned to any thing class. That is, void_instance_id are assigned to
both stuff regions and empty regions. both stuff regions and empty regions.
rescale_predictions: `bool`, whether to scale back prediction to original
image sizes. If True, image_info `dict` is used to rescale predictions.
**kwargs: additional kewargs arguments. **kwargs: additional kewargs arguments.
""" """
self._output_size = output_size self._output_size = output_size
...@@ -67,6 +70,7 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer): ...@@ -67,6 +70,7 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer):
self._things_class_label = things_class_label self._things_class_label = things_class_label
self._void_class_label = void_class_label self._void_class_label = void_class_label
self._void_instance_id = void_instance_id self._void_instance_id = void_instance_id
self._rescale_predictions = rescale_predictions
self._config_dict = { self._config_dict = {
'output_size': output_size, 'output_size': output_size,
...@@ -76,7 +80,8 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer): ...@@ -76,7 +80,8 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer):
'score_threshold': score_threshold, 'score_threshold': score_threshold,
'things_class_label': things_class_label, 'things_class_label': things_class_label,
'void_class_label': void_class_label, 'void_class_label': void_class_label,
'void_instance_id': void_instance_id 'void_instance_id': void_instance_id,
'rescale_predictions': rescale_predictions
} }
super(PanopticSegmentationGenerator, self).__init__(**kwargs) super(PanopticSegmentationGenerator, self).__init__(**kwargs)
...@@ -224,19 +229,56 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer): ...@@ -224,19 +229,56 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer):
} }
return results return results
def call(self, inputs): def _resize_and_pad_masks(self, mask, image_info):
rescale_size = tf.cast(
tf.math.ceil(image_info[1, :] / image_info[2, :]), tf.int32)
image_shape = tf.cast(image_info[0, :], tf.int32)
offsets = tf.cast(image_info[3, :], tf.int32)
mask = tf.image.resize(
mask,
rescale_size,
method=tf.image.ResizeMethod.BILINEAR)
mask = tf.image.crop_to_bounding_box(
mask,
offsets[0], offsets[1],
image_shape[0],
image_shape[1])
mask = tf.image.pad_to_bounding_box(
mask, 0, 0, self._output_size[0], self._output_size[1])
return mask
def call(self, inputs: tf.Tensor, image_info: Mapping[str, tf.Tensor]):
detections = inputs detections = inputs
batched_scores = detections['detection_scores'] batched_scores = detections['detection_scores']
batched_classes = detections['detection_classes'] batched_classes = detections['detection_classes']
batched_boxes = detections['detection_boxes']
batched_detections_masks = tf.expand_dims( batched_detections_masks = tf.expand_dims(
detections['detection_masks'], axis=-1) detections['detection_masks'], axis=-1)
batched_boxes = detections['detection_boxes']
batched_segmentation_masks = tf.cast(
detections['segmentation_outputs'], dtype=tf.float32)
if self._rescale_predictions:
scale = tf.tile(
tf.cast(image_info[:, 2:3, :], dtype=batched_boxes.dtype),
multiples=[1, 1, 2])
batched_boxes /= scale
batched_segmentation_masks = tf.map_fn(
fn=lambda x: self._resize_and_pad_masks(
x[0], x[1]),
elems=(
batched_segmentation_masks,
image_info),
fn_output_signature=tf.float32,
parallel_iterations=32)
else:
batched_segmentation_masks = tf.image.resize(
batched_segmentation_masks,
size=self._output_size,
method='bilinear')
batched_segmentation_masks = tf.image.resize(
detections['segmentation_outputs'],
size=self._output_size,
method='bilinear')
batched_segmentation_masks = tf.expand_dims(tf.cast( batched_segmentation_masks = tf.expand_dims(tf.cast(
tf.argmax(batched_segmentation_masks, axis=-1), tf.argmax(batched_segmentation_masks, axis=-1),
dtype=tf.float32), axis=-1) dtype=tf.float32), axis=-1)
...@@ -253,7 +295,7 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer): ...@@ -253,7 +295,7 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer):
fn_output_signature={ fn_output_signature={
'category_mask': tf.float32, 'category_mask': tf.float32,
'instance_mask': tf.float32 'instance_mask': tf.float32
}) }, parallel_iterations=32)
for k, v in panoptic_masks.items(): for k, v in panoptic_masks.items():
panoptic_masks[k] = tf.cast(v, dtype=tf.int32) panoptic_masks[k] = tf.cast(v, dtype=tf.int32)
......
...@@ -143,12 +143,13 @@ class PanopticMaskRCNNModel(maskrcnn_model.MaskRCNNModel): ...@@ -143,12 +143,13 @@ class PanopticMaskRCNNModel(maskrcnn_model.MaskRCNNModel):
def call(self, def call(self,
images: tf.Tensor, images: tf.Tensor,
image_shape: tf.Tensor, image_info: Mapping[str, tf.Tensor],
anchor_boxes: Optional[Mapping[str, tf.Tensor]] = None, anchor_boxes: Optional[Mapping[str, tf.Tensor]] = None,
gt_boxes: Optional[tf.Tensor] = None, gt_boxes: Optional[tf.Tensor] = None,
gt_classes: Optional[tf.Tensor] = None, gt_classes: Optional[tf.Tensor] = None,
gt_masks: Optional[tf.Tensor] = None, gt_masks: Optional[tf.Tensor] = None,
training: Optional[bool] = None) -> Mapping[str, tf.Tensor]: training: Optional[bool] = None) -> Mapping[str, tf.Tensor]:
image_shape = image_info[:, 1, :]
model_outputs = super(PanopticMaskRCNNModel, self).call( model_outputs = super(PanopticMaskRCNNModel, self).call(
images=images, images=images,
image_shape=image_shape, image_shape=image_shape,
...@@ -177,7 +178,8 @@ class PanopticMaskRCNNModel(maskrcnn_model.MaskRCNNModel): ...@@ -177,7 +178,8 @@ class PanopticMaskRCNNModel(maskrcnn_model.MaskRCNNModel):
}) })
if not training and self.panoptic_segmentation_generator is not None: if not training and self.panoptic_segmentation_generator is not None:
panoptic_outputs = self.panoptic_segmentation_generator(model_outputs) panoptic_outputs = self.panoptic_segmentation_generator(
model_outputs, image_info=image_info)
model_outputs.update({'panoptic_outputs': panoptic_outputs}) model_outputs.update({'panoptic_outputs': panoptic_outputs})
return model_outputs return model_outputs
......
...@@ -286,7 +286,7 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask): ...@@ -286,7 +286,7 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
with tf.GradientTape() as tape: with tf.GradientTape() as tape:
outputs = model( outputs = model(
images, images,
image_shape=labels['image_info'][:, 1, :], image_info=labels['image_info'],
anchor_boxes=labels['anchor_boxes'], anchor_boxes=labels['anchor_boxes'],
gt_boxes=labels['gt_boxes'], gt_boxes=labels['gt_boxes'],
gt_classes=labels['gt_classes'], gt_classes=labels['gt_classes'],
...@@ -355,7 +355,7 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask): ...@@ -355,7 +355,7 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
outputs = model( outputs = model(
images, images,
anchor_boxes=labels['anchor_boxes'], anchor_boxes=labels['anchor_boxes'],
image_shape=labels['image_info'][:, 1, :], image_info=labels['image_info'],
training=False) training=False)
logs = {self.loss: 0} logs = {self.loss: 0}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment