Unverified Commit 07e94bf1 authored by Alara Dirik's avatar Alara Dirik Committed by GitHub
Browse files

Maskformer post-processing fixes and improvements (#19172)

- Improves MaskFormer docs, corrects minor typos
- Restructures MaskFormerFeatureExtractor.post_process_panoptic_segmentation for better readability, adds target_sizes argument for optional resizing
- Adds post_process_semantic_segmentation and post_process_instance_segmentation methods.
- Adds a deprecation warning to post_process_segmentation method in favour of post_process_instance_segmentation
parent 6268694e
...@@ -58,6 +58,7 @@ This model was contributed by [francesco](https://huggingface.co/francesco). The ...@@ -58,6 +58,7 @@ This model was contributed by [francesco](https://huggingface.co/francesco). The
- encode_inputs - encode_inputs
- post_process_segmentation - post_process_segmentation
- post_process_semantic_segmentation - post_process_semantic_segmentation
- post_process_instance_segmentation
- post_process_panoptic_segmentation - post_process_panoptic_segmentation
## MaskFormerModel ## MaskFormerModel
......
...@@ -259,7 +259,8 @@ class MaskFormerForInstanceSegmentationOutput(ModelOutput): ...@@ -259,7 +259,8 @@ class MaskFormerForInstanceSegmentationOutput(ModelOutput):
""" """
Class for outputs of [`MaskFormerForInstanceSegmentation`]. Class for outputs of [`MaskFormerForInstanceSegmentation`].
This output can be directly passed to [`~MaskFormerFeatureExtractor.post_process_segmentation`] or This output can be directly passed to [`~MaskFormerFeatureExtractor.post_process_semantic_segmentation`] or or
[`~MaskFormerFeatureExtractor.post_process_instance_segmentation`] or
[`~MaskFormerFeatureExtractor.post_process_panoptic_segmentation`] depending on the task. Please, see [`~MaskFormerFeatureExtractor.post_process_panoptic_segmentation`] depending on the task. Please, see
[`~MaskFormerFeatureExtractor] for details regarding usage. [`~MaskFormerFeatureExtractor] for details regarding usage.
...@@ -267,11 +268,11 @@ class MaskFormerForInstanceSegmentationOutput(ModelOutput): ...@@ -267,11 +268,11 @@ class MaskFormerForInstanceSegmentationOutput(ModelOutput):
loss (`torch.Tensor`, *optional*): loss (`torch.Tensor`, *optional*):
The computed loss, returned when labels are present. The computed loss, returned when labels are present.
class_queries_logits (`torch.FloatTensor`): class_queries_logits (`torch.FloatTensor`):
A tensor of shape `(batch_size, num_queries, height, width)` representing the proposed masks for each
query.
masks_queries_logits (`torch.FloatTensor`):
A tensor of shape `(batch_size, num_queries, num_labels + 1)` representing the proposed classes for each A tensor of shape `(batch_size, num_queries, num_labels + 1)` representing the proposed classes for each
query. Note the `+ 1` is needed because we incorporate the null class. query. Note the `+ 1` is needed because we incorporate the null class.
masks_queries_logits (`torch.FloatTensor`):
A tensor of shape `(batch_size, num_queries, height, width)` representing the proposed masks for each
query.
encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Last hidden states (final feature map) of the last stage of the encoder model (backbone). Last hidden states (final feature map) of the last stage of the encoder model (backbone).
pixel_decoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): pixel_decoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
...@@ -2547,8 +2548,8 @@ class MaskFormerForInstanceSegmentation(MaskFormerPreTrainedModel): ...@@ -2547,8 +2548,8 @@ class MaskFormerForInstanceSegmentation(MaskFormerPreTrainedModel):
>>> masks_queries_logits = outputs.masks_queries_logits >>> masks_queries_logits = outputs.masks_queries_logits
>>> # you can pass them to feature_extractor for postprocessing >>> # you can pass them to feature_extractor for postprocessing
>>> output = feature_extractor.post_process_segmentation(outputs)
>>> output = feature_extractor.post_process_semantic_segmentation(outputs) >>> output = feature_extractor.post_process_semantic_segmentation(outputs)
>>> output = feature_extractor.post_process_instance_segmentation(outputs)
>>> output = feature_extractor.post_process_panoptic_segmentation(outputs) >>> output = feature_extractor.post_process_panoptic_segmentation(outputs)
``` ```
""" """
......
...@@ -29,6 +29,7 @@ if is_torch_available(): ...@@ -29,6 +29,7 @@ if is_torch_available():
if is_vision_available(): if is_vision_available():
from transformers import MaskFormerFeatureExtractor from transformers import MaskFormerFeatureExtractor
from transformers.models.maskformer.feature_extraction_maskformer import binary_mask_to_rle
from transformers.models.maskformer.modeling_maskformer import MaskFormerForInstanceSegmentationOutput from transformers.models.maskformer.modeling_maskformer import MaskFormerForInstanceSegmentationOutput
if is_vision_available(): if is_vision_available():
...@@ -344,6 +345,17 @@ class MaskFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest ...@@ -344,6 +345,17 @@ class MaskFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
common(is_instance_map=False, segmentation_type="pil") common(is_instance_map=False, segmentation_type="pil")
common(is_instance_map=True, segmentation_type="pil") common(is_instance_map=True, segmentation_type="pil")
def test_binary_mask_to_rle(self):
fake_binary_mask = np.zeros((20, 50))
fake_binary_mask[0, 20:] = 1
fake_binary_mask[1, :15] = 1
fake_binary_mask[5, :10] = 1
rle = binary_mask_to_rle(fake_binary_mask)
self.assertEqual(len(rle), 4)
self.assertEqual(rle[0], 21)
self.assertEqual(rle[1], 45)
def test_post_process_segmentation(self): def test_post_process_segmentation(self):
fature_extractor = self.feature_extraction_class(num_labels=self.feature_extract_tester.num_classes) fature_extractor = self.feature_extraction_class(num_labels=self.feature_extract_tester.num_classes)
outputs = self.feature_extract_tester.get_fake_maskformer_outputs() outputs = self.feature_extract_tester.get_fake_maskformer_outputs()
...@@ -373,31 +385,30 @@ class MaskFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest ...@@ -373,31 +385,30 @@ class MaskFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
segmentation = fature_extractor.post_process_semantic_segmentation(outputs) segmentation = fature_extractor.post_process_semantic_segmentation(outputs)
self.assertEqual(len(segmentation), self.feature_extract_tester.batch_size)
self.assertEqual( self.assertEqual(
segmentation.shape, segmentation[0].shape,
( (
self.feature_extract_tester.batch_size,
self.feature_extract_tester.height, self.feature_extract_tester.height,
self.feature_extract_tester.width, self.feature_extract_tester.width,
), ),
) )
target_size = (1, 4) target_sizes = [(1, 4) for i in range(self.feature_extract_tester.batch_size)]
segmentation = fature_extractor.post_process_semantic_segmentation(outputs, target_sizes=target_sizes)
segmentation = fature_extractor.post_process_semantic_segmentation(outputs, target_size=target_size)
self.assertEqual(segmentation.shape, (self.feature_extract_tester.batch_size, *target_size)) self.assertEqual(segmentation[0].shape, target_sizes[0])
def test_post_process_panoptic_segmentation(self): def test_post_process_panoptic_segmentation(self):
fature_extractor = self.feature_extraction_class(num_labels=self.feature_extract_tester.num_classes) fature_extractor = self.feature_extraction_class(num_labels=self.feature_extract_tester.num_classes)
outputs = self.feature_extract_tester.get_fake_maskformer_outputs() outputs = self.feature_extract_tester.get_fake_maskformer_outputs()
segmentation = fature_extractor.post_process_panoptic_segmentation(outputs, object_mask_threshold=0) segmentation = fature_extractor.post_process_panoptic_segmentation(outputs, threshold=0)
self.assertTrue(len(segmentation) == self.feature_extract_tester.batch_size) self.assertTrue(len(segmentation) == self.feature_extract_tester.batch_size)
for el in segmentation: for el in segmentation:
self.assertTrue("segmentation" in el) self.assertTrue("segmentation" in el)
self.assertTrue("segments" in el) self.assertTrue("segments_info" in el)
self.assertEqual(type(el["segments"]), list) self.assertEqual(type(el["segments_info"]), list)
self.assertEqual( self.assertEqual(
el["segmentation"].shape, (self.feature_extract_tester.height, self.feature_extract_tester.width) el["segmentation"].shape, (self.feature_extract_tester.height, self.feature_extract_tester.width)
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment