Unverified Commit 742273a5 authored by Francesco Saverio Zuppichini's avatar Francesco Saverio Zuppichini Committed by GitHub
Browse files

fix for the output from post_process_panoptic_segmentation (#15916)

parent 7c45fe74
......@@ -538,7 +538,6 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM
# create the area, since bool we just need to sum :)
mask_k_area = mask_k.sum()
# this is the area of all the stuff in query k
# TODO not 100%, why are the taking the k query here????
original_area = (mask_probs[k] >= 0.5).sum()
mask_does_exist = mask_k_area > 0 and original_area > 0
......@@ -565,5 +564,5 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM
)
if is_stuff:
stuff_memory_list[pred_class] = current_segment_id
results.append({"segmentation": segmentation, "segments": segments})
results.append({"segmentation": segmentation, "segments": segments})
return results
......@@ -404,3 +404,23 @@ class MaskFormerModelIntegrationTest(unittest.TestCase):
outputs = model(**inputs)
self.assertTrue(outputs.loss is not None)
def test_panoptic_segmentation(self):
model = MaskFormerForInstanceSegmentation.from_pretrained(self.model_checkpoints).to(torch_device).eval()
feature_extractor = self.default_feature_extractor
inputs = feature_extractor(
[np.zeros((3, 384, 384)), np.zeros((3, 384, 384))],
annotations=[
{"masks": np.random.rand(10, 384, 384).astype(np.float32), "labels": np.zeros(10).astype(np.int64)},
{"masks": np.random.rand(10, 384, 384).astype(np.float32), "labels": np.zeros(10).astype(np.int64)},
],
return_tensors="pt",
)
with torch.no_grad():
outputs = model(**inputs)
panoptic_segmentation = feature_extractor.post_process_panoptic_segmentation(outputs)
self.assertTrue(len(panoptic_segmentation) == 2)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment