Unverified Commit 8ef98628 authored by Nick DeGroot's avatar Nick DeGroot Committed by GitHub
Browse files

Fix OneFormer `post_process_instance_segmentation` for panoptic tasks (#29304)

* 🐛 Fix oneformer instance post processing when using panoptic task type

* 

 Add unit test for oneformer instance post processing panoptic bug

---------
Co-authored-by: default avatarNick DeGroot <1966472+nickthegroot@users.noreply.github.com>
parent 81220cba
......@@ -1244,8 +1244,8 @@ class OneFormerImageProcessor(BaseImageProcessor):
# if this is panoptic segmentation, we only keep the "thing" classes
if task_type == "panoptic":
keep = torch.zeros_like(scores_per_image).bool()
for i, lab in enumerate(labels_per_image):
keep[i] = lab in self.metadata["thing_ids"]
for j, lab in enumerate(labels_per_image):
keep[j] = lab in self.metadata["thing_ids"]
scores_per_image = scores_per_image[keep]
labels_per_image = labels_per_image[keep]
......@@ -1258,8 +1258,8 @@ class OneFormerImageProcessor(BaseImageProcessor):
continue
if "ade20k" in self.class_info_file and not is_demo and "instance" in task_type:
for i in range(labels_per_image.shape[0]):
labels_per_image[i] = self.metadata["thing_ids"].index(labels_per_image[i].item())
for j in range(labels_per_image.shape[0]):
labels_per_image[j] = self.metadata["thing_ids"].index(labels_per_image[j].item())
# Get segmentation map and segment information of batch item
target_size = target_sizes[i] if target_sizes is not None else None
......
......@@ -295,6 +295,19 @@ class OneFormerImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
el["segmentation"].shape, (self.image_processor_tester.height, self.image_processor_tester.width)
)
segmentation_with_opts = image_processor.post_process_instance_segmentation(
outputs,
threshold=0,
target_sizes=[(1, 4) for _ in range(self.image_processor_tester.batch_size)],
task_type="panoptic",
)
self.assertTrue(len(segmentation_with_opts) == self.image_processor_tester.batch_size)
for el in segmentation_with_opts:
self.assertTrue("segmentation" in el)
self.assertTrue("segments_info" in el)
self.assertEqual(type(el["segments_info"]), list)
self.assertEqual(el["segmentation"].shape, (1, 4))
def test_post_process_panoptic_segmentation(self):
image_processor = self.image_processing_class(
num_labels=self.image_processor_tester.num_classes,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment