Unverified Commit 52c9e6af authored by Alara Dirik's avatar Alara Dirik Committed by GitHub
Browse files

Fix bug in segmentation postprocessing (#20198)

* Fix post_process_instance_segmentation
* Add test for label fusing
parent 292acd71
...@@ -1050,12 +1050,13 @@ class MaskFormerImageProcessor(BaseImageProcessor): ...@@ -1050,12 +1050,13 @@ class MaskFormerImageProcessor(BaseImageProcessor):
# Get segmentation map and segment information of batch item # Get segmentation map and segment information of batch item
target_size = target_sizes[i] if target_sizes is not None else None target_size = target_sizes[i] if target_sizes is not None else None
segmentation, segments = compute_segments( segmentation, segments = compute_segments(
mask_probs_item, mask_probs=mask_probs_item,
pred_scores_item, pred_scores=pred_scores_item,
pred_labels_item, pred_labels=pred_labels_item,
mask_threshold, mask_threshold=mask_threshold,
overlap_mask_area_threshold, overlap_mask_area_threshold=overlap_mask_area_threshold,
target_size, label_ids_to_fuse=[],
target_size=target_size,
) )
# Return segmentation map in run-length encoding (RLE) format # Return segmentation map in run-length encoding (RLE) format
...@@ -1143,13 +1144,13 @@ class MaskFormerImageProcessor(BaseImageProcessor): ...@@ -1143,13 +1144,13 @@ class MaskFormerImageProcessor(BaseImageProcessor):
# Get segmentation map and segment information of batch item # Get segmentation map and segment information of batch item
target_size = target_sizes[i] if target_sizes is not None else None target_size = target_sizes[i] if target_sizes is not None else None
segmentation, segments = compute_segments( segmentation, segments = compute_segments(
mask_probs_item, mask_probs=mask_probs_item,
pred_scores_item, pred_scores=pred_scores_item,
pred_labels_item, pred_labels=pred_labels_item,
mask_threshold, mask_threshold=mask_threshold,
overlap_mask_area_threshold, overlap_mask_area_threshold=overlap_mask_area_threshold,
label_ids_to_fuse, label_ids_to_fuse=label_ids_to_fuse,
target_size, target_size=target_size,
) )
results.append({"segmentation": segmentation, "segments_info": segments}) results.append({"segmentation": segmentation, "segments_info": segments})
......
...@@ -589,3 +589,30 @@ class MaskFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest ...@@ -589,3 +589,30 @@ class MaskFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
self.assertEqual( self.assertEqual(
el["segmentation"].shape, (self.feature_extract_tester.height, self.feature_extract_tester.width) el["segmentation"].shape, (self.feature_extract_tester.height, self.feature_extract_tester.width)
) )
def test_post_process_label_fusing(self):
feature_extractor = self.feature_extraction_class(num_labels=self.feature_extract_tester.num_classes)
outputs = self.feature_extract_tester.get_fake_maskformer_outputs()
segmentation = feature_extractor.post_process_panoptic_segmentation(
outputs, threshold=0, mask_threshold=0, overlap_mask_area_threshold=0
)
unfused_segments = [el["segments_info"] for el in segmentation]
fused_segmentation = feature_extractor.post_process_panoptic_segmentation(
outputs, threshold=0, mask_threshold=0, overlap_mask_area_threshold=0, label_ids_to_fuse={1}
)
fused_segments = [el["segments_info"] for el in fused_segmentation]
for el_unfused, el_fused in zip(unfused_segments, fused_segments):
if len(el_unfused) == 0:
self.assertEqual(len(el_unfused), len(el_fused))
continue
# Get number of segments to be fused
fuse_targets = [1 for el in el_unfused if el["label_id"] in {1}]
num_to_fuse = 0 if len(fuse_targets) == 0 else sum(fuse_targets) - 1
# Expected number of segments after fusing
expected_num_segments = max([el["id"] for el in el_unfused]) - num_to_fuse
num_segments_fused = max([el["id"] for el in el_fused])
self.assertEqual(num_segments_fused, expected_num_segments)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment