Unverified Commit 19ed6e59 authored by srihari-humbarwadi's avatar srihari-humbarwadi
Browse files

support generating masks at model input size

parent 6366773f
......@@ -84,6 +84,7 @@ class PanopticSegmentationGenerator(hyperparams.Config):
things_class_label: int = 1
void_class_label: int = 0
void_instance_id: int = 0
rescale_predictions: bool = False
@dataclasses.dataclass
......
......@@ -110,7 +110,8 @@ def build_panoptic_maskrcnn(
things_class_label=postprocessing_config.things_class_label,
stuff_area_threshold=postprocessing_config.stuff_area_threshold,
void_class_label=postprocessing_config.void_class_label,
void_instance_id=postprocessing_config.void_instance_id)
void_instance_id=postprocessing_config.void_instance_id,
rescale_predictions=postprocessing_config.rescale_predictions)
else:
panoptic_segmentation_generator_obj = None
......
......@@ -35,6 +35,7 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer):
things_class_label: int = 1,
void_class_label: int = 0,
void_instance_id: int = -1,
rescale_predictions: bool = False,
**kwargs):
"""Generates panoptic segmentation masks.
......@@ -55,6 +56,8 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer):
void_instance_id: An `int` that is used to denote regions that are not
assigned to any thing class. That is, void_instance_id are assigned to
both stuff regions and empty regions.
rescale_predictions: `bool`, whether to scale back prediction to original
image sizes. If True, image_info `dict` is used to rescale predictions.
**kwargs: additional kewargs arguments.
"""
self._output_size = output_size
......@@ -67,6 +70,7 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer):
self._things_class_label = things_class_label
self._void_class_label = void_class_label
self._void_instance_id = void_instance_id
self._rescale_predictions = rescale_predictions
self._config_dict = {
'output_size': output_size,
......@@ -76,7 +80,8 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer):
'score_threshold': score_threshold,
'things_class_label': things_class_label,
'void_class_label': void_class_label,
'void_instance_id': void_instance_id
'void_instance_id': void_instance_id,
'rescale_predictions': rescale_predictions
}
super(PanopticSegmentationGenerator, self).__init__(**kwargs)
......@@ -250,23 +255,30 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer):
batched_classes = detections['detection_classes']
batched_detections_masks = tf.expand_dims(
detections['detection_masks'], axis=-1)
batched_boxes = detections['detection_boxes']
scale = tf.tile(
tf.cast(image_info[:, 2:3, :], dtype=batched_boxes.dtype),
multiples=[1, 1, 2])
batched_boxes /= scale
batched_segmentation_masks = tf.cast(
detections['segmentation_outputs'], dtype=tf.float32)
batched_segmentation_masks = tf.map_fn(
fn=lambda x: self._resize_and_pad_masks(
x[0], x[1]),
elems=(
batched_segmentation_masks,
image_info),
fn_output_signature=tf.float32,
parallel_iterations=32)
if self._rescale_predictions:
scale = tf.tile(
tf.cast(image_info[:, 2:3, :], dtype=batched_boxes.dtype),
multiples=[1, 1, 2])
batched_boxes /= scale
batched_segmentation_masks = tf.map_fn(
fn=lambda x: self._resize_and_pad_masks(
x[0], x[1]),
elems=(
batched_segmentation_masks,
image_info),
fn_output_signature=tf.float32,
parallel_iterations=32)
else:
batched_segmentation_masks = tf.image.resize(
batched_segmentation_masks,
size=self._output_size,
method='bilinear')
batched_segmentation_masks = tf.expand_dims(tf.cast(
tf.argmax(batched_segmentation_masks, axis=-1),
dtype=tf.float32), axis=-1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment