Unverified Commit f424b094 authored by Alara Dirik's avatar Alara Dirik Committed by GitHub
Browse files

Fix MaskFormerImageProcessor.post_process_instance_segmentation (#21256)

* fix instance segmentation post processing

* add Mask2FormerImageProcessor
parent 767939af
...@@ -22,8 +22,8 @@ The abstract from the paper is the following: ...@@ -22,8 +22,8 @@ The abstract from the paper is the following:
of semantics defines a task. While only the semantics of each task differ, current research focuses on designing specialized architectures for each task. We present Masked-attention Mask Transformer (Mask2Former), a new architecture capable of addressing any image segmentation task (panoptic, instance or semantic). Its key components include masked attention, which extracts localized features by constraining cross-attention within predicted mask regions. In addition to reducing the research effort by at least three times, it outperforms the best specialized architectures by a significant margin on four popular datasets. Most notably, Mask2Former sets a new state-of-the-art for panoptic segmentation (57.8 PQ on COCO), instance segmentation (50.1 AP on COCO) and semantic segmentation (57.7 mIoU on ADE20K).* of semantics defines a task. While only the semantics of each task differ, current research focuses on designing specialized architectures for each task. We present Masked-attention Mask Transformer (Mask2Former), a new architecture capable of addressing any image segmentation task (panoptic, instance or semantic). Its key components include masked attention, which extracts localized features by constraining cross-attention within predicted mask regions. In addition to reducing the research effort by at least three times, it outperforms the best specialized architectures by a significant margin on four popular datasets. Most notably, Mask2Former sets a new state-of-the-art for panoptic segmentation (57.8 PQ on COCO), instance segmentation (50.1 AP on COCO) and semantic segmentation (57.7 mIoU on ADE20K).*
Tips: Tips:
- Mask2Former uses the same preprocessing and postprocessing steps as [MaskFormer](maskformer). Use [`MaskFormerImageProcessor`] or [`AutoImageProcessor`] to prepare images and optional targets for the model. - Mask2Former uses the same preprocessing and postprocessing steps as [MaskFormer](maskformer). Use [`Mask2FormerImageProcessor`] or [`AutoImageProcessor`] to prepare images and optional targets for the model.
- To get the final segmentation, depending on the task, you can call [`~MaskFormerImageProcessor.post_process_semantic_segmentation`] or [`~MaskFormerImageProcessor.post_process_instance_segmentation`] or [`~MaskFormerImageProcessor.post_process_panoptic_segmentation`]. All three tasks can be solved using [`Mask2FormerForUniversalSegmentation`] output, panoptic segmentation accepts an optional `label_ids_to_fuse` argument to fuse instances of the target object/s (e.g. sky) together. - To get the final segmentation, depending on the task, you can call [`~Mask2FormerImageProcessor.post_process_semantic_segmentation`] or [`~Mask2FormerImageProcessor.post_process_instance_segmentation`] or [`~Mask2FormerImageProcessor.post_process_panoptic_segmentation`]. All three tasks can be solved using [`Mask2FormerForUniversalSegmentation`] output, panoptic segmentation accepts an optional `label_ids_to_fuse` argument to fuse instances of the target object/s (e.g. sky) together.
This model was contributed by [Shivalika Singh](https://huggingface.co/shivi) and [Alara Dirik](https://huggingface.co/adirik). The original code can be found [here](https://github.com/facebookresearch/Mask2Former). This model was contributed by [Shivalika Singh](https://huggingface.co/shivi) and [Alara Dirik](https://huggingface.co/adirik). The original code can be found [here](https://github.com/facebookresearch/Mask2Former).
...@@ -55,3 +55,12 @@ The resource should ideally demonstrate something new instead of duplicating an ...@@ -55,3 +55,12 @@ The resource should ideally demonstrate something new instead of duplicating an
[[autodoc]] Mask2FormerForUniversalSegmentation [[autodoc]] Mask2FormerForUniversalSegmentation
- forward - forward
## Mask2FormerImageProcessor
[[autodoc]] Mask2FormerImageProcessor
- preprocess
- encode_inputs
- post_process_semantic_segmentation
- post_process_instance_segmentation
- post_process_panoptic_segmentation
\ No newline at end of file
...@@ -799,6 +799,7 @@ else: ...@@ -799,6 +799,7 @@ else:
_import_structure["models.layoutlmv2"].extend(["LayoutLMv2FeatureExtractor", "LayoutLMv2ImageProcessor"]) _import_structure["models.layoutlmv2"].extend(["LayoutLMv2FeatureExtractor", "LayoutLMv2ImageProcessor"])
_import_structure["models.layoutlmv3"].extend(["LayoutLMv3FeatureExtractor", "LayoutLMv3ImageProcessor"]) _import_structure["models.layoutlmv3"].extend(["LayoutLMv3FeatureExtractor", "LayoutLMv3ImageProcessor"])
_import_structure["models.levit"].extend(["LevitFeatureExtractor", "LevitImageProcessor"]) _import_structure["models.levit"].extend(["LevitFeatureExtractor", "LevitImageProcessor"])
_import_structure["models.mask2former"].append("Mask2FormerImageProcessor")
_import_structure["models.maskformer"].extend(["MaskFormerFeatureExtractor", "MaskFormerImageProcessor"]) _import_structure["models.maskformer"].extend(["MaskFormerFeatureExtractor", "MaskFormerImageProcessor"])
_import_structure["models.mobilenet_v1"].extend(["MobileNetV1FeatureExtractor", "MobileNetV1ImageProcessor"]) _import_structure["models.mobilenet_v1"].extend(["MobileNetV1FeatureExtractor", "MobileNetV1ImageProcessor"])
_import_structure["models.mobilenet_v2"].extend(["MobileNetV2FeatureExtractor", "MobileNetV2ImageProcessor"]) _import_structure["models.mobilenet_v2"].extend(["MobileNetV2FeatureExtractor", "MobileNetV2ImageProcessor"])
...@@ -4152,6 +4153,7 @@ if TYPE_CHECKING: ...@@ -4152,6 +4153,7 @@ if TYPE_CHECKING:
from .models.layoutlmv2 import LayoutLMv2FeatureExtractor, LayoutLMv2ImageProcessor from .models.layoutlmv2 import LayoutLMv2FeatureExtractor, LayoutLMv2ImageProcessor
from .models.layoutlmv3 import LayoutLMv3FeatureExtractor, LayoutLMv3ImageProcessor from .models.layoutlmv3 import LayoutLMv3FeatureExtractor, LayoutLMv3ImageProcessor
from .models.levit import LevitFeatureExtractor, LevitImageProcessor from .models.levit import LevitFeatureExtractor, LevitImageProcessor
from .models.mask2former import Mask2FormerImageProcessor
from .models.maskformer import MaskFormerFeatureExtractor, MaskFormerImageProcessor from .models.maskformer import MaskFormerFeatureExtractor, MaskFormerImageProcessor
from .models.mobilenet_v1 import MobileNetV1FeatureExtractor, MobileNetV1ImageProcessor from .models.mobilenet_v1 import MobileNetV1FeatureExtractor, MobileNetV1ImageProcessor
from .models.mobilenet_v2 import MobileNetV2FeatureExtractor, MobileNetV2ImageProcessor from .models.mobilenet_v2 import MobileNetV2FeatureExtractor, MobileNetV2ImageProcessor
......
...@@ -62,7 +62,7 @@ IMAGE_PROCESSOR_MAPPING_NAMES = OrderedDict( ...@@ -62,7 +62,7 @@ IMAGE_PROCESSOR_MAPPING_NAMES = OrderedDict(
("layoutlmv2", "LayoutLMv2ImageProcessor"), ("layoutlmv2", "LayoutLMv2ImageProcessor"),
("layoutlmv3", "LayoutLMv3ImageProcessor"), ("layoutlmv3", "LayoutLMv3ImageProcessor"),
("levit", "LevitImageProcessor"), ("levit", "LevitImageProcessor"),
("mask2former", "MaskFormerImageProcessor"), ("mask2former", "Mask2FormerImageProcessor"),
("maskformer", "MaskFormerImageProcessor"), ("maskformer", "MaskFormerImageProcessor"),
("mobilenet_v1", "MobileNetV1ImageProcessor"), ("mobilenet_v1", "MobileNetV1ImageProcessor"),
("mobilenet_v2", "MobileNetV2ImageProcessor"), ("mobilenet_v2", "MobileNetV2ImageProcessor"),
......
...@@ -27,6 +27,13 @@ _import_structure = { ...@@ -27,6 +27,13 @@ _import_structure = {
], ],
} }
try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["image_processing_mask2former"] = ["Mask2FormerImageProcessor"]
try: try:
if not is_torch_available(): if not is_torch_available():
...@@ -44,6 +51,14 @@ else: ...@@ -44,6 +51,14 @@ else:
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_mask2former import MASK2FORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, Mask2FormerConfig from .configuration_mask2former import MASK2FORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, Mask2FormerConfig
try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .image_processing_mask2former import Mask2FormerImageProcessor
try: try:
if not is_torch_available(): if not is_torch_available():
raise OptionalDependencyNotAvailable() raise OptionalDependencyNotAvailable()
......
...@@ -33,8 +33,8 @@ from huggingface_hub import hf_hub_download ...@@ -33,8 +33,8 @@ from huggingface_hub import hf_hub_download
from transformers import ( from transformers import (
Mask2FormerConfig, Mask2FormerConfig,
Mask2FormerForUniversalSegmentation, Mask2FormerForUniversalSegmentation,
Mask2FormerImageProcessor,
Mask2FormerModel, Mask2FormerModel,
MaskFormerImageProcessor,
SwinConfig, SwinConfig,
) )
from transformers.models.mask2former.modeling_mask2former import ( from transformers.models.mask2former.modeling_mask2former import (
...@@ -193,11 +193,11 @@ class OriginalMask2FormerConfigToOursConverter: ...@@ -193,11 +193,11 @@ class OriginalMask2FormerConfigToOursConverter:
class OriginalMask2FormerConfigToFeatureExtractorConverter: class OriginalMask2FormerConfigToFeatureExtractorConverter:
def __call__(self, original_config: object) -> MaskFormerImageProcessor: def __call__(self, original_config: object) -> Mask2FormerImageProcessor:
model = original_config.MODEL model = original_config.MODEL
model_input = original_config.INPUT model_input = original_config.INPUT
return MaskFormerImageProcessor( return Mask2FormerImageProcessor(
image_mean=(torch.tensor(model.PIXEL_MEAN) / 255).tolist(), image_mean=(torch.tensor(model.PIXEL_MEAN) / 255).tolist(),
image_std=(torch.tensor(model.PIXEL_STD) / 255).tolist(), image_std=(torch.tensor(model.PIXEL_STD) / 255).tolist(),
size=model_input.MIN_SIZE_TEST, size=model_input.MIN_SIZE_TEST,
...@@ -847,7 +847,7 @@ class OriginalMask2FormerCheckpointToOursConverter: ...@@ -847,7 +847,7 @@ class OriginalMask2FormerCheckpointToOursConverter:
def test( def test(
original_model, original_model,
our_model: Mask2FormerForUniversalSegmentation, our_model: Mask2FormerForUniversalSegmentation,
feature_extractor: MaskFormerImageProcessor, feature_extractor: Mask2FormerImageProcessor,
tolerance: float, tolerance: float,
): ):
with torch.no_grad(): with torch.no_grad():
......
...@@ -49,6 +49,7 @@ logger = logging.get_logger(__name__) ...@@ -49,6 +49,7 @@ logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "Mask2FormerConfig" _CONFIG_FOR_DOC = "Mask2FormerConfig"
_CHECKPOINT_FOR_DOC = "facebook/mask2former-swin-small-coco-instance" _CHECKPOINT_FOR_DOC = "facebook/mask2former-swin-small-coco-instance"
_IMAGE_PROCESSOR_FOR_DOC = "Mask2FormerImageProcessor"
MASK2FORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [ MASK2FORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [
"facebook/mask2former-swin-small-coco-instance", "facebook/mask2former-swin-small-coco-instance",
...@@ -194,10 +195,10 @@ class Mask2FormerForUniversalSegmentationOutput(ModelOutput): ...@@ -194,10 +195,10 @@ class Mask2FormerForUniversalSegmentationOutput(ModelOutput):
""" """
Class for outputs of [`Mask2FormerForUniversalSegmentationOutput`]. Class for outputs of [`Mask2FormerForUniversalSegmentationOutput`].
This output can be directly passed to [`~MaskFormerImageProcessor.post_process_semantic_segmentation`] or This output can be directly passed to [`~Mask2FormerImageProcessor.post_process_semantic_segmentation`] or
[`~MaskFormerImageProcessor.post_process_instance_segmentation`] or [`~Mask2FormerImageProcessor.post_process_instance_segmentation`] or
[`~MaskFormerImageProcessor.post_process_panoptic_segmentation`] to compute final segmentation maps. Please, see [`~Mask2FormerImageProcessor.post_process_panoptic_segmentation`] to compute final segmentation maps. Please, see
[`~MaskFormerImageProcessor] for details regarding usage. [`~Mask2FormerImageProcessor] for details regarding usage.
Args: Args:
loss (`torch.Tensor`, *optional*): loss (`torch.Tensor`, *optional*):
......
...@@ -1016,6 +1016,7 @@ class MaskFormerImageProcessor(BaseImageProcessor): ...@@ -1016,6 +1016,7 @@ class MaskFormerImageProcessor(BaseImageProcessor):
overlap_mask_area_threshold: float = 0.8, overlap_mask_area_threshold: float = 0.8,
target_sizes: Optional[List[Tuple[int, int]]] = None, target_sizes: Optional[List[Tuple[int, int]]] = None,
return_coco_annotation: Optional[bool] = False, return_coco_annotation: Optional[bool] = False,
return_binary_maps: Optional[bool] = False,
) -> List[Dict]: ) -> List[Dict]:
""" """
Converts the output of [`MaskFormerForInstanceSegmentationOutput`] into instance segmentation predictions. Only Converts the output of [`MaskFormerForInstanceSegmentationOutput`] into instance segmentation predictions. Only
...@@ -1034,9 +1035,11 @@ class MaskFormerImageProcessor(BaseImageProcessor): ...@@ -1034,9 +1035,11 @@ class MaskFormerImageProcessor(BaseImageProcessor):
target_sizes (`List[Tuple]`, *optional*): target_sizes (`List[Tuple]`, *optional*):
List of length (batch_size), where each list item (`Tuple[int, int]]`) corresponds to the requested List of length (batch_size), where each list item (`Tuple[int, int]]`) corresponds to the requested
final size (height, width) of each prediction. If left to None, predictions will not be resized. final size (height, width) of each prediction. If left to None, predictions will not be resized.
return_coco_annotation (`bool`, *optional*): return_coco_annotation (`bool`, *optional*, defaults to `False`):
Defaults to `False`. If set to `True`, segmentation maps are returned in COCO run-length encoding (RLE) If set to `True`, segmentation maps are returned in COCO run-length encoding (RLE) format.
format. return_binary_maps (`bool`, *optional*, defaults to `False`):
If set to `True`, segmentation maps are returned as a concatenated tensor of binary segmentation maps
(one per detected instance).
Returns: Returns:
`List[Dict]`: A list of dictionaries, one per image, each dictionary containing two keys: `List[Dict]`: A list of dictionaries, one per image, each dictionary containing two keys:
- **segmentation** -- A tensor of shape `(height, width)` where each pixel represents a `segment_id` or - **segmentation** -- A tensor of shape `(height, width)` where each pixel represents a `segment_id` or
...@@ -1047,47 +1050,73 @@ class MaskFormerImageProcessor(BaseImageProcessor): ...@@ -1047,47 +1050,73 @@ class MaskFormerImageProcessor(BaseImageProcessor):
- **label_id** -- An integer representing the label / semantic class id corresponding to `segment_id`. - **label_id** -- An integer representing the label / semantic class id corresponding to `segment_id`.
- **score** -- Prediction score of segment with `segment_id`. - **score** -- Prediction score of segment with `segment_id`.
""" """
class_queries_logits = outputs.class_queries_logits # [batch_size, num_queries, num_classes+1] if return_coco_annotation and return_binary_maps:
masks_queries_logits = outputs.masks_queries_logits # [batch_size, num_queries, height, width] raise ValueError("return_coco_annotation and return_binary_maps can not be both set to True.")
batch_size = class_queries_logits.shape[0]
num_labels = class_queries_logits.shape[-1] - 1
mask_probs = masks_queries_logits.sigmoid() # [batch_size, num_queries, height, width] # [batch_size, num_queries, num_classes+1]
class_queries_logits = outputs.class_queries_logits
# [batch_size, num_queries, height, width]
masks_queries_logits = outputs.masks_queries_logits
# Predicted label and score of each query (batch_size, num_queries) device = masks_queries_logits.device
pred_scores, pred_labels = nn.functional.softmax(class_queries_logits, dim=-1).max(-1) num_classes = class_queries_logits.shape[-1] - 1
num_queries = class_queries_logits.shape[-2]
# Loop over items in batch size # Loop over items in batch size
results: List[Dict[str, TensorType]] = [] results: List[Dict[str, TensorType]] = []
for i in range(batch_size): for i in range(class_queries_logits.shape[0]):
mask_probs_item, pred_scores_item, pred_labels_item = remove_low_and_no_objects( mask_pred = masks_queries_logits[i]
mask_probs[i], pred_scores[i], pred_labels[i], threshold, num_labels mask_cls = class_queries_logits[i]
)
# No mask found scores = torch.nn.functional.softmax(mask_cls, dim=-1)[:, :-1]
if mask_probs_item.shape[0] <= 0: labels = torch.arange(num_classes, device=device).unsqueeze(0).repeat(num_queries, 1).flatten(0, 1)
height, width = target_sizes[i] if target_sizes is not None else mask_probs_item.shape[1:]
segmentation = torch.zeros((height, width)) - 1
results.append({"segmentation": segmentation, "segments_info": []})
continue
# Get segmentation map and segment information of batch item scores_per_image, topk_indices = scores.flatten(0, 1).topk(num_queries, sorted=False)
target_size = target_sizes[i] if target_sizes is not None else None labels_per_image = labels[topk_indices]
segmentation, segments = compute_segments(
mask_probs=mask_probs_item,
pred_scores=pred_scores_item,
pred_labels=pred_labels_item,
mask_threshold=mask_threshold,
overlap_mask_area_threshold=overlap_mask_area_threshold,
label_ids_to_fuse=[],
target_size=target_size,
)
# Return segmentation map in run-length encoding (RLE) format topk_indices = topk_indices // num_classes
if return_coco_annotation: mask_pred = mask_pred[topk_indices]
segmentation = convert_segmentation_to_rle(segmentation) pred_masks = (mask_pred > 0).float()
# Calculate average mask prob
mask_scores_per_image = (mask_pred.sigmoid().flatten(1) * pred_masks.flatten(1)).sum(1) / (
pred_masks.flatten(1).sum(1) + 1e-6
)
pred_scores = scores_per_image * mask_scores_per_image
pred_classes = labels_per_image
segmentation = torch.zeros(masks_queries_logits.shape[2:]) - 1
if target_sizes is not None:
segmentation = torch.zeros(target_sizes[i]) - 1
pred_masks = torch.nn.functional.interpolate(
pred_masks.unsqueeze(0), size=target_sizes[i], mode="nearest"
)[0]
instance_maps, segments = [], []
current_segment_id = 0
for j in range(num_queries):
score = pred_scores[j].item()
if not torch.all(pred_masks[j] == 0) and score >= threshold:
segmentation[pred_masks[j] == 1] = current_segment_id
segments.append(
{
"id": current_segment_id,
"label_id": pred_classes[j].item(),
"was_fused": False,
"score": round(score, 6),
}
)
current_segment_id += 1
instance_maps.append(pred_masks[j])
# Return segmentation map in run-length encoding (RLE) format
if return_coco_annotation:
segmentation = convert_segmentation_to_rle(segmentation)
# Return a concatenated tensor of binary instance maps
if return_binary_maps and len(instance_maps) != 0:
segmentation = torch.stack(instance_maps, dim=0)
results.append({"segmentation": segmentation, "segments_info": segments}) results.append({"segmentation": segmentation, "segments_info": segments})
return results return results
......
...@@ -269,6 +269,13 @@ class LevitImageProcessor(metaclass=DummyObject): ...@@ -269,6 +269,13 @@ class LevitImageProcessor(metaclass=DummyObject):
requires_backends(self, ["vision"]) requires_backends(self, ["vision"])
class Mask2FormerImageProcessor(metaclass=DummyObject):
_backends = ["vision"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["vision"])
class MaskFormerFeatureExtractor(metaclass=DummyObject): class MaskFormerFeatureExtractor(metaclass=DummyObject):
_backends = ["vision"] _backends = ["vision"]
......
This diff is collapsed.
...@@ -34,7 +34,7 @@ if is_torch_available(): ...@@ -34,7 +34,7 @@ if is_torch_available():
from transformers import Mask2FormerForUniversalSegmentation, Mask2FormerModel from transformers import Mask2FormerForUniversalSegmentation, Mask2FormerModel
if is_vision_available(): if is_vision_available():
from transformers import MaskFormerImageProcessor from transformers import Mask2FormerImageProcessor
if is_vision_available(): if is_vision_available():
from PIL import Image from PIL import Image
...@@ -325,7 +325,7 @@ class Mask2FormerModelIntegrationTest(unittest.TestCase): ...@@ -325,7 +325,7 @@ class Mask2FormerModelIntegrationTest(unittest.TestCase):
@cached_property @cached_property
def default_feature_extractor(self): def default_feature_extractor(self):
return MaskFormerImageProcessor.from_pretrained(self.model_checkpoints) if is_vision_available() else None return Mask2FormerImageProcessor.from_pretrained(self.model_checkpoints) if is_vision_available() else None
def test_inference_no_head(self): def test_inference_no_head(self):
model = Mask2FormerModel.from_pretrained(self.model_checkpoints).to(torch_device) model = Mask2FormerModel.from_pretrained(self.model_checkpoints).to(torch_device)
......
...@@ -576,6 +576,34 @@ class MaskFormerImageProcessingTest(ImageProcessingSavingTestMixin, unittest.Tes ...@@ -576,6 +576,34 @@ class MaskFormerImageProcessingTest(ImageProcessingSavingTestMixin, unittest.Tes
self.assertEqual(segmentation[0].shape, target_sizes[0]) self.assertEqual(segmentation[0].shape, target_sizes[0])
def test_post_process_instance_segmentation(self):
feature_extractor = self.image_processing_class(num_labels=self.image_processor_tester.num_classes)
outputs = self.image_processor_tester.get_fake_maskformer_outputs()
segmentation = feature_extractor.post_process_instance_segmentation(outputs, threshold=0)
self.assertTrue(len(segmentation) == self.image_processor_tester.batch_size)
for el in segmentation:
self.assertTrue("segmentation" in el)
self.assertTrue("segments_info" in el)
self.assertEqual(type(el["segments_info"]), list)
self.assertEqual(
el["segmentation"].shape, (self.image_processor_tester.height, self.image_processor_tester.width)
)
segmentation = feature_extractor.post_process_instance_segmentation(
outputs, threshold=0, return_binary_maps=True
)
self.assertTrue(len(segmentation) == self.image_processor_tester.batch_size)
for el in segmentation:
self.assertTrue("segmentation" in el)
self.assertTrue("segments_info" in el)
self.assertEqual(type(el["segments_info"]), list)
self.assertEqual(len(el["segmentation"].shape), 3)
self.assertEqual(
el["segmentation"].shape[1:], (self.image_processor_tester.height, self.image_processor_tester.width)
)
def test_post_process_panoptic_segmentation(self): def test_post_process_panoptic_segmentation(self):
image_processing = self.image_processing_class(num_labels=self.image_processor_tester.num_classes) image_processing = self.image_processing_class(num_labels=self.image_processor_tester.num_classes)
outputs = self.image_processor_tester.get_fake_maskformer_outputs() outputs = self.image_processor_tester.get_fake_maskformer_outputs()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment