Unverified Commit 3822e4a5 authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Enabling MaskFormer in pipelines (#15917)

* Enabling MaskFormer in ppipelines

No AutoModel though :(

* Ooops local file.
parent 79d28e80
...@@ -565,5 +565,5 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM ...@@ -565,5 +565,5 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM
) )
if is_stuff: if is_stuff:
stuff_memory_list[pred_class] = current_segment_id stuff_memory_list[pred_class] = current_segment_id
results.append({"segmentation": segmentation, "segments": segments}) results.append({"segmentation": segmentation, "segments": segments})
return results return results
...@@ -110,7 +110,18 @@ class ImageSegmentationPipeline(Pipeline): ...@@ -110,7 +110,18 @@ class ImageSegmentationPipeline(Pipeline):
return model_outputs return model_outputs
def postprocess(self, model_outputs, raw_image=False, threshold=0.9, mask_threshold=0.5): def postprocess(self, model_outputs, raw_image=False, threshold=0.9, mask_threshold=0.5):
if hasattr(self.feature_extractor, "post_process_segmentation"): if hasattr(self.feature_extractor, "post_process_panoptic_segmentation"):
outputs = self.feature_extractor.post_process_panoptic_segmentation(
model_outputs, is_thing_map=self.model.config.id2label
)[0]
annotation = []
segmentation = outputs["segmentation"]
for segment in outputs["segments"]:
mask = (segmentation == segment["id"]) * 255
mask = Image.fromarray(mask.numpy().astype(np.uint8), mode="L")
label = self.model.config.id2label[segment["category_id"]]
annotation.append({"mask": mask, "label": label, "score": None})
elif hasattr(self.feature_extractor, "post_process_segmentation"):
# Panoptic # Panoptic
raw_annotations = self.feature_extractor.post_process_segmentation( raw_annotations = self.feature_extractor.post_process_segmentation(
model_outputs, model_outputs["target_size"], threshold=threshold, mask_threshold=0.5 model_outputs, model_outputs["target_size"], threshold=threshold, mask_threshold=0.5
......
...@@ -16,6 +16,7 @@ import hashlib ...@@ -16,6 +16,7 @@ import hashlib
import unittest import unittest
import datasets import datasets
from datasets import load_dataset
from transformers import ( from transformers import (
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING, MODEL_FOR_IMAGE_SEGMENTATION_MAPPING,
...@@ -308,3 +309,35 @@ class ImageSegmentationPipelineTests(unittest.TestCase, metaclass=PipelineTestCa ...@@ -308,3 +309,35 @@ class ImageSegmentationPipelineTests(unittest.TestCase, metaclass=PipelineTestCa
{"score": 0.9994, "label": "cat", "mask": "88b37bd2202c750cc9dd191518050a9b0ca5228c"}, {"score": 0.9994, "label": "cat", "mask": "88b37bd2202c750cc9dd191518050a9b0ca5228c"},
], ],
) )
@require_torch
@slow
def test_maskformer(self):
threshold = 0.999
model_id = "facebook/maskformer-swin-base-ade"
from transformers import MaskFormerFeatureExtractor, MaskFormerForInstanceSegmentation
model = MaskFormerForInstanceSegmentation.from_pretrained(model_id)
feature_extractor = MaskFormerFeatureExtractor.from_pretrained(model_id)
image_segmenter = pipeline("image-segmentation", model=model, feature_extractor=feature_extractor)
image = load_dataset("hf-internal-testing/fixtures_ade20k", split="test")
outputs = image_segmenter(image[0]["file"], threshold=threshold)
for o in outputs:
o["mask"] = hashimage(o["mask"])
self.assertEqual(
nested_simplify(outputs, decimals=4),
[
{"mask": "20d1b9480d1dc1501dbdcfdff483e370", "label": "wall", "score": None},
{"mask": "0f902fbc66a0ff711ea455b0e4943adf", "label": "house", "score": None},
{"mask": "4537bdc07d47d84b3f8634b7ada37bd4", "label": "grass", "score": None},
{"mask": "b7ac77dfae44a904b479a0926a2acaf7", "label": "tree", "score": None},
{"mask": "e9bedd56bd40650fb263ce03eb621079", "label": "plant", "score": None},
{"mask": "37a609f8c9c1b8db91fbff269f428b20", "label": "road, route", "score": None},
{"mask": "0d8cdfd63bae8bf6e4344d460a2fa711", "label": "sky", "score": None},
],
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment