Unverified Commit 07e94bf1 authored by Alara Dirik's avatar Alara Dirik Committed by GitHub
Browse files

Maskformer post-processing fixes and improvements (#19172)

- Improves MaskFormer docs, corrects minor typos
- Restructures MaskFormerFeatureExtractor.post_process_panoptic_segmentation for better readability, adds target_sizes argument for optional resizing
- Adds post_process_semantic_segmentation and post_process_instance_segmentation methods.
- Adds a deprecation warning to post_process_segmentation method in favour of post_process_instance_segmentation
parent 6268694e
......@@ -58,6 +58,7 @@ This model was contributed by [francesco](https://huggingface.co/francesco). The
- encode_inputs
- post_process_segmentation
- post_process_semantic_segmentation
- post_process_instance_segmentation
- post_process_panoptic_segmentation
## MaskFormerModel
......
......@@ -259,7 +259,8 @@ class MaskFormerForInstanceSegmentationOutput(ModelOutput):
"""
Class for outputs of [`MaskFormerForInstanceSegmentation`].
This output can be directly passed to [`~MaskFormerFeatureExtractor.post_process_segmentation`] or
This output can be directly passed to [`~MaskFormerFeatureExtractor.post_process_semantic_segmentation`] or or
[`~MaskFormerFeatureExtractor.post_process_instance_segmentation`] or
[`~MaskFormerFeatureExtractor.post_process_panoptic_segmentation`] depending on the task. Please, see
[`~MaskFormerFeatureExtractor] for details regarding usage.
......@@ -267,11 +268,11 @@ class MaskFormerForInstanceSegmentationOutput(ModelOutput):
loss (`torch.Tensor`, *optional*):
The computed loss, returned when labels are present.
class_queries_logits (`torch.FloatTensor`):
A tensor of shape `(batch_size, num_queries, height, width)` representing the proposed masks for each
query.
masks_queries_logits (`torch.FloatTensor`):
A tensor of shape `(batch_size, num_queries, num_labels + 1)` representing the proposed classes for each
query. Note the `+ 1` is needed because we incorporate the null class.
masks_queries_logits (`torch.FloatTensor`):
A tensor of shape `(batch_size, num_queries, height, width)` representing the proposed masks for each
query.
encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Last hidden states (final feature map) of the last stage of the encoder model (backbone).
pixel_decoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
......@@ -2547,8 +2548,8 @@ class MaskFormerForInstanceSegmentation(MaskFormerPreTrainedModel):
>>> masks_queries_logits = outputs.masks_queries_logits
>>> # you can pass them to feature_extractor for postprocessing
>>> output = feature_extractor.post_process_segmentation(outputs)
>>> output = feature_extractor.post_process_semantic_segmentation(outputs)
>>> output = feature_extractor.post_process_instance_segmentation(outputs)
>>> output = feature_extractor.post_process_panoptic_segmentation(outputs)
```
"""
......
......@@ -29,6 +29,7 @@ if is_torch_available():
if is_vision_available():
from transformers import MaskFormerFeatureExtractor
from transformers.models.maskformer.feature_extraction_maskformer import binary_mask_to_rle
from transformers.models.maskformer.modeling_maskformer import MaskFormerForInstanceSegmentationOutput
if is_vision_available():
......@@ -344,6 +345,17 @@ class MaskFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
common(is_instance_map=False, segmentation_type="pil")
common(is_instance_map=True, segmentation_type="pil")
def test_binary_mask_to_rle(self):
fake_binary_mask = np.zeros((20, 50))
fake_binary_mask[0, 20:] = 1
fake_binary_mask[1, :15] = 1
fake_binary_mask[5, :10] = 1
rle = binary_mask_to_rle(fake_binary_mask)
self.assertEqual(len(rle), 4)
self.assertEqual(rle[0], 21)
self.assertEqual(rle[1], 45)
def test_post_process_segmentation(self):
fature_extractor = self.feature_extraction_class(num_labels=self.feature_extract_tester.num_classes)
outputs = self.feature_extract_tester.get_fake_maskformer_outputs()
......@@ -373,31 +385,30 @@ class MaskFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
segmentation = fature_extractor.post_process_semantic_segmentation(outputs)
self.assertEqual(len(segmentation), self.feature_extract_tester.batch_size)
self.assertEqual(
segmentation.shape,
segmentation[0].shape,
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.height,
self.feature_extract_tester.width,
),
)
target_size = (1, 4)
segmentation = fature_extractor.post_process_semantic_segmentation(outputs, target_size=target_size)
target_sizes = [(1, 4) for i in range(self.feature_extract_tester.batch_size)]
segmentation = fature_extractor.post_process_semantic_segmentation(outputs, target_sizes=target_sizes)
self.assertEqual(segmentation.shape, (self.feature_extract_tester.batch_size, *target_size))
self.assertEqual(segmentation[0].shape, target_sizes[0])
def test_post_process_panoptic_segmentation(self):
fature_extractor = self.feature_extraction_class(num_labels=self.feature_extract_tester.num_classes)
outputs = self.feature_extract_tester.get_fake_maskformer_outputs()
segmentation = fature_extractor.post_process_panoptic_segmentation(outputs, object_mask_threshold=0)
segmentation = fature_extractor.post_process_panoptic_segmentation(outputs, threshold=0)
self.assertTrue(len(segmentation) == self.feature_extract_tester.batch_size)
for el in segmentation:
self.assertTrue("segmentation" in el)
self.assertTrue("segments" in el)
self.assertEqual(type(el["segments"]), list)
self.assertTrue("segments_info" in el)
self.assertEqual(type(el["segments_info"]), list)
self.assertEqual(
el["segmentation"].shape, (self.feature_extract_tester.height, self.feature_extract_tester.width)
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment