Unverified Commit c09b41da authored by srihari-humbarwadi's avatar srihari-humbarwadi
Browse files

filter things and stuff segments based on area

parent 0e4c6a5b
...@@ -79,6 +79,8 @@ class PanopticSegmentationGenerator(hyperparams.Config): ...@@ -79,6 +79,8 @@ class PanopticSegmentationGenerator(hyperparams.Config):
default_factory=list) default_factory=list)
mask_binarize_threshold: float = 0.5 mask_binarize_threshold: float = 0.5
score_threshold: float = 0.5 score_threshold: float = 0.5
things_overlap_threshold: float = 0.5
stuff_area_threshold: float = 4096.0
things_class_label: int = 1 things_class_label: int = 1
void_class_label: int = 0 void_class_label: int = 0
void_instance_id: int = 0 void_instance_id: int = 0
......
...@@ -106,7 +106,9 @@ def build_panoptic_maskrcnn( ...@@ -106,7 +106,9 @@ def build_panoptic_maskrcnn(
stuff_classes_offset=model_config.stuff_classes_offset, stuff_classes_offset=model_config.stuff_classes_offset,
mask_binarize_threshold=mask_binarize_threshold, mask_binarize_threshold=mask_binarize_threshold,
score_threshold=postprocessing_config.score_threshold, score_threshold=postprocessing_config.score_threshold,
things_overlap_threshold=postprocessing_config.things_overlap_threshold,
things_class_label=postprocessing_config.things_class_label, things_class_label=postprocessing_config.things_class_label,
stuff_area_threshold=postprocessing_config.stuff_area_threshold,
void_class_label=postprocessing_config.void_class_label, void_class_label=postprocessing_config.void_class_label,
void_instance_id=postprocessing_config.void_instance_id) void_instance_id=postprocessing_config.void_instance_id)
else: else:
......
...@@ -23,16 +23,18 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer): ...@@ -23,16 +23,18 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer):
"""Panoptic segmentation generator layer.""" """Panoptic segmentation generator layer."""
def __init__( def __init__(
self, self,
output_size: List[int], output_size: List[int],
max_num_detections: int, max_num_detections: int,
stuff_classes_offset: int, stuff_classes_offset: int,
mask_binarize_threshold: float = 0.5, mask_binarize_threshold: float = 0.5,
score_threshold: float = 0.05, score_threshold: float = 0.5,
things_class_label: int = 1, things_overlap_threshold: float = 0.5,
void_class_label: int = 0, stuff_area_threshold: float = 4096,
void_instance_id: int = -1, things_class_label: int = 1,
**kwargs): void_class_label: int = 0,
void_instance_id: int = -1,
**kwargs):
"""Generates panoptic segmentation masks. """Generates panoptic segmentation masks.
Args: Args:
...@@ -59,6 +61,8 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer): ...@@ -59,6 +61,8 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer):
self._stuff_classes_offset = stuff_classes_offset self._stuff_classes_offset = stuff_classes_offset
self._mask_binarize_threshold = mask_binarize_threshold self._mask_binarize_threshold = mask_binarize_threshold
self._score_threshold = score_threshold self._score_threshold = score_threshold
self._things_overlap_threshold = things_overlap_threshold
self._stuff_area_threshold = stuff_area_threshold
self._things_class_label = things_class_label self._things_class_label = things_class_label
self._void_class_label = void_class_label self._void_class_label = void_class_label
self._void_instance_id = void_instance_id self._void_instance_id = void_instance_id
...@@ -184,6 +188,19 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer): ...@@ -184,6 +188,19 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer):
if not tf.reduce_sum(tf.cast(binary_mask, tf.float32)) > 0: if not tf.reduce_sum(tf.cast(binary_mask, tf.float32)) > 0:
continue continue
overlap = tf.logical_and(
binary_mask,
tf.not_equal(category_mask, self._void_class_label))
binary_mask_area = tf.reduce_sum(
tf.cast(binary_mask, dtype=tf.float32))
overlap_area = tf.reduce_sum(
tf.cast(overlap, dtype=tf.float32))
# skip instance that have a big enough overlap with instances with
# higer scores
if overlap_area / binary_mask_area > self._things_overlap_threshold:
continue
# fill empty regions in category_mask represented by # fill empty regions in category_mask represented by
# void_class_label with class_id of the instance. # void_class_label with class_id of the instance.
category_mask = tf.where( category_mask = tf.where(
...@@ -200,18 +217,25 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer): ...@@ -200,18 +217,25 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer):
tf.ones_like(instance_mask) * tf.ones_like(instance_mask) *
tf.cast(instance_idx + 1, tf.float32), instance_mask) tf.cast(instance_idx + 1, tf.float32), instance_mask)
# add stuff segmentation labels to empty regions of category_mask. stuff_class_ids = tf.unique(tf.reshape(segmentation_mask, [-1])).y
# we ignore the pixels labelled as "things", since we get them from for stuff_class_id in stuff_class_ids:
# the instance masks. if stuff_class_id == self._things_class_label:
# TODO(srihari, arashwan): Support filtering stuff classes based on area. continue
category_mask = tf.where(
tf.logical_and( stuff_mask = tf.logical_and(
tf.equal( tf.equal(segmentation_mask, stuff_class_id),
category_mask, self._void_class_label), tf.equal(category_mask, self._void_class_label))
tf.logical_and(
tf.not_equal(segmentation_mask, self._things_class_label), stuff_mask_area = tf.reduce_sum(
tf.not_equal(segmentation_mask, self._void_class_label))), tf.cast(stuff_mask, dtype=tf.float32))
segmentation_mask, category_mask)
if stuff_mask_area < self._stuff_area_threshold:
continue
category_mask = tf.where(
stuff_mask,
tf.ones_like(category_mask) * stuff_class_id,
category_mask)
results = { results = {
'category_mask': category_mask[:, :, 0], 'category_mask': category_mask[:, :, 0],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment