Unverified Commit 292acd71 authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

Update image processor parameters if creating with kwargs (#20866)

* Update parameters if creating with kwargs

* Shallow copy to prevent mutating input

* Pass all args in constructor dict - warnings in init

* Fix typo
parent f9e977be
......@@ -316,8 +316,17 @@ class ImageProcessingMixin(PushToHubMixin):
[`~image_processing_utils.ImageProcessingMixin`]: The image processor object instantiated from those
parameters.
"""
image_processor_dict = image_processor_dict.copy()
return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
# The `size` parameter is a dict and was previously an int or tuple in feature extractors.
# We set `size` here directly to the `image_processor_dict` so that it is converted to the appropriate
# dict within the image processor and isn't overwritten if `size` is passed in as a kwarg.
if "size" in kwargs and "size" in image_processor_dict:
image_processor_dict["size"] = kwargs.pop("size")
if "crop_size" in kwargs and "crop_size" in image_processor_dict:
image_processor_dict["crop_size"] = kwargs.pop("crop_size")
image_processor = cls(**image_processor_dict)
# Update image_processor with kwargs if needed
......
......@@ -15,7 +15,7 @@
"""Image processor class for Beit."""
import warnings
from typing import Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
......@@ -131,6 +131,17 @@ class BeitImageProcessor(BaseImageProcessor):
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
self.do_reduce_labels = do_reduce_labels
@classmethod
def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs):
"""
Overrides the `from_dict` method from the base class to make sure `reduce_labels` is updated if image processor
is created using from_dict and kwargs e.g. `BeitImageProcessor.from_pretrained(checkpoint, reduce_labels=True)`
"""
image_processor_dict = image_processor_dict.copy()
if "reduce_labels" in kwargs:
image_processor_dict["reduce_labels"] = kwargs.pop("reduce_labels")
return super().from_dict(image_processor_dict, **kwargs)
def resize(
self,
image: np.ndarray,
......
......@@ -815,6 +815,21 @@ class ConditionalDetrImageProcessor(BaseImageProcessor):
self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
self.do_pad = do_pad
@classmethod
# Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.from_dict with Detr->ConditionalDetr
def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs):
"""
Overrides the `from_dict` method from the base class to make sure parameters are updated if image processor is
created using from_dict and kwargs e.g. `ConditionalDetrImageProcessor.from_pretrained(checkpoint, size=600,
max_size=800)`
"""
image_processor_dict = image_processor_dict.copy()
if "max_size" in kwargs:
image_processor_dict["max_size"] = kwargs.pop("max_size")
if "pad_and_return_pixel_mask" in kwargs:
image_processor_dict["pad_and_return_pixel_mask"] = kwargs.pop("pad_and_return_pixel_mask")
return super().from_dict(image_processor_dict, **kwargs)
# Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.prepare_annotation with DETR->ConditionalDetr
def prepare_annotation(
self,
......
......@@ -813,6 +813,21 @@ class DeformableDetrImageProcessor(BaseImageProcessor):
self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
self.do_pad = do_pad
@classmethod
# Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.from_dict with Detr->DeformableDetr
def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs):
"""
Overrides the `from_dict` method from the base class to make sure parameters are updated if image processor is
created using from_dict and kwargs e.g. `DeformableDetrImageProcessor.from_pretrained(checkpoint, size=600,
max_size=800)`
"""
image_processor_dict = image_processor_dict.copy()
if "max_size" in kwargs:
image_processor_dict["max_size"] = kwargs.pop("max_size")
if "pad_and_return_pixel_mask" in kwargs:
image_processor_dict["pad_and_return_pixel_mask"] = kwargs.pop("pad_and_return_pixel_mask")
return super().from_dict(image_processor_dict, **kwargs)
# Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.prepare_annotation with DETR->DeformableDetr
def prepare_annotation(
self,
......
......@@ -797,6 +797,20 @@ class DetrImageProcessor(BaseImageProcessor):
self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
self.do_pad = do_pad
@classmethod
def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs):
"""
Overrides the `from_dict` method from the base class to make sure parameters are updated if image processor is
created using from_dict and kwargs e.g. `DetrImageProcessor.from_pretrained(checkpoint, size=600,
max_size=800)`
"""
image_processor_dict = image_processor_dict.copy()
if "max_size" in kwargs:
image_processor_dict["max_size"] = kwargs.pop("max_size")
if "pad_and_return_pixel_mask" in kwargs:
image_processor_dict["pad_and_return_pixel_mask"] = kwargs.pop("pad_and_return_pixel_mask")
return super().from_dict(image_processor_dict, **kwargs)
def prepare_annotation(
self,
image: np.ndarray,
......
......@@ -17,7 +17,7 @@
import math
import random
from functools import lru_cache
from typing import Dict, Iterable, List, Optional, Tuple, Union
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
import numpy as np
......@@ -293,6 +293,19 @@ class FlavaImageProcessor(BaseImageProcessor):
self.codebook_image_mean = codebook_image_mean if codebook_image_mean is not None else FLAVA_CODEBOOK_MEAN
self.codebook_image_std = codebook_image_std if codebook_image_std is not None else FLAVA_CODEBOOK_STD
@classmethod
def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs):
"""
Overrides the `from_dict` method from the base class to make sure parameters are updated if image processor is
created using from_dict and kwargs e.g. `FlavaImageProcessor.from_pretrained(checkpoint, codebook_size=600)`
"""
image_processor_dict = image_processor_dict.copy()
if "codebook_size" in kwargs:
image_processor_dict["codebook_size"] = kwargs.pop("codebook_size")
if "codebook_crop_size" in kwargs:
image_processor_dict["codebook_crop_size"] = kwargs.pop("codebook_crop_size")
return super().from_dict(image_processor_dict, **kwargs)
@lru_cache()
def masking_generator(
self,
......
......@@ -400,7 +400,7 @@ class MaskFormerImageProcessor(BaseImageProcessor):
if "size_divisibility" in kwargs:
warnings.warn(
"The `size_divisibility` argument is deprecated and will be removed in v4.27. Please use "
"`size_divisibility` instead.",
"`size_divisor` instead.",
FutureWarning,
)
size_divisor = kwargs.pop("size_divisibility")
......@@ -432,6 +432,19 @@ class MaskFormerImageProcessor(BaseImageProcessor):
self.ignore_index = ignore_index
self.reduce_labels = reduce_labels
@classmethod
def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs):
"""
Overrides the `from_dict` method from the base class to make sure parameters are updated if image processor is
created using from_dict and kwargs e.g. `MaskFormerImageProcessor.from_pretrained(checkpoint, max_size=800)`
"""
image_processor_dict = image_processor_dict.copy()
if "max_size" in kwargs:
image_processor_dict["max_size"] = kwargs.pop("max_size")
if "size_divisibility" in kwargs:
image_processor_dict["size_divisibility"] = kwargs.pop("size_divisibility")
return super().from_dict(image_processor_dict, **kwargs)
@property
def size_divisibility(self):
warnings.warn(
......
......@@ -15,7 +15,7 @@
"""Image processor class for Segformer."""
import warnings
from typing import Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
......@@ -119,6 +119,18 @@ class SegformerImageProcessor(BaseImageProcessor):
self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
self.do_reduce_labels = do_reduce_labels
@classmethod
def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs):
"""
Overrides the `from_dict` method from the base class to make sure `reduce_labels` is updated if image processor
is created using from_dict and kwargs e.g. `SegformerImageProcessor.from_pretrained(checkpoint,
reduce_labels=True)`
"""
image_processor_dict = image_processor_dict.copy()
if "reduce_labels" in kwargs:
image_processor_dict["reduce_labels"] = kwargs.pop("reduce_labels")
return super().from_dict(image_processor_dict, **kwargs)
def resize(
self,
image: np.ndarray,
......
......@@ -185,6 +185,18 @@ class ViltImageProcessor(BaseImageProcessor):
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
self.do_pad = do_pad
@classmethod
def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs):
"""
Overrides the `from_dict` method from the base class to make sure `reduce_labels` is updated if image processor
is created using from_dict and kwargs e.g. `ViltImageProcessor.from_pretrained(checkpoint,
pad_and_return_pixel_mask=False)`
"""
image_processor_dict = image_processor_dict.copy()
if "pad_and_return_pixel_mask" in kwargs:
image_processor_dict["pad_and_return_pixel_mask"] = kwargs.pop("pad_and_return_pixel_mask")
return super().from_dict(image_processor_dict, **kwargs)
def resize(
self,
image: np.ndarray,
......
......@@ -725,6 +725,21 @@ class YolosImageProcessor(BaseImageProcessor):
self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
self.do_pad = do_pad
@classmethod
# Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.from_dict with Detr->Yolos
def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs):
"""
Overrides the `from_dict` method from the base class to make sure parameters are updated if image processor is
created using from_dict and kwargs e.g. `YolosImageProcessor.from_pretrained(checkpoint, size=600,
max_size=800)`
"""
image_processor_dict = image_processor_dict.copy()
if "max_size" in kwargs:
image_processor_dict["max_size"] = kwargs.pop("max_size")
if "pad_and_return_pixel_mask" in kwargs:
image_processor_dict["pad_and_return_pixel_mask"] = kwargs.pop("pad_and_return_pixel_mask")
return super().from_dict(image_processor_dict, **kwargs)
# Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.prepare_annotation
def prepare_annotation(
self,
......
......@@ -125,6 +125,19 @@ class BeitFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
self.assertTrue(hasattr(feature_extractor, "image_mean"))
self.assertTrue(hasattr(feature_extractor, "image_std"))
def test_feat_extract_from_dict_with_kwargs(self):
feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict)
self.assertEqual(feature_extractor.size, {"height": 20, "width": 20})
self.assertEqual(feature_extractor.crop_size, {"height": 18, "width": 18})
self.assertEqual(feature_extractor.do_reduce_labels, False)
feature_extractor = self.feature_extraction_class.from_dict(
self.feat_extract_dict, size=42, crop_size=84, reduce_labels=True
)
self.assertEqual(feature_extractor.size, {"height": 42, "width": 42})
self.assertEqual(feature_extractor.crop_size, {"height": 84, "width": 84})
self.assertEqual(feature_extractor.do_reduce_labels, True)
def test_batch_feature(self):
pass
......
......@@ -135,6 +135,15 @@ class ChineseCLIPFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittes
self.assertTrue(hasattr(feature_extractor, "image_std"))
self.assertTrue(hasattr(feature_extractor, "do_convert_rgb"))
def test_feat_extract_from_dict_with_kwargs(self):
feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict)
self.assertEqual(feature_extractor.size, {"height": 224, "width": 224})
self.assertEqual(feature_extractor.crop_size, {"height": 18, "width": 18})
feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict, size=42, crop_size=84)
self.assertEqual(feature_extractor.size, {"shortest_edge": 42})
self.assertEqual(feature_extractor.crop_size, {"height": 84, "width": 84})
def test_batch_feature(self):
pass
......
......@@ -135,6 +135,15 @@ class CLIPFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
self.assertTrue(hasattr(feature_extractor, "image_std"))
self.assertTrue(hasattr(feature_extractor, "do_convert_rgb"))
def test_feat_extract_from_dict_with_kwargs(self):
feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict)
self.assertEqual(feature_extractor.size, {"shortest_edge": 20})
self.assertEqual(feature_extractor.crop_size, {"height": 18, "width": 18})
feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict, size=42, crop_size=84)
self.assertEqual(feature_extractor.size, {"shortest_edge": 42})
self.assertEqual(feature_extractor.crop_size, {"height": 84, "width": 84})
def test_batch_feature(self):
pass
......
......@@ -133,6 +133,17 @@ class ConditionalDetrFeatureExtractionTest(FeatureExtractionSavingTestMixin, uni
self.assertTrue(hasattr(feature_extractor, "do_resize"))
self.assertTrue(hasattr(feature_extractor, "size"))
def test_feat_extract_from_dict_with_kwargs(self):
feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict)
self.assertEqual(feature_extractor.size, {"shortest_edge": 18, "longest_edge": 1333})
self.assertEqual(feature_extractor.do_pad, True)
feature_extractor = self.feature_extraction_class.from_dict(
self.feat_extract_dict, size=42, max_size=84, pad_and_return_pixel_mask=False
)
self.assertEqual(feature_extractor.size, {"shortest_edge": 42, "longest_edge": 84})
self.assertEqual(feature_extractor.do_pad, False)
def test_batch_feature(self):
pass
......
......@@ -96,6 +96,13 @@ class ConvNextFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.T
self.assertTrue(hasattr(feature_extractor, "image_mean"))
self.assertTrue(hasattr(feature_extractor, "image_std"))
def test_feat_extract_from_dict_with_kwargs(self):
feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict)
self.assertEqual(feature_extractor.size, {"shortest_edge": 20})
feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict, size=42)
self.assertEqual(feature_extractor.size, {"shortest_edge": 42})
def test_batch_feature(self):
pass
......
......@@ -135,6 +135,17 @@ class DeformableDetrFeatureExtractionTest(FeatureExtractionSavingTestMixin, unit
self.assertTrue(hasattr(feature_extractor, "do_pad"))
self.assertTrue(hasattr(feature_extractor, "size"))
def test_feat_extract_from_dict_with_kwargs(self):
feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict)
self.assertEqual(feature_extractor.size, {"shortest_edge": 18, "longest_edge": 1333})
self.assertEqual(feature_extractor.do_pad, True)
feature_extractor = self.feature_extraction_class.from_dict(
self.feat_extract_dict, size=42, max_size=84, pad_and_return_pixel_mask=False
)
self.assertEqual(feature_extractor.size, {"shortest_edge": 42, "longest_edge": 84})
self.assertEqual(feature_extractor.do_pad, False)
def test_batch_feature(self):
pass
......
......@@ -103,6 +103,15 @@ class DeiTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
self.assertTrue(hasattr(feature_extractor, "image_mean"))
self.assertTrue(hasattr(feature_extractor, "image_std"))
def test_feat_extract_from_dict_with_kwargs(self):
feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict)
self.assertEqual(feature_extractor.size, {"height": 20, "width": 20})
self.assertEqual(feature_extractor.crop_size, {"height": 18, "width": 18})
feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict, size=42, crop_size=84)
self.assertEqual(feature_extractor.size, {"height": 42, "width": 42})
self.assertEqual(feature_extractor.crop_size, {"height": 84, "width": 84})
def test_batch_feature(self):
pass
......
......@@ -136,6 +136,17 @@ class DetrFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
self.assertTrue(hasattr(feature_extractor, "size"))
self.assertTrue(hasattr(feature_extractor, "do_pad"))
def test_feat_extract_from_dict_with_kwargs(self):
feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict)
self.assertEqual(feature_extractor.size, {"shortest_edge": 18, "longest_edge": 1333})
self.assertEqual(feature_extractor.do_pad, True)
feature_extractor = self.feature_extraction_class.from_dict(
self.feat_extract_dict, size=42, max_size=84, pad_and_return_pixel_mask=False
)
self.assertEqual(feature_extractor.size, {"shortest_edge": 42, "longest_edge": 84})
self.assertEqual(feature_extractor.do_pad, False)
def test_batch_feature(self):
pass
......
......@@ -103,6 +103,17 @@ class DonutFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.Test
self.assertTrue(hasattr(feature_extractor, "image_mean"))
self.assertTrue(hasattr(feature_extractor, "image_std"))
def test_feat_extract_from_dict_with_kwargs(self):
feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict)
self.assertEqual(feature_extractor.size, {"height": 18, "width": 20})
feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict, size=42)
self.assertEqual(feature_extractor.size, {"height": 42, "width": 42})
# Previous config had dimensions in (width, height) order
feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict, size=(42, 84))
self.assertEqual(feature_extractor.size, {"height": 84, "width": 42})
def test_batch_feature(self):
pass
......
......@@ -92,6 +92,13 @@ class DPTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCa
self.assertTrue(hasattr(feature_extractor, "do_resize"))
self.assertTrue(hasattr(feature_extractor, "size"))
def test_feat_extract_from_dict_with_kwargs(self):
feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict)
self.assertEqual(feature_extractor.size, {"height": 18, "width": 18})
feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict, size=42)
self.assertEqual(feature_extractor.size, {"height": 42, "width": 42})
def test_call_pil(self):
# Initialize feature_extractor
feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment