"vscode:/vscode.git/clone" did not exist on "417c5578e1c30c86c634076bdd9c8ccc58a97d18"
Unverified Commit 19ed6e59 authored by srihari-humbarwadi's avatar srihari-humbarwadi
Browse files

support generating masks at model input size

parent 6366773f
...@@ -84,6 +84,7 @@ class PanopticSegmentationGenerator(hyperparams.Config): ...@@ -84,6 +84,7 @@ class PanopticSegmentationGenerator(hyperparams.Config):
things_class_label: int = 1 things_class_label: int = 1
void_class_label: int = 0 void_class_label: int = 0
void_instance_id: int = 0 void_instance_id: int = 0
rescale_predictions: bool = False
@dataclasses.dataclass @dataclasses.dataclass
......
...@@ -110,7 +110,8 @@ def build_panoptic_maskrcnn( ...@@ -110,7 +110,8 @@ def build_panoptic_maskrcnn(
things_class_label=postprocessing_config.things_class_label, things_class_label=postprocessing_config.things_class_label,
stuff_area_threshold=postprocessing_config.stuff_area_threshold, stuff_area_threshold=postprocessing_config.stuff_area_threshold,
void_class_label=postprocessing_config.void_class_label, void_class_label=postprocessing_config.void_class_label,
void_instance_id=postprocessing_config.void_instance_id) void_instance_id=postprocessing_config.void_instance_id,
rescale_predictions=postprocessing_config.rescale_predictions)
else: else:
panoptic_segmentation_generator_obj = None panoptic_segmentation_generator_obj = None
......
...@@ -35,6 +35,7 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer): ...@@ -35,6 +35,7 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer):
things_class_label: int = 1, things_class_label: int = 1,
void_class_label: int = 0, void_class_label: int = 0,
void_instance_id: int = -1, void_instance_id: int = -1,
rescale_predictions: bool = False,
**kwargs): **kwargs):
"""Generates panoptic segmentation masks. """Generates panoptic segmentation masks.
...@@ -55,6 +56,8 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer): ...@@ -55,6 +56,8 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer):
void_instance_id: An `int` that is used to denote regions that are not void_instance_id: An `int` that is used to denote regions that are not
assigned to any thing class. That is, void_instance_id are assigned to assigned to any thing class. That is, void_instance_id are assigned to
both stuff regions and empty regions. both stuff regions and empty regions.
rescale_predictions: `bool`, whether to scale back prediction to original
image sizes. If True, image_info `dict` is used to rescale predictions.
**kwargs: additional kewargs arguments. **kwargs: additional kewargs arguments.
""" """
self._output_size = output_size self._output_size = output_size
...@@ -67,6 +70,7 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer): ...@@ -67,6 +70,7 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer):
self._things_class_label = things_class_label self._things_class_label = things_class_label
self._void_class_label = void_class_label self._void_class_label = void_class_label
self._void_instance_id = void_instance_id self._void_instance_id = void_instance_id
self._rescale_predictions = rescale_predictions
self._config_dict = { self._config_dict = {
'output_size': output_size, 'output_size': output_size,
...@@ -76,7 +80,8 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer): ...@@ -76,7 +80,8 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer):
'score_threshold': score_threshold, 'score_threshold': score_threshold,
'things_class_label': things_class_label, 'things_class_label': things_class_label,
'void_class_label': void_class_label, 'void_class_label': void_class_label,
'void_instance_id': void_instance_id 'void_instance_id': void_instance_id,
'rescale_predictions': rescale_predictions
} }
super(PanopticSegmentationGenerator, self).__init__(**kwargs) super(PanopticSegmentationGenerator, self).__init__(**kwargs)
...@@ -250,23 +255,30 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer): ...@@ -250,23 +255,30 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer):
batched_classes = detections['detection_classes'] batched_classes = detections['detection_classes']
batched_detections_masks = tf.expand_dims( batched_detections_masks = tf.expand_dims(
detections['detection_masks'], axis=-1) detections['detection_masks'], axis=-1)
batched_boxes = detections['detection_boxes'] batched_boxes = detections['detection_boxes']
scale = tf.tile(
tf.cast(image_info[:, 2:3, :], dtype=batched_boxes.dtype),
multiples=[1, 1, 2])
batched_boxes /= scale
batched_segmentation_masks = tf.cast( batched_segmentation_masks = tf.cast(
detections['segmentation_outputs'], dtype=tf.float32) detections['segmentation_outputs'], dtype=tf.float32)
batched_segmentation_masks = tf.map_fn(
fn=lambda x: self._resize_and_pad_masks( if self._rescale_predictions:
x[0], x[1]), scale = tf.tile(
elems=( tf.cast(image_info[:, 2:3, :], dtype=batched_boxes.dtype),
batched_segmentation_masks, multiples=[1, 1, 2])
image_info), batched_boxes /= scale
fn_output_signature=tf.float32,
parallel_iterations=32) batched_segmentation_masks = tf.map_fn(
fn=lambda x: self._resize_and_pad_masks(
x[0], x[1]),
elems=(
batched_segmentation_masks,
image_info),
fn_output_signature=tf.float32,
parallel_iterations=32)
else:
batched_segmentation_masks = tf.image.resize(
batched_segmentation_masks,
size=self._output_size,
method='bilinear')
batched_segmentation_masks = tf.expand_dims(tf.cast( batched_segmentation_masks = tf.expand_dims(tf.cast(
tf.argmax(batched_segmentation_masks, axis=-1), tf.argmax(batched_segmentation_masks, axis=-1),
dtype=tf.float32), axis=-1) dtype=tf.float32), axis=-1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment