Unverified Commit de6d19ea authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

Add segmentation + object detection image processors (#20160)

* Add transforms for object detection

* DETR models + Yolos

* Scrappy additions

* Maskformer image processor

* Fix up; MaskFormer tests

* Update owlvit processor

* Add to docs

* OwlViT tests

* Update pad logic

* Remove changes to transforms

* Import fn directly

* Update to include pad transformation

* Remove uninstended changes

* Add new owlvit post processing function

* Tidy up

* Fix copies

* Fix some copies

* Include device fix

* Fix scipy imports

* Update _pad_image

* Update padding functionality

* Fix bug

* Properly handle ignore index

* Fix up

* Remove defaults to None in docstrings

* Fix docstrings & docs

* Fix sizes bug

* Resolve conflicts in init

* Cast to float after resizing

* Tidy & add size if missing

* Allow kwards when processing for owlvit

* Update test values
parent ae3cbc95
......@@ -47,6 +47,7 @@ except OptionalDependencyNotAvailable:
pass
else:
_import_structure["feature_extraction_owlvit"] = ["OwlViTFeatureExtractor"]
_import_structure["image_processing_owlvit"] = ["OwlViTImageProcessor"]
try:
if not is_torch_available():
......@@ -80,6 +81,7 @@ if TYPE_CHECKING:
pass
else:
from .feature_extraction_owlvit import OwlViTFeatureExtractor
from .image_processing_owlvit import OwlViTImageProcessor
try:
if not is_torch_available():
......
This diff is collapsed.
......@@ -15,6 +15,7 @@
"""
Image/Text processor class for OWL-ViT
"""
from typing import List
import numpy as np
......@@ -33,15 +34,15 @@ class OwlViTProcessor(ProcessorMixin):
Args:
feature_extractor ([`OwlViTFeatureExtractor`]):
The feature extractor is a required input.
The image processor is a required input.
tokenizer ([`CLIPTokenizer`, `CLIPTokenizerFast`]):
The tokenizer is a required input.
"""
feature_extractor_class = "OwlViTFeatureExtractor"
tokenizer_class = ("CLIPTokenizer", "CLIPTokenizerFast")
def __init__(self, feature_extractor, tokenizer):
super().__init__(feature_extractor, tokenizer)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def __call__(self, text=None, images=None, query_images=None, padding="max_length", return_tensors="np", **kwargs):
"""
......
......@@ -287,7 +287,6 @@ class SegformerImageProcessor(BaseImageProcessor):
do_reduce_labels: bool = None,
do_resize: bool = None,
size: Dict[str, int] = None,
resample: PILImageResampling = None,
) -> np.ndarray:
"""Preprocesses a single mask."""
segmentation_map = to_numpy_array(segmentation_map)
......@@ -301,7 +300,7 @@ class SegformerImageProcessor(BaseImageProcessor):
image=segmentation_map,
do_reduce_labels=do_reduce_labels,
do_resize=do_resize,
resample=PIL.Image.NEAREST,
resample=PILImageResampling.NEAREST,
size=size,
do_rescale=False,
do_normalize=False,
......@@ -438,7 +437,6 @@ class SegformerImageProcessor(BaseImageProcessor):
segmentation_map=segmentation_map,
do_reduce_labels=do_reduce_labels,
do_resize=do_resize,
resample=PIL.Image.NEAREST,
size=size,
)
for segmentation_map in segmentation_maps
......
......@@ -29,6 +29,7 @@ except OptionalDependencyNotAvailable:
pass
else:
_import_structure["feature_extraction_yolos"] = ["YolosFeatureExtractor"]
_import_structure["image_processing_yolos"] = ["YolosImageProcessor"]
try:
if not is_torch_available():
......@@ -54,6 +55,7 @@ if TYPE_CHECKING:
pass
else:
from .feature_extraction_yolos import YolosFeatureExtractor
from .image_processing_yolos import YolosImageProcessor
try:
if not is_torch_available():
......
This diff is collapsed.
......@@ -64,6 +64,13 @@ class ConditionalDetrFeatureExtractor(metaclass=DummyObject):
requires_backends(self, ["vision"])
class ConditionalDetrImageProcessor(metaclass=DummyObject):
_backends = ["vision"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["vision"])
class ConvNextFeatureExtractor(metaclass=DummyObject):
_backends = ["vision"]
......@@ -85,6 +92,13 @@ class DeformableDetrFeatureExtractor(metaclass=DummyObject):
requires_backends(self, ["vision"])
class DeformableDetrImageProcessor(metaclass=DummyObject):
_backends = ["vision"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["vision"])
class DeiTFeatureExtractor(metaclass=DummyObject):
_backends = ["vision"]
......@@ -106,6 +120,13 @@ class DetrFeatureExtractor(metaclass=DummyObject):
requires_backends(self, ["vision"])
class DetrImageProcessor(metaclass=DummyObject):
_backends = ["vision"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["vision"])
class DonutFeatureExtractor(metaclass=DummyObject):
_backends = ["vision"]
......@@ -232,6 +253,13 @@ class MaskFormerFeatureExtractor(metaclass=DummyObject):
requires_backends(self, ["vision"])
class MaskFormerImageProcessor(metaclass=DummyObject):
_backends = ["vision"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["vision"])
class MobileNetV1FeatureExtractor(metaclass=DummyObject):
_backends = ["vision"]
......@@ -281,6 +309,13 @@ class OwlViTFeatureExtractor(metaclass=DummyObject):
requires_backends(self, ["vision"])
class OwlViTImageProcessor(metaclass=DummyObject):
_backends = ["vision"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["vision"])
class PerceiverFeatureExtractor(metaclass=DummyObject):
_backends = ["vision"]
......@@ -377,3 +412,10 @@ class YolosFeatureExtractor(metaclass=DummyObject):
def __init__(self, *args, **kwargs):
requires_backends(self, ["vision"])
class YolosImageProcessor(metaclass=DummyObject):
_backends = ["vision"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["vision"])
......@@ -44,12 +44,16 @@ class ConditionalDetrFeatureExtractionTester(unittest.TestCase):
min_resolution=30,
max_resolution=400,
do_resize=True,
size=18,
max_size=1333, # by setting max_size > max_resolution we're effectively not testing this :p
size=None,
do_normalize=True,
image_mean=[0.5, 0.5, 0.5],
image_std=[0.5, 0.5, 0.5],
do_rescale=True,
rescale_factor=1 / 255,
do_pad=True,
):
# by setting size["longest_edge"] > max_resolution we're effectively not testing this :p
size = size if size is not None else {"shortest_edge": 18, "longest_edge": 1333}
self.parent = parent
self.batch_size = batch_size
self.num_channels = num_channels
......@@ -57,19 +61,23 @@ class ConditionalDetrFeatureExtractionTester(unittest.TestCase):
self.max_resolution = max_resolution
self.do_resize = do_resize
self.size = size
self.max_size = max_size
self.do_normalize = do_normalize
self.image_mean = image_mean
self.image_std = image_std
self.do_rescale = do_rescale
self.rescale_factor = rescale_factor
self.do_pad = do_pad
def prepare_feat_extract_dict(self):
return {
"do_resize": self.do_resize,
"size": self.size,
"max_size": self.max_size,
"do_normalize": self.do_normalize,
"image_mean": self.image_mean,
"image_std": self.image_std,
"do_rescale": self.do_rescale,
"rescale_factor": self.rescale_factor,
"do_pad": self.do_pad,
}
def get_expected_values(self, image_inputs, batched=False):
......@@ -84,14 +92,14 @@ class ConditionalDetrFeatureExtractionTester(unittest.TestCase):
else:
h, w = image.shape[1], image.shape[2]
if w < h:
expected_height = int(self.size * h / w)
expected_width = self.size
expected_height = int(self.size["shortest_edge"] * h / w)
expected_width = self.size["shortest_edge"]
elif w > h:
expected_height = self.size
expected_width = int(self.size * w / h)
expected_height = self.size["shortest_edge"]
expected_width = int(self.size["shortest_edge"] * w / h)
else:
expected_height = self.size
expected_width = self.size
expected_height = self.size["shortest_edge"]
expected_width = self.size["shortest_edge"]
else:
expected_values = []
......@@ -124,7 +132,6 @@ class ConditionalDetrFeatureExtractionTest(FeatureExtractionSavingTestMixin, uni
self.assertTrue(hasattr(feature_extractor, "do_normalize"))
self.assertTrue(hasattr(feature_extractor, "do_resize"))
self.assertTrue(hasattr(feature_extractor, "size"))
self.assertTrue(hasattr(feature_extractor, "max_size"))
def test_batch_feature(self):
pass
......@@ -230,7 +237,7 @@ class ConditionalDetrFeatureExtractionTest(FeatureExtractionSavingTestMixin, uni
def test_equivalence_pad_and_create_pixel_mask(self):
# Initialize feature_extractors
feature_extractor_1 = self.feature_extraction_class(**self.feat_extract_dict)
feature_extractor_2 = self.feature_extraction_class(do_resize=False, do_normalize=False)
feature_extractor_2 = self.feature_extraction_class(do_resize=False, do_normalize=False, do_rescale=False)
# create random PyTorch tensors
image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False, torchify=True)
for image in image_inputs:
......@@ -331,7 +338,7 @@ class ConditionalDetrFeatureExtractionTest(FeatureExtractionSavingTestMixin, uni
expected_class_labels = torch.tensor([17, 17, 63, 75, 75, 93])
self.assertTrue(torch.allclose(encoding["labels"][0]["class_labels"], expected_class_labels))
# verify masks
expected_masks_sum = 822338
expected_masks_sum = 822873
self.assertEqual(encoding["labels"][0]["masks"].sum().item(), expected_masks_sum)
# verify orig_size
expected_orig_size = torch.tensor([480, 640])
......
......@@ -43,9 +43,9 @@ class OwlViTFeatureExtractionTester(unittest.TestCase):
min_resolution=30,
max_resolution=400,
do_resize=True,
size=20,
size=None,
do_center_crop=True,
crop_size=18,
crop_size=None,
do_normalize=True,
image_mean=[0.48145466, 0.4578275, 0.40821073],
image_std=[0.26862954, 0.26130258, 0.27577711],
......@@ -58,9 +58,9 @@ class OwlViTFeatureExtractionTester(unittest.TestCase):
self.min_resolution = min_resolution
self.max_resolution = max_resolution
self.do_resize = do_resize
self.size = size
self.size = size if size is not None else {"height": 18, "width": 18}
self.do_center_crop = do_center_crop
self.crop_size = crop_size
self.crop_size = crop_size if crop_size is not None else {"height": 18, "width": 18}
self.do_normalize = do_normalize
self.image_mean = image_mean
self.image_std = image_std
......@@ -119,8 +119,8 @@ class OwlViTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.Tes
(
1,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.crop_size,
self.feature_extract_tester.crop_size,
self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.crop_size["width"],
),
)
......@@ -131,8 +131,8 @@ class OwlViTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.Tes
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.crop_size,
self.feature_extract_tester.crop_size,
self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.crop_size["width"],
),
)
......@@ -151,8 +151,8 @@ class OwlViTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.Tes
(
1,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.crop_size,
self.feature_extract_tester.crop_size,
self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.crop_size["width"],
),
)
......@@ -163,8 +163,8 @@ class OwlViTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.Tes
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.crop_size,
self.feature_extract_tester.crop_size,
self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.crop_size["width"],
),
)
......@@ -183,8 +183,8 @@ class OwlViTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.Tes
(
1,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.crop_size,
self.feature_extract_tester.crop_size,
self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.crop_size["width"],
),
)
......@@ -195,7 +195,7 @@ class OwlViTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.Tes
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.crop_size,
self.feature_extract_tester.crop_size,
self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.crop_size["width"],
),
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment