Unverified Commit f1430377 authored by Arthur's avatar Arthur Committed by GitHub
Browse files

Add `automatic-mask-generation` pipeline for Segment Anything Model (SAM) (#22840)



* cleanup

* updates

* more refactoring

* make style

* update inits

* support other inputs in base

* update based on review
Co-authored-by: default avatarNicolas Patry <patry.nicolas@gmail.com>

* Update tests/pipelines/test_pipelines_automatic_mask_generation.py
Co-authored-by: default avatarNicolas Patry <patry.nicolas@protonmail.com>

* update

* fixup

* TODO x and y to refactor, _h _w refactored here

* update docstring

* more nits

* style on these

* more doc fix

* rename variables

* update

* updates

* style

* update

* fix `_mask_to_rle_pytorch`

* styling

* fix ask to rle, wrong outputs

* add device arg

* update

* more updates, fix tets

* udpate

* update docstrings

* styling

* fixup

* add notebook on the docs

* update orginal sizes

* fix docstring

* updat condition on point_per-batch

* updates tests

* fix CI  test

* extend is required, append does not work!

* fixup

* fix CI tests

* whit pixels left

* address doc comments

* fix doc

* slow pipeline tests

* update auto init

* add revision

* make fixup

* update p!ipoeline tag when calling tests

* alphabeitcal order in inits

* fix copies

* last style nits

* Apply suggestions from code review
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* reformat docstring

* more reformat

* address most of the comments

* Update src/transformers/pipelines/mask_generation.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* final refactor

* Update src/transformers/models/sam/image_processing_sam.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* fixup and fix slow tests

* revert

---------
Co-authored-by: default avatarNicolas Patry <patry.nicolas@gmail.com>
Co-authored-by: default avatarNicolas Patry <patry.nicolas@protonmail.com>
Co-authored-by: default avataryounesbelkada <younesbelkada@gmail.com>
Co-authored-by: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent e5f34871
...@@ -64,6 +64,7 @@ scores = outputs.iou_scores ...@@ -64,6 +64,7 @@ scores = outputs.iou_scores
Resources: Resources:
- [Demo notebook](https://github.com/huggingface/notebooks/blob/main/examples/segment_anything.ipynb) for using the model - [Demo notebook](https://github.com/huggingface/notebooks/blob/main/examples/segment_anything.ipynb) for using the model
- [Demo notebook](https://github.com/huggingface/notebooks/blob/main/examples/automatic_mask_generation.ipynb) for using automatic mask generation pipeline.
## SamConfig ## SamConfig
......
...@@ -1012,6 +1012,7 @@ else: ...@@ -1012,6 +1012,7 @@ else:
"MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING", "MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING",
"MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING", "MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING",
"MODEL_FOR_MASKED_LM_MAPPING", "MODEL_FOR_MASKED_LM_MAPPING",
"MODEL_FOR_MASK_GENERATION_MAPPING",
"MODEL_FOR_MULTIPLE_CHOICE_MAPPING", "MODEL_FOR_MULTIPLE_CHOICE_MAPPING",
"MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING", "MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING",
"MODEL_FOR_OBJECT_DETECTION_MAPPING", "MODEL_FOR_OBJECT_DETECTION_MAPPING",
...@@ -4650,6 +4651,7 @@ if TYPE_CHECKING: ...@@ -4650,6 +4651,7 @@ if TYPE_CHECKING:
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING, MODEL_FOR_IMAGE_SEGMENTATION_MAPPING,
MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING, MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING,
MODEL_FOR_MASK_GENERATION_MAPPING,
MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING, MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING,
MODEL_FOR_MASKED_LM_MAPPING, MODEL_FOR_MASKED_LM_MAPPING,
MODEL_FOR_MULTIPLE_CHOICE_MAPPING, MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
......
...@@ -52,6 +52,7 @@ else: ...@@ -52,6 +52,7 @@ else:
"MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING", "MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING",
"MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING", "MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING",
"MODEL_FOR_MASKED_LM_MAPPING", "MODEL_FOR_MASKED_LM_MAPPING",
"MODEL_FOR_MASK_GENERATION_MAPPING",
"MODEL_FOR_MULTIPLE_CHOICE_MAPPING", "MODEL_FOR_MULTIPLE_CHOICE_MAPPING",
"MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING", "MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING",
"MODEL_FOR_OBJECT_DETECTION_MAPPING", "MODEL_FOR_OBJECT_DETECTION_MAPPING",
...@@ -213,6 +214,7 @@ if TYPE_CHECKING: ...@@ -213,6 +214,7 @@ if TYPE_CHECKING:
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING, MODEL_FOR_IMAGE_SEGMENTATION_MAPPING,
MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING, MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING,
MODEL_FOR_MASK_GENERATION_MAPPING,
MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING, MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING,
MODEL_FOR_MASKED_LM_MAPPING, MODEL_FOR_MASKED_LM_MAPPING,
MODEL_FOR_MULTIPLE_CHOICE_MAPPING, MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
......
...@@ -977,7 +977,7 @@ MODEL_FOR_BACKBONE_MAPPING_NAMES = OrderedDict( ...@@ -977,7 +977,7 @@ MODEL_FOR_BACKBONE_MAPPING_NAMES = OrderedDict(
] ]
) )
MODEL_FOR_AUTOMATIC_MASK_GENERATION_MAPPING_NAMES = OrderedDict( MODEL_FOR_MASK_GENERATION_MAPPING_NAMES = OrderedDict(
[ [
("sam", "SamModel"), ("sam", "SamModel"),
] ]
...@@ -1058,9 +1058,11 @@ MODEL_FOR_AUDIO_XVECTOR_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_F ...@@ -1058,9 +1058,11 @@ MODEL_FOR_AUDIO_XVECTOR_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_F
MODEL_FOR_BACKBONE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_BACKBONE_MAPPING_NAMES) MODEL_FOR_BACKBONE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_BACKBONE_MAPPING_NAMES)
MODEL_FOR_AUTOMATIC_MASK_GENERATION_MAPPING = _LazyAutoMapping( MODEL_FOR_MASK_GENERATION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MASK_GENERATION_MAPPING_NAMES)
CONFIG_MAPPING_NAMES, MODEL_FOR_AUTOMATIC_MASK_GENERATION_MAPPING_NAMES
)
class AutoModelForMaskGeneration(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_MASK_GENERATION_MAPPING
class AutoModel(_BaseAutoModelClass): class AutoModel(_BaseAutoModelClass):
......
...@@ -63,6 +63,7 @@ from .fill_mask import FillMaskPipeline ...@@ -63,6 +63,7 @@ from .fill_mask import FillMaskPipeline
from .image_classification import ImageClassificationPipeline from .image_classification import ImageClassificationPipeline
from .image_segmentation import ImageSegmentationPipeline from .image_segmentation import ImageSegmentationPipeline
from .image_to_text import ImageToTextPipeline from .image_to_text import ImageToTextPipeline
from .mask_generation import MaskGenerationPipeline
from .object_detection import ObjectDetectionPipeline from .object_detection import ObjectDetectionPipeline
from .question_answering import QuestionAnsweringArgumentHandler, QuestionAnsweringPipeline from .question_answering import QuestionAnsweringArgumentHandler, QuestionAnsweringPipeline
from .table_question_answering import TableQuestionAnsweringArgumentHandler, TableQuestionAnsweringPipeline from .table_question_answering import TableQuestionAnsweringArgumentHandler, TableQuestionAnsweringPipeline
...@@ -124,6 +125,7 @@ if is_torch_available(): ...@@ -124,6 +125,7 @@ if is_torch_available():
AutoModelForImageClassification, AutoModelForImageClassification,
AutoModelForImageSegmentation, AutoModelForImageSegmentation,
AutoModelForMaskedLM, AutoModelForMaskedLM,
AutoModelForMaskGeneration,
AutoModelForObjectDetection, AutoModelForObjectDetection,
AutoModelForQuestionAnswering, AutoModelForQuestionAnswering,
AutoModelForSemanticSegmentation, AutoModelForSemanticSegmentation,
...@@ -384,6 +386,13 @@ SUPPORTED_TASKS = { ...@@ -384,6 +386,13 @@ SUPPORTED_TASKS = {
"default": {"model": {"pt": ("MCG-NJU/videomae-base-finetuned-kinetics", "4800870")}}, "default": {"model": {"pt": ("MCG-NJU/videomae-base-finetuned-kinetics", "4800870")}},
"type": "video", "type": "video",
}, },
"mask-generation": {
"impl": MaskGenerationPipeline,
"tf": (),
"pt": (AutoModelForMaskGeneration,) if is_torch_available() else (),
"default": {"model": {"pt": ("facebook/sam-vit-huge", "997b15")}},
"type": "multimodal",
},
} }
NO_FEATURE_EXTRACTOR_TASKS = set() NO_FEATURE_EXTRACTOR_TASKS = set()
...@@ -536,6 +545,7 @@ def pipeline( ...@@ -536,6 +545,7 @@ def pipeline(
- `"image-classification"`: will return a [`ImageClassificationPipeline`]. - `"image-classification"`: will return a [`ImageClassificationPipeline`].
- `"image-segmentation"`: will return a [`ImageSegmentationPipeline`]. - `"image-segmentation"`: will return a [`ImageSegmentationPipeline`].
- `"image-to-text"`: will return a [`ImageToTextPipeline`]. - `"image-to-text"`: will return a [`ImageToTextPipeline`].
- `"mask-generation"`: will return a [`MaskGenerationPipeline`].
- `"object-detection"`: will return a [`ObjectDetectionPipeline`]. - `"object-detection"`: will return a [`ObjectDetectionPipeline`].
- `"question-answering"`: will return a [`QuestionAnsweringPipeline`]. - `"question-answering"`: will return a [`QuestionAnsweringPipeline`].
- `"summarization"`: will return a [`SummarizationPipeline`]. - `"summarization"`: will return a [`SummarizationPipeline`].
......
...@@ -97,6 +97,8 @@ def _pad(items, key, padding_value, padding_side): ...@@ -97,6 +97,8 @@ def _pad(items, key, padding_value, padding_side):
tensor = torch.zeros((batch_size, max_length), dtype=dtype) + padding_value tensor = torch.zeros((batch_size, max_length), dtype=dtype) + padding_value
elif dim == 3: elif dim == 3:
tensor = torch.zeros((batch_size, max_length, shape[-1]), dtype=dtype) + padding_value tensor = torch.zeros((batch_size, max_length, shape[-1]), dtype=dtype) + padding_value
elif dim == 4:
tensor = torch.zeros((batch_size, max_length, shape[-2], shape[-1]), dtype=dtype) + padding_value
for i, item in enumerate(items): for i, item in enumerate(items):
if dim == 2: if dim == 2:
...@@ -109,6 +111,12 @@ def _pad(items, key, padding_value, padding_side): ...@@ -109,6 +111,12 @@ def _pad(items, key, padding_value, padding_side):
tensor[i, -len(item[key][0]) :, :] = item[key][0].clone() tensor[i, -len(item[key][0]) :, :] = item[key][0].clone()
else: else:
tensor[i, : len(item[key][0]), :] = item[key][0].clone() tensor[i, : len(item[key][0]), :] = item[key][0].clone()
elif dim == 4:
if padding_side == "left":
tensor[i, -len(item[key][0]) :, :, :] = item[key][0].clone()
else:
tensor[i, : len(item[key][0]), :, :] = item[key][0].clone()
return tensor return tensor
else: else:
return [item[key] for item in items] return [item[key] for item in items]
......
...@@ -81,11 +81,11 @@ class ImageSegmentationPipeline(Pipeline): ...@@ -81,11 +81,11 @@ class ImageSegmentationPipeline(Pipeline):
) )
def _sanitize_parameters(self, **kwargs): def _sanitize_parameters(self, **kwargs):
preprocessor_kwargs = {} preprocess_kwargs = {}
postprocess_kwargs = {} postprocess_kwargs = {}
if "subtask" in kwargs: if "subtask" in kwargs:
postprocess_kwargs["subtask"] = kwargs["subtask"] postprocess_kwargs["subtask"] = kwargs["subtask"]
preprocessor_kwargs["subtask"] = kwargs["subtask"] preprocess_kwargs["subtask"] = kwargs["subtask"]
if "threshold" in kwargs: if "threshold" in kwargs:
postprocess_kwargs["threshold"] = kwargs["threshold"] postprocess_kwargs["threshold"] = kwargs["threshold"]
if "mask_threshold" in kwargs: if "mask_threshold" in kwargs:
...@@ -93,7 +93,7 @@ class ImageSegmentationPipeline(Pipeline): ...@@ -93,7 +93,7 @@ class ImageSegmentationPipeline(Pipeline):
if "overlap_mask_area_threshold" in kwargs: if "overlap_mask_area_threshold" in kwargs:
postprocess_kwargs["overlap_mask_area_threshold"] = kwargs["overlap_mask_area_threshold"] postprocess_kwargs["overlap_mask_area_threshold"] = kwargs["overlap_mask_area_threshold"]
return preprocessor_kwargs, {}, postprocess_kwargs return preprocess_kwargs, {}, postprocess_kwargs
def __call__(self, images, **kwargs) -> Union[Predictions, List[Prediction]]: def __call__(self, images, **kwargs) -> Union[Predictions, List[Prediction]]:
""" """
......
from collections import defaultdict
from typing import Optional
from ..image_utils import load_image
from ..utils import (
add_end_docstrings,
is_torch_available,
logging,
requires_backends,
)
from .base import PIPELINE_INIT_ARGS, ChunkPipeline
if is_torch_available():
import torch
from ..models.auto.modeling_auto import MODEL_FOR_MASK_GENERATION_MAPPING
logger = logging.get_logger(__name__)
@add_end_docstrings(PIPELINE_INIT_ARGS)
class MaskGenerationPipeline(ChunkPipeline):
"""
Automatic mask generation for images using `SamForMaskGeneration`. This pipeline predicts binary masks for an
image, given an image. It is a `ChunkPipeline` because you can seperate the points in a mini-batch in order to
avoid OOM issues. Use the `points_per_batch` argument to control the number of points that will be processed at the
same time. Default is `64`.
The pipeline works in 3 steps:
1. `preprocess`: A grid of 1024 points evenly separated is generated along with bounding boxes and point
labels.
For more details on how the points and bounding boxes are created, check the `_generate_crop_boxes`
function. The image is also preprocessed using the `image_processor`. This function `yields` a minibatch of
`points_per_batch`.
2. `forward`: feeds the outputs of `preprocess` to the model. The image embedding is computed only once.
Calls both `self.model.get_image_embeddings` and makes sure that the gradients are not computed, and the
tensors and models are on the same device.
3. `postprocess`: The most important part of the automatic mask generation happens here. Three steps
are induced:
- image_processor.postprocess_masks (run on each minibatch loop): takes in the raw output masks,
resizes them according
to the image size, and transforms there to binary masks.
- image_processor.filter_masks (on each minibatch loop): uses both `pred_iou_thresh` and
`stability_scores`. Also
applies a variety of filters based on non maximum suppression to remove bad masks.
- image_processor.postprocess_masks_for_amg applies the NSM on the mask to only keep relevant ones.
Arguments:
model ([`PreTrainedModel`] or [`TFPreTrainedModel`]):
The model that will be used by the pipeline to make predictions. This needs to be a model inheriting from
[`PreTrainedModel`] for PyTorch and [`TFPreTrainedModel`] for TensorFlow.
tokenizer ([`PreTrainedTokenizer`]):
The tokenizer that will be used by the pipeline to encode data for the model. This object inherits from
[`PreTrainedTokenizer`].
feature_extractor ([`SequenceFeatureExtractor`]):
The feature extractor that will be used by the pipeline to encode the input.
points_per_batch (*optional*, int, default to 64):
Sets the number of points run simultaneously by the model. Higher numbers may be faster but use more GPU
memory.
output_bboxes_mask (`bool`, *optional*, default to `False`):
Whether or not to output the bounding box predictions.
output_rle_masks (`bool`, *optional*, default to `False`):
Whether or not to output the masks in `RLE` format
Example:
```python
>>> from transformers import pipeline
>>> generator = pipeline(model="facebook/sam-vit-h", task="mask-generation")
>>> outputs = generator(
... "http://images.cocodataset.org/val2017/000000039769.jpg",
... )
>>> outputs = generator(
... "https://huggingface.co/datasets/Narsil/image_dummy/raw/main/parrots.png", points_per_batch=128
... )
```
Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial)
This segmentation pipeline can currently be loaded from [`pipeline`] using the following task identifier:
`"mask-generation"`.
See the list of available models on [huggingface.co/models](https://huggingface.co/models?filter=mask-generation).
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
requires_backends(self, "vision")
requires_backends(self, "torch")
if self.framework != "pt":
raise ValueError(f"The {self.__class__} is only available in PyTorch.")
self.check_model_type(MODEL_FOR_MASK_GENERATION_MAPPING)
def _sanitize_parameters(self, **kwargs):
preprocess_kwargs = {}
postprocess_kwargs = {}
forward_params = {}
# preprocess args
if "points_per_batch" in kwargs:
preprocess_kwargs["points_per_batch"] = kwargs["points_per_batch"]
if "points_per_crop" in kwargs:
preprocess_kwargs["points_per_crop"] = kwargs["points_per_crop"]
if "crops_n_layers" in kwargs:
preprocess_kwargs["crops_n_layers"] = kwargs["crops_n_layers"]
if "crop_overlap_ratio" in kwargs:
preprocess_kwargs["crop_overlap_ratio"] = kwargs["crop_overlap_ratio"]
if "crop_n_points_downscale_factor" in kwargs:
preprocess_kwargs["crop_n_points_downscale_factor"] = kwargs["crop_n_points_downscale_factor"]
# postprocess args
if "pred_iou_thresh" in kwargs:
forward_params["pred_iou_thresh"] = kwargs["pred_iou_thresh"]
if "stability_score_offset" in kwargs:
forward_params["stability_score_offset"] = kwargs["stability_score_offset"]
if "mask_threshold" in kwargs:
forward_params["mask_threshold"] = kwargs["mask_threshold"]
if "stability_score_thresh" in kwargs:
forward_params["stability_score_thresh"] = kwargs["stability_score_thresh"]
if "crops_nms_thresh" in kwargs:
postprocess_kwargs["crops_nms_thresh"] = kwargs["crops_nms_thresh"]
if "output_rle_mask" in kwargs:
postprocess_kwargs["output_rle_mask"] = kwargs["output_rle_mask"]
if "output_bboxes_mask" in kwargs:
postprocess_kwargs["output_bboxes_mask"] = kwargs["output_bboxes_mask"]
return preprocess_kwargs, forward_params, postprocess_kwargs
def __call__(self, image, *args, num_workers=None, batch_size=None, **kwargs):
"""
Generates binary segmentation masks
Args:
inputs (`np.ndarray` or `bytes` or `str` or `dict`):
Image or list of images.
mask_threshold (`float`, *optional*, defaults to 0.0):
Threshold to use when turning the predicted masks into binary values.
pred_iou_thresh (`float`, *optional*, defaults to 0.88):
A filtering threshold in `[0,1]` applied on the model's predicted mask quality.
stability_score_thresh (`float`, *optional*, defaults to 0.95):
A filtering threshold in `[0,1]`, using the stability of the mask under changes to the cutoff used to
binarize the model's mask predictions.
stability_score_offset (`int`, *optional*, defaults to 1):
The amount to shift the cutoff when calculated the stability score.
crops_nms_thresh (`float`, *optional*, defaults to 0.7):
The box IoU cutoff used by non-maximal suppression to filter duplicate masks.
crops_n_layers (`int`, *optional*, defaults to 0):
If `crops_n_layers>0`, mask prediction will be run again on crops of the image. Sets the number of
layers to run, where each layer has 2**i_layer number of image crops.
crop_overlap_ratio (`float`, *optional*, defaults to `512 / 1500`):
Sets the degree to which crops overlap. In the first crop layer, crops will overlap by this fraction of
the image length. Later layers with more crops scale down this overlap.
crop_n_points_downscale_factor (`int`, *optional*, defaults to `1`):
The number of points-per-side sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
Return:
`Dict`: A dictionary with the following keys:
- **mask** (`PIL.Image`) -- A binary mask of the detected object as a PIL Image of shape `(width,
height)` of the original image. Returns a mask filled with zeros if no object is found.
- **score** (*optional* `float`) -- Optionally, when the model is capable of estimating a confidence of
the "object" described by the label and the mask.
"""
return super().__call__(image, *args, num_workers=num_workers, batch_size=batch_size, **kwargs)
def preprocess(
self,
image,
points_per_batch=64,
crops_n_layers: int = 0,
crop_overlap_ratio: float = 512 / 1500,
points_per_crop: Optional[int] = 32,
crop_n_points_downscale_factor: Optional[int] = 1,
):
image = load_image(image)
target_size = self.image_processor.size["longest_edge"]
crop_boxes, grid_points, cropped_images, input_labels = self.image_processor.generate_crop_boxes(
image, target_size, crops_n_layers, crop_overlap_ratio, points_per_crop, crop_n_points_downscale_factor
)
model_inputs = self.image_processor(images=cropped_images, return_tensors="pt")
with self.device_placement():
if self.framework == "pt":
inference_context = self.get_inference_context()
with inference_context():
model_inputs = self._ensure_tensor_on_device(model_inputs, device=self.device)
image_embeddings = self.model.get_image_embeddings(model_inputs.pop("pixel_values"))
model_inputs["image_embeddings"] = image_embeddings
n_points = grid_points.shape[1]
points_per_batch = points_per_batch if points_per_batch is not None else n_points
if points_per_batch <= 0:
raise ValueError(
"Cannot have points_per_batch<=0. Must be >=1 to returned batched outputs. "
"To return all points at once, set points_per_batch to None"
)
for i in range(0, n_points, points_per_batch):
batched_points = grid_points[:, i : i + points_per_batch, :, :]
labels = input_labels[:, i : i + points_per_batch]
is_last = i == n_points - points_per_batch
yield {
"input_points": batched_points,
"input_labels": labels,
"input_boxes": crop_boxes,
"is_last": is_last,
**model_inputs,
}
def _forward(
self,
model_inputs,
pred_iou_thresh=0.88,
stability_score_thresh=0.95,
mask_threshold=0,
stability_score_offset=1,
):
input_boxes = model_inputs.pop("input_boxes")
is_last = model_inputs.pop("is_last")
original_sizes = model_inputs.pop("original_sizes").tolist()
reshaped_input_sizes = model_inputs.pop("reshaped_input_sizes").tolist()
model_outputs = self.model(**model_inputs)
# post processing happens here in order to avoid CPU GPU copies of ALL the masks
low_resolution_masks = model_outputs["pred_masks"]
masks = self.image_processor.post_process_masks(
low_resolution_masks, original_sizes, reshaped_input_sizes, mask_threshold, binarize=False
)
iou_scores = model_outputs["iou_scores"]
masks, iou_scores, boxes = self.image_processor.filter_masks(
masks[0],
iou_scores[0],
original_sizes[0],
input_boxes[0],
pred_iou_thresh,
stability_score_thresh,
mask_threshold,
stability_score_offset,
)
return {
"masks": masks,
"is_last": is_last,
"boxes": boxes,
"iou_scores": iou_scores,
}
def postprocess(
self,
model_outputs,
output_rle_mask=False,
output_bboxes_mask=False,
crops_nms_thresh=0.7,
):
all_scores = []
all_masks = []
all_boxes = []
for model_output in model_outputs:
all_scores.append(model_output.pop("iou_scores"))
all_masks.extend(model_output.pop("masks"))
all_boxes.append(model_output.pop("boxes"))
all_scores = torch.cat(all_scores)
all_boxes = torch.cat(all_boxes)
output_masks, iou_scores, rle_mask, bounding_boxes = self.image_processor.post_process_for_mask_generation(
all_masks, all_scores, all_boxes, crops_nms_thresh
)
extra = defaultdict(list)
for output in model_outputs:
for k, v in output.items():
extra[k].append(v)
optional = {}
if output_rle_mask:
optional["rle_mask"] = rle_mask
if output_bboxes_mask:
optional["bounding_boxes"] = bounding_boxes
return {"masks": output_masks, "scores": iou_scores, **optional, **extra}
...@@ -475,6 +475,9 @@ MODEL_FOR_IMAGE_SEGMENTATION_MAPPING = None ...@@ -475,6 +475,9 @@ MODEL_FOR_IMAGE_SEGMENTATION_MAPPING = None
MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING = None MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING = None
MODEL_FOR_MASK_GENERATION_MAPPING = None
MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING = None MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING = None
......
...@@ -464,8 +464,8 @@ class SamModelIntegrationTest(unittest.TestCase): ...@@ -464,8 +464,8 @@ class SamModelIntegrationTest(unittest.TestCase):
self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.5798), atol=1e-4)) self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.5798), atol=1e-4))
def test_inference_mask_generation_one_point_one_bb(self): def test_inference_mask_generation_one_point_one_bb(self):
model = SamModel.from_pretrained("facebook/sam-vit-h") model = SamModel.from_pretrained("facebook/sam-vit-huge")
processor = SamProcessor.from_pretrained("facebook/sam-vit-h") processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
......
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import hashlib
import unittest
from typing import Dict
import numpy as np
from transformers import MODEL_FOR_MASK_GENERATION_MAPPING, is_vision_available, pipeline
from transformers.pipelines import MaskGenerationPipeline
from transformers.testing_utils import (
is_pipeline_test,
nested_simplify,
require_tf,
require_torch,
require_vision,
slow,
)
if is_vision_available():
from PIL import Image
def hashimage(image: Image) -> str:
m = hashlib.md5(image.tobytes())
return m.hexdigest()[:10]
def mask_to_test_readable(mask: Image) -> Dict:
npimg = np.array(mask)
shape = npimg.shape
return {"hash": hashimage(mask), "shape": shape}
@is_pipeline_test
@require_vision
@require_torch
class MaskGenerationPipelineTests(unittest.TestCase):
model_mapping = dict(
(list(MODEL_FOR_MASK_GENERATION_MAPPING.items()) if MODEL_FOR_MASK_GENERATION_MAPPING else [])
)
def get_test_pipeline(self, model, tokenizer, processor):
image_segmenter = MaskGenerationPipeline(model=model, image_processor=processor)
return image_segmenter, [
"./tests/fixtures/tests_samples/COCO/000000039769.png",
"./tests/fixtures/tests_samples/COCO/000000039769.png",
]
@require_tf
@unittest.skip("Image segmentation not implemented in TF")
def test_small_model_tf(self):
pass
@slow
@require_torch
def test_small_model_pt(self):
image_segmenter = pipeline("mask-generation", model="facebook/sam-vit-huge")
outputs = image_segmenter("http://images.cocodataset.org/val2017/000000039769.jpg", points_per_batch=256)
# Shortening by hashing
new_outupt = []
for i, o in enumerate(outputs["masks"]):
new_outupt += [{"mask": mask_to_test_readable(o), "scores": outputs["scores"][i]}]
# fmt: off
self.assertEqual(
nested_simplify(new_outupt, decimals=4),
[
{'mask': {'hash': '115ad19f5f', 'shape': (480, 640)}, 'scores': 1.0444},
{'mask': {'hash': '6affa964c6', 'shape': (480, 640)}, 'scores': 1.021},
{'mask': {'hash': 'dfe28a0388', 'shape': (480, 640)}, 'scores': 1.0167},
{'mask': {'hash': 'c0a5f4a318', 'shape': (480, 640)}, 'scores': 1.0132},
{'mask': {'hash': 'fe8065c197', 'shape': (480, 640)}, 'scores': 1.0053},
{'mask': {'hash': 'e2d0b7a0b7', 'shape': (480, 640)}, 'scores': 0.9967},
{'mask': {'hash': '453c7844bd', 'shape': (480, 640)}, 'scores': 0.993},
{'mask': {'hash': '3d44f2926d', 'shape': (480, 640)}, 'scores': 0.9909},
{'mask': {'hash': '64033ddc3f', 'shape': (480, 640)}, 'scores': 0.9879},
{'mask': {'hash': '801064ff79', 'shape': (480, 640)}, 'scores': 0.9834},
{'mask': {'hash': '6172f276ef', 'shape': (480, 640)}, 'scores': 0.9716},
{'mask': {'hash': 'b49e60e084', 'shape': (480, 640)}, 'scores': 0.9612},
{'mask': {'hash': 'a811e775fd', 'shape': (480, 640)}, 'scores': 0.9599},
{'mask': {'hash': 'a6a8ebcf4b', 'shape': (480, 640)}, 'scores': 0.9552},
{'mask': {'hash': '9d8257e080', 'shape': (480, 640)}, 'scores': 0.9532},
{'mask': {'hash': '32de6454a8', 'shape': (480, 640)}, 'scores': 0.9516},
{'mask': {'hash': 'af3d4af2c8', 'shape': (480, 640)}, 'scores': 0.9499},
{'mask': {'hash': '3c6db475fb', 'shape': (480, 640)}, 'scores': 0.9483},
{'mask': {'hash': 'c290813fb9', 'shape': (480, 640)}, 'scores': 0.9464},
{'mask': {'hash': 'b6f0b8f606', 'shape': (480, 640)}, 'scores': 0.943},
{'mask': {'hash': '92ce16bfdf', 'shape': (480, 640)}, 'scores': 0.943},
{'mask': {'hash': 'c749b25868', 'shape': (480, 640)}, 'scores': 0.9408},
{'mask': {'hash': 'efb6cab859', 'shape': (480, 640)}, 'scores': 0.9335},
{'mask': {'hash': '1ff2eafb30', 'shape': (480, 640)}, 'scores': 0.9326},
{'mask': {'hash': '788b798e24', 'shape': (480, 640)}, 'scores': 0.9262},
{'mask': {'hash': 'abea804f0e', 'shape': (480, 640)}, 'scores': 0.8999},
{'mask': {'hash': '7b9e8ddb73', 'shape': (480, 640)}, 'scores': 0.8986},
{'mask': {'hash': 'cd24047c8a', 'shape': (480, 640)}, 'scores': 0.8984},
{'mask': {'hash': '6943e6bcbd', 'shape': (480, 640)}, 'scores': 0.8873},
{'mask': {'hash': 'b5f47c9191', 'shape': (480, 640)}, 'scores': 0.8871}
],
)
# fmt: on
@require_torch
@slow
def test_threshold(self):
model_id = "facebook/sam-vit-huge"
image_segmenter = pipeline("mask-generation", model=model_id)
outputs = image_segmenter(
"http://images.cocodataset.org/val2017/000000039769.jpg", pred_iou_thresh=1, points_per_batch=256
)
# Shortening by hashing
new_outupt = []
for i, o in enumerate(outputs["masks"]):
new_outupt += [{"mask": mask_to_test_readable(o), "scores": outputs["scores"][i]}]
self.assertEqual(
nested_simplify(new_outupt, decimals=4),
[
{"mask": {"hash": "115ad19f5f", "shape": (480, 640)}, "scores": 1.0444},
{"mask": {"hash": "6affa964c6", "shape": (480, 640)}, "scores": 1.0210},
{"mask": {"hash": "dfe28a0388", "shape": (480, 640)}, "scores": 1.0167},
{"mask": {"hash": "c0a5f4a318", "shape": (480, 640)}, "scores": 1.0132},
{"mask": {"hash": "fe8065c197", "shape": (480, 640)}, "scores": 1.0053},
],
)
...@@ -98,6 +98,7 @@ PIPELINE_TAGS_AND_AUTO_MODELS = [ ...@@ -98,6 +98,7 @@ PIPELINE_TAGS_AND_AUTO_MODELS = [
), ),
("depth-estimation", "MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES", "AutoModelForDepthEstimation"), ("depth-estimation", "MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES", "AutoModelForDepthEstimation"),
("video-classification", "MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES", "AutoModelForVideoClassification"), ("video-classification", "MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES", "AutoModelForVideoClassification"),
("mask-generation", "MODEL_FOR_MASK_GENERATION_MAPPING_NAMES", "AutoModelForMaskGeneration"),
] ]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment