Unverified Commit e9fa7cd5 authored by Francesco Saverio Zuppichini's avatar Francesco Saverio Zuppichini Committed by GitHub
Browse files

Make is_thing_map in Feature Extractor post_process_panoptic_segmentation...

Make is_thing_map in Feature Extractor post_process_panoptic_segmentation defaults to all instances (#15954)

* is_thing_map defaults to all instances

* better naming

* control flow

* resolving conversations
parent 2596f95e
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
"""Feature extractor class for MaskFormer.""" """Feature extractor class for MaskFormer."""
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Union
import numpy as np import numpy as np
from PIL import Image from PIL import Image
...@@ -466,7 +466,7 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM ...@@ -466,7 +466,7 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM
outputs: "MaskFormerForInstanceSegmentationOutput", outputs: "MaskFormerForInstanceSegmentationOutput",
object_mask_threshold: float = 0.8, object_mask_threshold: float = 0.8,
overlap_mask_area_threshold: float = 0.8, overlap_mask_area_threshold: float = 0.8,
is_thing_map: Optional[Dict[int, bool]] = None, label_ids_to_fuse: Optional[Set[int]] = None,
) -> List[Dict]: ) -> List[Dict]:
""" """
Converts the output of [`MaskFormerForInstanceSegmentationOutput`] into image panoptic segmentation Converts the output of [`MaskFormerForInstanceSegmentationOutput`] into image panoptic segmentation
...@@ -479,23 +479,23 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM ...@@ -479,23 +479,23 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM
The object mask threshold. The object mask threshold.
overlap_mask_area_threshold (`float`, *optional*, defaults to 0.8): overlap_mask_area_threshold (`float`, *optional*, defaults to 0.8):
The overlap mask area threshold to use. The overlap mask area threshold to use.
is_thing_map (`Dict[int, bool]`, *optional*): label_ids_to_fuse (`Set[int]`, *optional*):
Dictionary mapping class indices to either `True` or `False`, depending on whether or not they are a The labels in this state will have all their instances be fused together. For instance we could say
thing. If not set, defaults to the `is_thing_map` of COCO panoptic. there can only be one sky in an image, but several persons, so the label ID for sky would be in that
set, but not the one for person.
Returns: Returns:
`List[Dict]`: A list of dictionaries, one per image, each dictionary containing two keys: `List[Dict]`: A list of dictionaries, one per image, each dictionary containing two keys:
- **segmentation** -- a tensor of shape `(height, width)` where each pixel represents a `segment_id`. - **segmentation** -- a tensor of shape `(height, width)` where each pixel represents a `segment_id`.
- **segments** -- a dictionary with the following keys - **segments** -- a dictionary with the following keys
- **id** -- an integer representing the `segment_id`. - **id** -- an integer representing the `segment_id`.
- **category_id** -- an integer representing the segment's label. - **label_id** -- an integer representing the segment's label.
- **is_thing** -- a boolean, `True` if `category_id` was in `is_thing_map`, `False` otherwise. - **was_fused** -- a boolean, `True` if `label_id` was in `label_ids_to_fuse`, `False` otherwise.
""" """
if is_thing_map is None: if label_ids_to_fuse is None:
logger.warning("`is_thing_map` unset. Default to COCO.") logger.warning("`label_ids_to_fuse` unset. No instance will be fused.")
# default to is_thing_map of COCO panoptic label_ids_to_fuse = set()
is_thing_map = {i: i <= 90 for i in range(201)}
# class_queries_logits has shape [BATCH, QUERIES, CLASSES + 1] # class_queries_logits has shape [BATCH, QUERIES, CLASSES + 1]
class_queries_logits = outputs.class_queries_logits class_queries_logits = outputs.class_queries_logits
# keep track of the number of labels, subtract -1 for null class # keep track of the number of labels, subtract -1 for null class
...@@ -531,8 +531,8 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM ...@@ -531,8 +531,8 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM
# this is a map between stuff and segments id, the used it to keep track of the instances of one class # this is a map between stuff and segments id, the used it to keep track of the instances of one class
for k in range(pred_labels.shape[0]): for k in range(pred_labels.shape[0]):
pred_class = pred_labels[k].item() pred_class = pred_labels[k].item()
# check if pred_class is not a "thing", so it can be merged with other instance. For example, class "sky" cannot have more then one instance # check if pred_class should be fused. For example, class "sky" cannot have more then one instance
is_stuff = not is_thing_map[pred_class] should_fuse = pred_class in label_ids_to_fuse
# get the mask associated with the k class # get the mask associated with the k class
mask_k = mask_labels == k mask_k = mask_labels == k
# create the area, since bool we just need to sum :) # create the area, since bool we just need to sum :)
...@@ -540,9 +540,9 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM ...@@ -540,9 +540,9 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM
# this is the area of all the stuff in query k # this is the area of all the stuff in query k
original_area = (mask_probs[k] >= 0.5).sum() original_area = (mask_probs[k] >= 0.5).sum()
mask_does_exist = mask_k_area > 0 and original_area > 0 mask_exists = mask_k_area > 0 and original_area > 0
if mask_does_exist: if mask_exists:
# find out how much of the all area mask_k is using # find out how much of the all area mask_k is using
area_ratio = mask_k_area / original_area area_ratio = mask_k_area / original_area
mask_k_is_overlapping_enough = area_ratio.item() > overlap_mask_area_threshold mask_k_is_overlapping_enough = area_ratio.item() > overlap_mask_area_threshold
...@@ -558,11 +558,11 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM ...@@ -558,11 +558,11 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM
segments.append( segments.append(
{ {
"id": current_segment_id, "id": current_segment_id,
"category_id": pred_class, "label_id": pred_class,
"is_thing": not is_stuff, "was_fused": should_fuse,
} }
) )
if is_stuff: if should_fuse:
stuff_memory_list[pred_class] = current_segment_id stuff_memory_list[pred_class] = current_segment_id
results.append({"segmentation": segmentation, "segments": segments}) results.append({"segmentation": segmentation, "segments": segments})
return results return results
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment