Unverified Commit 7598791c authored by Alara Dirik's avatar Alara Dirik Committed by GitHub
Browse files

Fix MaskFormer failing postprocess tests (#19354)

Ensures post_process_instance_segmentation and post_process_panoptic_segmentation methods return a tensor of shape (target_height, target_width) filled with -1 values if no segment with score > threshold is found.
parent ad98642a
...@@ -772,8 +772,9 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM ...@@ -772,8 +772,9 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM
# No mask found # No mask found
if mask_probs_item.shape[0] <= 0: if mask_probs_item.shape[0] <= 0:
segmentation = None height, width = target_sizes[i] if target_sizes is not None else mask_probs_item.shape[1:]
segments: List[Dict] = [] segmentation = torch.zeros((height, width)) - 1
results.append({"segmentation": segmentation, "segments_info": []})
continue continue
# Get segmentation map and segment information of batch item # Get segmentation map and segment information of batch item
...@@ -860,8 +861,9 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM ...@@ -860,8 +861,9 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM
# No mask found # No mask found
if mask_probs_item.shape[0] <= 0: if mask_probs_item.shape[0] <= 0:
segmentation = None height, width = target_sizes[i] if target_sizes is not None else mask_probs_item.shape[1:]
segments: List[Dict] = [] segmentation = torch.zeros((height, width)) - 1
results.append({"segmentation": segmentation, "segments_info": []})
continue continue
# Get segmentation map and segment information of batch item # Get segmentation map and segment information of batch item
......
...@@ -401,10 +401,11 @@ class MaskFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest ...@@ -401,10 +401,11 @@ class MaskFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
@unittest.skip("Fix me Alara!") @unittest.skip("Fix me Alara!")
def test_post_process_panoptic_segmentation(self): def test_post_process_panoptic_segmentation(self):
fature_extractor = self.feature_extraction_class(num_labels=self.feature_extract_tester.num_classes) feature_extractor = self.feature_extraction_class(num_labels=self.feature_extract_tester.num_classes)
outputs = self.feature_extract_tester.get_fake_maskformer_outputs() outputs = self.feature_extract_tester.get_fake_maskformer_outputs()
segmentation = fature_extractor.post_process_panoptic_segmentation(outputs, threshold=0) segmentation = feature_extractor.post_process_panoptic_segmentation(outputs, threshold=0)
print(len(segmentation))
print(self.feature_extract_tester.batch_size)
self.assertTrue(len(segmentation) == self.feature_extract_tester.batch_size) self.assertTrue(len(segmentation) == self.feature_extract_tester.batch_size)
for el in segmentation: for el in segmentation:
self.assertTrue("segmentation" in el) self.assertTrue("segmentation" in el)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment