Unverified Commit 742273a5 authored by Francesco Saverio Zuppichini's avatar Francesco Saverio Zuppichini Committed by GitHub
Browse files

fix for the output from post_process_panoptic_segmentation (#15916)

parent 7c45fe74
...@@ -538,7 +538,6 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM ...@@ -538,7 +538,6 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM
# create the area, since bool we just need to sum :) # create the area, since bool we just need to sum :)
mask_k_area = mask_k.sum() mask_k_area = mask_k.sum()
# this is the area of all the stuff in query k # this is the area of all the stuff in query k
# TODO not 100%, why are the taking the k query here????
original_area = (mask_probs[k] >= 0.5).sum() original_area = (mask_probs[k] >= 0.5).sum()
mask_does_exist = mask_k_area > 0 and original_area > 0 mask_does_exist = mask_k_area > 0 and original_area > 0
......
...@@ -404,3 +404,23 @@ class MaskFormerModelIntegrationTest(unittest.TestCase): ...@@ -404,3 +404,23 @@ class MaskFormerModelIntegrationTest(unittest.TestCase):
outputs = model(**inputs) outputs = model(**inputs)
self.assertTrue(outputs.loss is not None) self.assertTrue(outputs.loss is not None)
def test_panoptic_segmentation(self):
model = MaskFormerForInstanceSegmentation.from_pretrained(self.model_checkpoints).to(torch_device).eval()
feature_extractor = self.default_feature_extractor
inputs = feature_extractor(
[np.zeros((3, 384, 384)), np.zeros((3, 384, 384))],
annotations=[
{"masks": np.random.rand(10, 384, 384).astype(np.float32), "labels": np.zeros(10).astype(np.int64)},
{"masks": np.random.rand(10, 384, 384).astype(np.float32), "labels": np.zeros(10).astype(np.int64)},
],
return_tensors="pt",
)
with torch.no_grad():
outputs = model(**inputs)
panoptic_segmentation = feature_extractor.post_process_panoptic_segmentation(outputs)
self.assertTrue(len(panoptic_segmentation) == 2)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment