Unverified Commit c09b41da authored by srihari-humbarwadi's avatar srihari-humbarwadi
Browse files

filter things and stuff segments based on area

parent 0e4c6a5b
......@@ -79,6 +79,8 @@ class PanopticSegmentationGenerator(hyperparams.Config):
default_factory=list)
mask_binarize_threshold: float = 0.5
score_threshold: float = 0.5
things_overlap_threshold: float = 0.5
stuff_area_threshold: float = 4096.0
things_class_label: int = 1
void_class_label: int = 0
void_instance_id: int = 0
......
......@@ -106,7 +106,9 @@ def build_panoptic_maskrcnn(
stuff_classes_offset=model_config.stuff_classes_offset,
mask_binarize_threshold=mask_binarize_threshold,
score_threshold=postprocessing_config.score_threshold,
things_overlap_threshold=postprocessing_config.things_overlap_threshold,
things_class_label=postprocessing_config.things_class_label,
stuff_area_threshold=postprocessing_config.stuff_area_threshold,
void_class_label=postprocessing_config.void_class_label,
void_instance_id=postprocessing_config.void_instance_id)
else:
......
......@@ -23,16 +23,18 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer):
"""Panoptic segmentation generator layer."""
def __init__(
self,
output_size: List[int],
max_num_detections: int,
stuff_classes_offset: int,
mask_binarize_threshold: float = 0.5,
score_threshold: float = 0.05,
things_class_label: int = 1,
void_class_label: int = 0,
void_instance_id: int = -1,
**kwargs):
self,
output_size: List[int],
max_num_detections: int,
stuff_classes_offset: int,
mask_binarize_threshold: float = 0.5,
score_threshold: float = 0.5,
things_overlap_threshold: float = 0.5,
stuff_area_threshold: float = 4096,
things_class_label: int = 1,
void_class_label: int = 0,
void_instance_id: int = -1,
**kwargs):
"""Generates panoptic segmentation masks.
Args:
......@@ -59,6 +61,8 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer):
self._stuff_classes_offset = stuff_classes_offset
self._mask_binarize_threshold = mask_binarize_threshold
self._score_threshold = score_threshold
self._things_overlap_threshold = things_overlap_threshold
self._stuff_area_threshold = stuff_area_threshold
self._things_class_label = things_class_label
self._void_class_label = void_class_label
self._void_instance_id = void_instance_id
......@@ -184,6 +188,19 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer):
if not tf.reduce_sum(tf.cast(binary_mask, tf.float32)) > 0:
continue
overlap = tf.logical_and(
binary_mask,
tf.not_equal(category_mask, self._void_class_label))
binary_mask_area = tf.reduce_sum(
tf.cast(binary_mask, dtype=tf.float32))
overlap_area = tf.reduce_sum(
tf.cast(overlap, dtype=tf.float32))
# skip instance that have a big enough overlap with instances with
# higer scores
if overlap_area / binary_mask_area > self._things_overlap_threshold:
continue
# fill empty regions in category_mask represented by
# void_class_label with class_id of the instance.
category_mask = tf.where(
......@@ -200,18 +217,25 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer):
tf.ones_like(instance_mask) *
tf.cast(instance_idx + 1, tf.float32), instance_mask)
# add stuff segmentation labels to empty regions of category_mask.
# we ignore the pixels labelled as "things", since we get them from
# the instance masks.
# TODO(srihari, arashwan): Support filtering stuff classes based on area.
category_mask = tf.where(
tf.logical_and(
tf.equal(
category_mask, self._void_class_label),
tf.logical_and(
tf.not_equal(segmentation_mask, self._things_class_label),
tf.not_equal(segmentation_mask, self._void_class_label))),
segmentation_mask, category_mask)
stuff_class_ids = tf.unique(tf.reshape(segmentation_mask, [-1])).y
for stuff_class_id in stuff_class_ids:
if stuff_class_id == self._things_class_label:
continue
stuff_mask = tf.logical_and(
tf.equal(segmentation_mask, stuff_class_id),
tf.equal(category_mask, self._void_class_label))
stuff_mask_area = tf.reduce_sum(
tf.cast(stuff_mask, dtype=tf.float32))
if stuff_mask_area < self._stuff_area_threshold:
continue
category_mask = tf.where(
stuff_mask,
tf.ones_like(category_mask) * stuff_class_id,
category_mask)
results = {
'category_mask': category_mask[:, :, 0],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment