Unverified Commit e9fa7cd5 authored by Francesco Saverio Zuppichini's avatar Francesco Saverio Zuppichini Committed by GitHub
Browse files

Make is_thing_map in Feature Extractor post_process_panoptic_segmentation...

Make is_thing_map in Feature Extractor post_process_panoptic_segmentation defaults to all instances (#15954)

* is_thing_map defaults to all instances

* better naming

* control flow

* resolving conversations
parent 2596f95e
......@@ -14,7 +14,7 @@
# limitations under the License.
"""Feature extractor class for MaskFormer."""
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Union
import numpy as np
from PIL import Image
......@@ -466,7 +466,7 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM
outputs: "MaskFormerForInstanceSegmentationOutput",
object_mask_threshold: float = 0.8,
overlap_mask_area_threshold: float = 0.8,
is_thing_map: Optional[Dict[int, bool]] = None,
label_ids_to_fuse: Optional[Set[int]] = None,
) -> List[Dict]:
"""
Converts the output of [`MaskFormerForInstanceSegmentationOutput`] into image panoptic segmentation
......@@ -479,23 +479,23 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM
The object mask threshold.
overlap_mask_area_threshold (`float`, *optional*, defaults to 0.8):
The overlap mask area threshold to use.
is_thing_map (`Dict[int, bool]`, *optional*):
Dictionary mapping class indices to either `True` or `False`, depending on whether or not they are a
thing. If not set, defaults to the `is_thing_map` of COCO panoptic.
label_ids_to_fuse (`Set[int]`, *optional*):
The labels in this state will have all their instances be fused together. For instance we could say
there can only be one sky in an image, but several persons, so the label ID for sky would be in that
set, but not the one for person.
Returns:
`List[Dict]`: A list of dictionaries, one per image, each dictionary containing two keys:
- **segmentation** -- a tensor of shape `(height, width)` where each pixel represents a `segment_id`.
- **segments** -- a dictionary with the following keys
- **id** -- an integer representing the `segment_id`.
- **category_id** -- an integer representing the segment's label.
- **is_thing** -- a boolean, `True` if `category_id` was in `is_thing_map`, `False` otherwise.
- **label_id** -- an integer representing the segment's label.
- **was_fused** -- a boolean, `True` if `label_id` was in `label_ids_to_fuse`, `False` otherwise.
"""
if is_thing_map is None:
logger.warning("`is_thing_map` unset. Default to COCO.")
# default to is_thing_map of COCO panoptic
is_thing_map = {i: i <= 90 for i in range(201)}
if label_ids_to_fuse is None:
logger.warning("`label_ids_to_fuse` unset. No instance will be fused.")
label_ids_to_fuse = set()
# class_queries_logits has shape [BATCH, QUERIES, CLASSES + 1]
class_queries_logits = outputs.class_queries_logits
# keep track of the number of labels, subtract -1 for null class
......@@ -531,8 +531,8 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM
# this is a map between stuff and segments id, the used it to keep track of the instances of one class
for k in range(pred_labels.shape[0]):
pred_class = pred_labels[k].item()
# check if pred_class is not a "thing", so it can be merged with other instance. For example, class "sky" cannot have more then one instance
is_stuff = not is_thing_map[pred_class]
# check if pred_class should be fused. For example, class "sky" cannot have more then one instance
should_fuse = pred_class in label_ids_to_fuse
# get the mask associated with the k class
mask_k = mask_labels == k
# create the area, since bool we just need to sum :)
......@@ -540,9 +540,9 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM
# this is the area of all the stuff in query k
original_area = (mask_probs[k] >= 0.5).sum()
mask_does_exist = mask_k_area > 0 and original_area > 0
mask_exists = mask_k_area > 0 and original_area > 0
if mask_does_exist:
if mask_exists:
# find out how much of the all area mask_k is using
area_ratio = mask_k_area / original_area
mask_k_is_overlapping_enough = area_ratio.item() > overlap_mask_area_threshold
......@@ -558,11 +558,11 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM
segments.append(
{
"id": current_segment_id,
"category_id": pred_class,
"is_thing": not is_stuff,
"label_id": pred_class,
"was_fused": should_fuse,
}
)
if is_stuff:
if should_fuse:
stuff_memory_list[pred_class] = current_segment_id
results.append({"segmentation": segmentation, "segments": segments})
return results
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment