Unverified Commit f4e4ad34 authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Add `ForInstanceSegmentation` models to `image-segmentation` pipelines (#15937)

* Adding ForInstanceSegmentation to pipelines.

* Last fix `category_id` renamed to `label_id`.

* Can't be none no more.

* No `is_thing_map` anymore.
parent 5b7dcc73
...@@ -18,6 +18,7 @@ if is_torch_available(): ...@@ -18,6 +18,7 @@ if is_torch_available():
from ..models.auto.modeling_auto import ( from ..models.auto.modeling_auto import (
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING, MODEL_FOR_IMAGE_SEGMENTATION_MAPPING,
MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING,
MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING, MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING,
) )
...@@ -32,10 +33,10 @@ Predictions = List[Prediction] ...@@ -32,10 +33,10 @@ Predictions = List[Prediction]
@add_end_docstrings(PIPELINE_INIT_ARGS) @add_end_docstrings(PIPELINE_INIT_ARGS)
class ImageSegmentationPipeline(Pipeline): class ImageSegmentationPipeline(Pipeline):
""" """
Image segmentation pipeline using any `AutoModelForImageSegmentation`. This pipeline predicts masks of objects and Image segmentation pipeline using any `AutoModelForXXXSegmentation`. This pipeline predicts masks of objects and
their classes. their classes.
This image segmntation pipeline can currently be loaded from [`pipeline`] using the following task identifier: This image segmentation pipeline can currently be loaded from [`pipeline`] using the following task identifier:
`"image-segmentation"`. `"image-segmentation"`.
See the list of available models on See the list of available models on
...@@ -50,7 +51,11 @@ class ImageSegmentationPipeline(Pipeline): ...@@ -50,7 +51,11 @@ class ImageSegmentationPipeline(Pipeline):
requires_backends(self, "vision") requires_backends(self, "vision")
self.check_model_type( self.check_model_type(
dict(MODEL_FOR_IMAGE_SEGMENTATION_MAPPING.items() + MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING.items()) dict(
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING.items()
+ MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING.items()
+ MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING.items()
)
) )
def _sanitize_parameters(self, **kwargs): def _sanitize_parameters(self, **kwargs):
...@@ -112,14 +117,14 @@ class ImageSegmentationPipeline(Pipeline): ...@@ -112,14 +117,14 @@ class ImageSegmentationPipeline(Pipeline):
def postprocess(self, model_outputs, raw_image=False, threshold=0.9, mask_threshold=0.5): def postprocess(self, model_outputs, raw_image=False, threshold=0.9, mask_threshold=0.5):
if hasattr(self.feature_extractor, "post_process_panoptic_segmentation"): if hasattr(self.feature_extractor, "post_process_panoptic_segmentation"):
outputs = self.feature_extractor.post_process_panoptic_segmentation( outputs = self.feature_extractor.post_process_panoptic_segmentation(
model_outputs, is_thing_map=self.model.config.id2label model_outputs, object_mask_threshold=threshold
)[0] )[0]
annotation = [] annotation = []
segmentation = outputs["segmentation"] segmentation = outputs["segmentation"]
for segment in outputs["segments"]: for segment in outputs["segments"]:
mask = (segmentation == segment["id"]) * 255 mask = (segmentation == segment["id"]) * 255
mask = Image.fromarray(mask.numpy().astype(np.uint8), mode="L") mask = Image.fromarray(mask.numpy().astype(np.uint8), mode="L")
label = self.model.config.id2label[segment["category_id"]] label = self.model.config.id2label[segment["label_id"]]
annotation.append({"mask": mask, "label": label, "score": None}) annotation.append({"mask": mask, "label": label, "score": None})
elif hasattr(self.feature_extractor, "post_process_segmentation"): elif hasattr(self.feature_extractor, "post_process_segmentation"):
# Panoptic # Panoptic
......
...@@ -20,11 +20,14 @@ from datasets import load_dataset ...@@ -20,11 +20,14 @@ from datasets import load_dataset
from transformers import ( from transformers import (
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING, MODEL_FOR_IMAGE_SEGMENTATION_MAPPING,
MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING,
MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING, MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING,
AutoFeatureExtractor, AutoFeatureExtractor,
AutoModelForImageSegmentation, AutoModelForImageSegmentation,
AutoModelForInstanceSegmentation,
DetrForSegmentation, DetrForSegmentation,
ImageSegmentationPipeline, ImageSegmentationPipeline,
MaskFormerForInstanceSegmentation,
is_vision_available, is_vision_available,
pipeline, pipeline,
) )
...@@ -67,6 +70,7 @@ class ImageSegmentationPipelineTests(unittest.TestCase, metaclass=PipelineTestCa ...@@ -67,6 +70,7 @@ class ImageSegmentationPipelineTests(unittest.TestCase, metaclass=PipelineTestCa
list(MODEL_FOR_IMAGE_SEGMENTATION_MAPPING.items()) if MODEL_FOR_IMAGE_SEGMENTATION_MAPPING else [] list(MODEL_FOR_IMAGE_SEGMENTATION_MAPPING.items()) if MODEL_FOR_IMAGE_SEGMENTATION_MAPPING else []
) )
+ (MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING.items() if MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING else []) + (MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING.items() if MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING else [])
+ (MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING.items() if MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING else [])
} }
def get_test_pipeline(self, model, tokenizer, feature_extractor): def get_test_pipeline(self, model, tokenizer, feature_extractor):
...@@ -80,7 +84,12 @@ class ImageSegmentationPipelineTests(unittest.TestCase, metaclass=PipelineTestCa ...@@ -80,7 +84,12 @@ class ImageSegmentationPipelineTests(unittest.TestCase, metaclass=PipelineTestCa
outputs = image_segmenter("./tests/fixtures/tests_samples/COCO/000000039769.png", threshold=0.0) outputs = image_segmenter("./tests/fixtures/tests_samples/COCO/000000039769.png", threshold=0.0)
self.assertIsInstance(outputs, list) self.assertIsInstance(outputs, list)
n = len(outputs) n = len(outputs)
self.assertGreater(n, 1) if isinstance(image_segmenter.model, (MaskFormerForInstanceSegmentation)):
# Instance segmentation (maskformer) have a slot for null class
# and can output nothing even with a low threshold
self.assertGreaterEqual(n, 0)
else:
self.assertGreaterEqual(n, 1)
# XXX: PIL.Image implements __eq__ which bypasses ANY, so we inverse the comparison # XXX: PIL.Image implements __eq__ which bypasses ANY, so we inverse the comparison
# to make it work # to make it work
self.assertEqual([{"score": ANY(float, type(None)), "label": ANY(str), "mask": ANY(Image.Image)}] * n, outputs) self.assertEqual([{"score": ANY(float, type(None)), "label": ANY(str), "mask": ANY(Image.Image)}] * n, outputs)
...@@ -119,7 +128,6 @@ class ImageSegmentationPipelineTests(unittest.TestCase, metaclass=PipelineTestCa ...@@ -119,7 +128,6 @@ class ImageSegmentationPipelineTests(unittest.TestCase, metaclass=PipelineTestCa
] ]
outputs = image_segmenter(batch, threshold=0.0, batch_size=batch_size) outputs = image_segmenter(batch, threshold=0.0, batch_size=batch_size)
self.assertEqual(len(batch), len(outputs)) self.assertEqual(len(batch), len(outputs))
self.assertEqual({"score": ANY(float, type(None)), "label": ANY(str), "mask": ANY(Image.Image)}, outputs[0][0])
self.assertEqual(len(outputs[0]), n) self.assertEqual(len(outputs[0]), n)
self.assertEqual( self.assertEqual(
[ [
...@@ -313,18 +321,17 @@ class ImageSegmentationPipelineTests(unittest.TestCase, metaclass=PipelineTestCa ...@@ -313,18 +321,17 @@ class ImageSegmentationPipelineTests(unittest.TestCase, metaclass=PipelineTestCa
@require_torch @require_torch
@slow @slow
def test_maskformer(self): def test_maskformer(self):
threshold = 0.999 threshold = 0.8
model_id = "facebook/maskformer-swin-base-ade" model_id = "facebook/maskformer-swin-base-ade"
from transformers import MaskFormerFeatureExtractor, MaskFormerForInstanceSegmentation model = AutoModelForInstanceSegmentation.from_pretrained(model_id)
feature_extractor = AutoFeatureExtractor.from_pretrained(model_id)
model = MaskFormerForInstanceSegmentation.from_pretrained(model_id)
feature_extractor = MaskFormerFeatureExtractor.from_pretrained(model_id)
image_segmenter = pipeline("image-segmentation", model=model, feature_extractor=feature_extractor) image_segmenter = pipeline("image-segmentation", model=model, feature_extractor=feature_extractor)
image = load_dataset("hf-internal-testing/fixtures_ade20k", split="test") image = load_dataset("hf-internal-testing/fixtures_ade20k", split="test")
outputs = image_segmenter(image[0]["file"], threshold=threshold) file = image[0]["file"]
outputs = image_segmenter(file, threshold=threshold)
for o in outputs: for o in outputs:
o["mask"] = hashimage(o["mask"]) o["mask"] = hashimage(o["mask"])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment