Unverified Commit f424b094 authored by Alara Dirik's avatar Alara Dirik Committed by GitHub
Browse files

Fix MaskFormerImageProcessor.post_process_instance_segmentation (#21256)

* fix instance segmentation post processing

* add Mask2FormerImageProcessor
parent 767939af
......@@ -22,8 +22,8 @@ The abstract from the paper is the following:
of semantics defines a task. While only the semantics of each task differ, current research focuses on designing specialized architectures for each task. We present Masked-attention Mask Transformer (Mask2Former), a new architecture capable of addressing any image segmentation task (panoptic, instance or semantic). Its key components include masked attention, which extracts localized features by constraining cross-attention within predicted mask regions. In addition to reducing the research effort by at least three times, it outperforms the best specialized architectures by a significant margin on four popular datasets. Most notably, Mask2Former sets a new state-of-the-art for panoptic segmentation (57.8 PQ on COCO), instance segmentation (50.1 AP on COCO) and semantic segmentation (57.7 mIoU on ADE20K).*
Tips:
- Mask2Former uses the same preprocessing and postprocessing steps as [MaskFormer](maskformer). Use [`MaskFormerImageProcessor`] or [`AutoImageProcessor`] to prepare images and optional targets for the model.
- To get the final segmentation, depending on the task, you can call [`~MaskFormerImageProcessor.post_process_semantic_segmentation`] or [`~MaskFormerImageProcessor.post_process_instance_segmentation`] or [`~MaskFormerImageProcessor.post_process_panoptic_segmentation`]. All three tasks can be solved using [`Mask2FormerForUniversalSegmentation`] output, panoptic segmentation accepts an optional `label_ids_to_fuse` argument to fuse instances of the target object/s (e.g. sky) together.
- Mask2Former uses the same preprocessing and postprocessing steps as [MaskFormer](maskformer). Use [`Mask2FormerImageProcessor`] or [`AutoImageProcessor`] to prepare images and optional targets for the model.
- To get the final segmentation, depending on the task, you can call [`~Mask2FormerImageProcessor.post_process_semantic_segmentation`] or [`~Mask2FormerImageProcessor.post_process_instance_segmentation`] or [`~Mask2FormerImageProcessor.post_process_panoptic_segmentation`]. All three tasks can be solved using [`Mask2FormerForUniversalSegmentation`] output, panoptic segmentation accepts an optional `label_ids_to_fuse` argument to fuse instances of the target object/s (e.g. sky) together.
This model was contributed by [Shivalika Singh](https://huggingface.co/shivi) and [Alara Dirik](https://huggingface.co/adirik). The original code can be found [here](https://github.com/facebookresearch/Mask2Former).
......@@ -55,3 +55,12 @@ The resource should ideally demonstrate something new instead of duplicating an
[[autodoc]] Mask2FormerForUniversalSegmentation
- forward
## Mask2FormerImageProcessor
[[autodoc]] Mask2FormerImageProcessor
- preprocess
- encode_inputs
- post_process_semantic_segmentation
- post_process_instance_segmentation
- post_process_panoptic_segmentation
\ No newline at end of file
......@@ -799,6 +799,7 @@ else:
_import_structure["models.layoutlmv2"].extend(["LayoutLMv2FeatureExtractor", "LayoutLMv2ImageProcessor"])
_import_structure["models.layoutlmv3"].extend(["LayoutLMv3FeatureExtractor", "LayoutLMv3ImageProcessor"])
_import_structure["models.levit"].extend(["LevitFeatureExtractor", "LevitImageProcessor"])
_import_structure["models.mask2former"].append("Mask2FormerImageProcessor")
_import_structure["models.maskformer"].extend(["MaskFormerFeatureExtractor", "MaskFormerImageProcessor"])
_import_structure["models.mobilenet_v1"].extend(["MobileNetV1FeatureExtractor", "MobileNetV1ImageProcessor"])
_import_structure["models.mobilenet_v2"].extend(["MobileNetV2FeatureExtractor", "MobileNetV2ImageProcessor"])
......@@ -4152,6 +4153,7 @@ if TYPE_CHECKING:
from .models.layoutlmv2 import LayoutLMv2FeatureExtractor, LayoutLMv2ImageProcessor
from .models.layoutlmv3 import LayoutLMv3FeatureExtractor, LayoutLMv3ImageProcessor
from .models.levit import LevitFeatureExtractor, LevitImageProcessor
from .models.mask2former import Mask2FormerImageProcessor
from .models.maskformer import MaskFormerFeatureExtractor, MaskFormerImageProcessor
from .models.mobilenet_v1 import MobileNetV1FeatureExtractor, MobileNetV1ImageProcessor
from .models.mobilenet_v2 import MobileNetV2FeatureExtractor, MobileNetV2ImageProcessor
......
......@@ -62,7 +62,7 @@ IMAGE_PROCESSOR_MAPPING_NAMES = OrderedDict(
("layoutlmv2", "LayoutLMv2ImageProcessor"),
("layoutlmv3", "LayoutLMv3ImageProcessor"),
("levit", "LevitImageProcessor"),
("mask2former", "MaskFormerImageProcessor"),
("mask2former", "Mask2FormerImageProcessor"),
("maskformer", "MaskFormerImageProcessor"),
("mobilenet_v1", "MobileNetV1ImageProcessor"),
("mobilenet_v2", "MobileNetV2ImageProcessor"),
......
......@@ -27,6 +27,13 @@ _import_structure = {
],
}
try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["image_processing_mask2former"] = ["Mask2FormerImageProcessor"]
try:
if not is_torch_available():
......@@ -44,6 +51,14 @@ else:
if TYPE_CHECKING:
from .configuration_mask2former import MASK2FORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, Mask2FormerConfig
try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .image_processing_mask2former import Mask2FormerImageProcessor
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
......
......@@ -33,8 +33,8 @@ from huggingface_hub import hf_hub_download
from transformers import (
Mask2FormerConfig,
Mask2FormerForUniversalSegmentation,
Mask2FormerImageProcessor,
Mask2FormerModel,
MaskFormerImageProcessor,
SwinConfig,
)
from transformers.models.mask2former.modeling_mask2former import (
......@@ -193,11 +193,11 @@ class OriginalMask2FormerConfigToOursConverter:
class OriginalMask2FormerConfigToFeatureExtractorConverter:
def __call__(self, original_config: object) -> MaskFormerImageProcessor:
def __call__(self, original_config: object) -> Mask2FormerImageProcessor:
model = original_config.MODEL
model_input = original_config.INPUT
return MaskFormerImageProcessor(
return Mask2FormerImageProcessor(
image_mean=(torch.tensor(model.PIXEL_MEAN) / 255).tolist(),
image_std=(torch.tensor(model.PIXEL_STD) / 255).tolist(),
size=model_input.MIN_SIZE_TEST,
......@@ -847,7 +847,7 @@ class OriginalMask2FormerCheckpointToOursConverter:
def test(
original_model,
our_model: Mask2FormerForUniversalSegmentation,
feature_extractor: MaskFormerImageProcessor,
feature_extractor: Mask2FormerImageProcessor,
tolerance: float,
):
with torch.no_grad():
......
......@@ -49,6 +49,7 @@ logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "Mask2FormerConfig"
_CHECKPOINT_FOR_DOC = "facebook/mask2former-swin-small-coco-instance"
_IMAGE_PROCESSOR_FOR_DOC = "Mask2FormerImageProcessor"
MASK2FORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [
"facebook/mask2former-swin-small-coco-instance",
......@@ -194,10 +195,10 @@ class Mask2FormerForUniversalSegmentationOutput(ModelOutput):
"""
Class for outputs of [`Mask2FormerForUniversalSegmentationOutput`].
This output can be directly passed to [`~MaskFormerImageProcessor.post_process_semantic_segmentation`] or
[`~MaskFormerImageProcessor.post_process_instance_segmentation`] or
[`~MaskFormerImageProcessor.post_process_panoptic_segmentation`] to compute final segmentation maps. Please, see
[`~MaskFormerImageProcessor] for details regarding usage.
This output can be directly passed to [`~Mask2FormerImageProcessor.post_process_semantic_segmentation`] or
[`~Mask2FormerImageProcessor.post_process_instance_segmentation`] or
[`~Mask2FormerImageProcessor.post_process_panoptic_segmentation`] to compute final segmentation maps. Please, see
[`~Mask2FormerImageProcessor] for details regarding usage.
Args:
loss (`torch.Tensor`, *optional*):
......
......@@ -1016,6 +1016,7 @@ class MaskFormerImageProcessor(BaseImageProcessor):
overlap_mask_area_threshold: float = 0.8,
target_sizes: Optional[List[Tuple[int, int]]] = None,
return_coco_annotation: Optional[bool] = False,
return_binary_maps: Optional[bool] = False,
) -> List[Dict]:
"""
Converts the output of [`MaskFormerForInstanceSegmentationOutput`] into instance segmentation predictions. Only
......@@ -1034,9 +1035,11 @@ class MaskFormerImageProcessor(BaseImageProcessor):
target_sizes (`List[Tuple]`, *optional*):
List of length (batch_size), where each list item (`Tuple[int, int]]`) corresponds to the requested
final size (height, width) of each prediction. If left to None, predictions will not be resized.
return_coco_annotation (`bool`, *optional*):
Defaults to `False`. If set to `True`, segmentation maps are returned in COCO run-length encoding (RLE)
format.
return_coco_annotation (`bool`, *optional*, defaults to `False`):
If set to `True`, segmentation maps are returned in COCO run-length encoding (RLE) format.
return_binary_maps (`bool`, *optional*, defaults to `False`):
If set to `True`, segmentation maps are returned as a concatenated tensor of binary segmentation maps
(one per detected instance).
Returns:
`List[Dict]`: A list of dictionaries, one per image, each dictionary containing two keys:
- **segmentation** -- A tensor of shape `(height, width)` where each pixel represents a `segment_id` or
......@@ -1047,47 +1050,73 @@ class MaskFormerImageProcessor(BaseImageProcessor):
- **label_id** -- An integer representing the label / semantic class id corresponding to `segment_id`.
- **score** -- Prediction score of segment with `segment_id`.
"""
class_queries_logits = outputs.class_queries_logits # [batch_size, num_queries, num_classes+1]
masks_queries_logits = outputs.masks_queries_logits # [batch_size, num_queries, height, width]
batch_size = class_queries_logits.shape[0]
num_labels = class_queries_logits.shape[-1] - 1
if return_coco_annotation and return_binary_maps:
raise ValueError("return_coco_annotation and return_binary_maps can not be both set to True.")
mask_probs = masks_queries_logits.sigmoid() # [batch_size, num_queries, height, width]
# [batch_size, num_queries, num_classes+1]
class_queries_logits = outputs.class_queries_logits
# [batch_size, num_queries, height, width]
masks_queries_logits = outputs.masks_queries_logits
# Predicted label and score of each query (batch_size, num_queries)
pred_scores, pred_labels = nn.functional.softmax(class_queries_logits, dim=-1).max(-1)
device = masks_queries_logits.device
num_classes = class_queries_logits.shape[-1] - 1
num_queries = class_queries_logits.shape[-2]
# Loop over items in batch size
results: List[Dict[str, TensorType]] = []
for i in range(batch_size):
mask_probs_item, pred_scores_item, pred_labels_item = remove_low_and_no_objects(
mask_probs[i], pred_scores[i], pred_labels[i], threshold, num_labels
)
for i in range(class_queries_logits.shape[0]):
mask_pred = masks_queries_logits[i]
mask_cls = class_queries_logits[i]
# No mask found
if mask_probs_item.shape[0] <= 0:
height, width = target_sizes[i] if target_sizes is not None else mask_probs_item.shape[1:]
segmentation = torch.zeros((height, width)) - 1
results.append({"segmentation": segmentation, "segments_info": []})
continue
scores = torch.nn.functional.softmax(mask_cls, dim=-1)[:, :-1]
labels = torch.arange(num_classes, device=device).unsqueeze(0).repeat(num_queries, 1).flatten(0, 1)
# Get segmentation map and segment information of batch item
target_size = target_sizes[i] if target_sizes is not None else None
segmentation, segments = compute_segments(
mask_probs=mask_probs_item,
pred_scores=pred_scores_item,
pred_labels=pred_labels_item,
mask_threshold=mask_threshold,
overlap_mask_area_threshold=overlap_mask_area_threshold,
label_ids_to_fuse=[],
target_size=target_size,
)
scores_per_image, topk_indices = scores.flatten(0, 1).topk(num_queries, sorted=False)
labels_per_image = labels[topk_indices]
# Return segmentation map in run-length encoding (RLE) format
if return_coco_annotation:
segmentation = convert_segmentation_to_rle(segmentation)
topk_indices = topk_indices // num_classes
mask_pred = mask_pred[topk_indices]
pred_masks = (mask_pred > 0).float()
# Calculate average mask prob
mask_scores_per_image = (mask_pred.sigmoid().flatten(1) * pred_masks.flatten(1)).sum(1) / (
pred_masks.flatten(1).sum(1) + 1e-6
)
pred_scores = scores_per_image * mask_scores_per_image
pred_classes = labels_per_image
segmentation = torch.zeros(masks_queries_logits.shape[2:]) - 1
if target_sizes is not None:
segmentation = torch.zeros(target_sizes[i]) - 1
pred_masks = torch.nn.functional.interpolate(
pred_masks.unsqueeze(0), size=target_sizes[i], mode="nearest"
)[0]
instance_maps, segments = [], []
current_segment_id = 0
for j in range(num_queries):
score = pred_scores[j].item()
if not torch.all(pred_masks[j] == 0) and score >= threshold:
segmentation[pred_masks[j] == 1] = current_segment_id
segments.append(
{
"id": current_segment_id,
"label_id": pred_classes[j].item(),
"was_fused": False,
"score": round(score, 6),
}
)
current_segment_id += 1
instance_maps.append(pred_masks[j])
# Return segmentation map in run-length encoding (RLE) format
if return_coco_annotation:
segmentation = convert_segmentation_to_rle(segmentation)
# Return a concatenated tensor of binary instance maps
if return_binary_maps and len(instance_maps) != 0:
segmentation = torch.stack(instance_maps, dim=0)
results.append({"segmentation": segmentation, "segments_info": segments})
return results
......
......@@ -269,6 +269,13 @@ class LevitImageProcessor(metaclass=DummyObject):
requires_backends(self, ["vision"])
class Mask2FormerImageProcessor(metaclass=DummyObject):
_backends = ["vision"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["vision"])
class MaskFormerFeatureExtractor(metaclass=DummyObject):
_backends = ["vision"]
......
This diff is collapsed.
......@@ -34,7 +34,7 @@ if is_torch_available():
from transformers import Mask2FormerForUniversalSegmentation, Mask2FormerModel
if is_vision_available():
from transformers import MaskFormerImageProcessor
from transformers import Mask2FormerImageProcessor
if is_vision_available():
from PIL import Image
......@@ -325,7 +325,7 @@ class Mask2FormerModelIntegrationTest(unittest.TestCase):
@cached_property
def default_feature_extractor(self):
return MaskFormerImageProcessor.from_pretrained(self.model_checkpoints) if is_vision_available() else None
return Mask2FormerImageProcessor.from_pretrained(self.model_checkpoints) if is_vision_available() else None
def test_inference_no_head(self):
model = Mask2FormerModel.from_pretrained(self.model_checkpoints).to(torch_device)
......
......@@ -576,6 +576,34 @@ class MaskFormerImageProcessingTest(ImageProcessingSavingTestMixin, unittest.Tes
self.assertEqual(segmentation[0].shape, target_sizes[0])
def test_post_process_instance_segmentation(self):
feature_extractor = self.image_processing_class(num_labels=self.image_processor_tester.num_classes)
outputs = self.image_processor_tester.get_fake_maskformer_outputs()
segmentation = feature_extractor.post_process_instance_segmentation(outputs, threshold=0)
self.assertTrue(len(segmentation) == self.image_processor_tester.batch_size)
for el in segmentation:
self.assertTrue("segmentation" in el)
self.assertTrue("segments_info" in el)
self.assertEqual(type(el["segments_info"]), list)
self.assertEqual(
el["segmentation"].shape, (self.image_processor_tester.height, self.image_processor_tester.width)
)
segmentation = feature_extractor.post_process_instance_segmentation(
outputs, threshold=0, return_binary_maps=True
)
self.assertTrue(len(segmentation) == self.image_processor_tester.batch_size)
for el in segmentation:
self.assertTrue("segmentation" in el)
self.assertTrue("segments_info" in el)
self.assertEqual(type(el["segments_info"]), list)
self.assertEqual(len(el["segmentation"].shape), 3)
self.assertEqual(
el["segmentation"].shape[1:], (self.image_processor_tester.height, self.image_processor_tester.width)
)
def test_post_process_panoptic_segmentation(self):
image_processing = self.image_processing_class(num_labels=self.image_processor_tester.num_classes)
outputs = self.image_processor_tester.get_fake_maskformer_outputs()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment