"git@developer.sourcefind.cn:wangsen/paddle_dbnet.git" did not exist on "c645a86588ade1c81050fa6549675c9ae583605d"
Unverified Commit 292acd71 authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

Update image processor parameters if creating with kwargs (#20866)

* Update parameters if creating with kwargs

* Shallow copy to prevent mutating input

* Pass all args in constructor dict - warnings in init

* Fix typo
parent f9e977be
...@@ -316,8 +316,17 @@ class ImageProcessingMixin(PushToHubMixin): ...@@ -316,8 +316,17 @@ class ImageProcessingMixin(PushToHubMixin):
[`~image_processing_utils.ImageProcessingMixin`]: The image processor object instantiated from those [`~image_processing_utils.ImageProcessingMixin`]: The image processor object instantiated from those
parameters. parameters.
""" """
image_processor_dict = image_processor_dict.copy()
return_unused_kwargs = kwargs.pop("return_unused_kwargs", False) return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
# The `size` parameter is a dict and was previously an int or tuple in feature extractors.
# We set `size` here directly to the `image_processor_dict` so that it is converted to the appropriate
# dict within the image processor and isn't overwritten if `size` is passed in as a kwarg.
if "size" in kwargs and "size" in image_processor_dict:
image_processor_dict["size"] = kwargs.pop("size")
if "crop_size" in kwargs and "crop_size" in image_processor_dict:
image_processor_dict["crop_size"] = kwargs.pop("crop_size")
image_processor = cls(**image_processor_dict) image_processor = cls(**image_processor_dict)
# Update image_processor with kwargs if needed # Update image_processor with kwargs if needed
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
"""Image processor class for Beit.""" """Image processor class for Beit."""
import warnings import warnings
from typing import Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
...@@ -131,6 +131,17 @@ class BeitImageProcessor(BaseImageProcessor): ...@@ -131,6 +131,17 @@ class BeitImageProcessor(BaseImageProcessor):
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
self.do_reduce_labels = do_reduce_labels self.do_reduce_labels = do_reduce_labels
@classmethod
def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs):
"""
Overrides the `from_dict` method from the base class to make sure `reduce_labels` is updated if image processor
is created using from_dict and kwargs e.g. `BeitImageProcessor.from_pretrained(checkpoint, reduce_labels=True)`
"""
image_processor_dict = image_processor_dict.copy()
if "reduce_labels" in kwargs:
image_processor_dict["reduce_labels"] = kwargs.pop("reduce_labels")
return super().from_dict(image_processor_dict, **kwargs)
def resize( def resize(
self, self,
image: np.ndarray, image: np.ndarray,
......
...@@ -815,6 +815,21 @@ class ConditionalDetrImageProcessor(BaseImageProcessor): ...@@ -815,6 +815,21 @@ class ConditionalDetrImageProcessor(BaseImageProcessor):
self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
self.do_pad = do_pad self.do_pad = do_pad
@classmethod
# Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.from_dict with Detr->ConditionalDetr
def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs):
"""
Overrides the `from_dict` method from the base class to make sure parameters are updated if image processor is
created using from_dict and kwargs e.g. `ConditionalDetrImageProcessor.from_pretrained(checkpoint, size=600,
max_size=800)`
"""
image_processor_dict = image_processor_dict.copy()
if "max_size" in kwargs:
image_processor_dict["max_size"] = kwargs.pop("max_size")
if "pad_and_return_pixel_mask" in kwargs:
image_processor_dict["pad_and_return_pixel_mask"] = kwargs.pop("pad_and_return_pixel_mask")
return super().from_dict(image_processor_dict, **kwargs)
# Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.prepare_annotation with DETR->ConditionalDetr # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.prepare_annotation with DETR->ConditionalDetr
def prepare_annotation( def prepare_annotation(
self, self,
......
...@@ -813,6 +813,21 @@ class DeformableDetrImageProcessor(BaseImageProcessor): ...@@ -813,6 +813,21 @@ class DeformableDetrImageProcessor(BaseImageProcessor):
self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
self.do_pad = do_pad self.do_pad = do_pad
@classmethod
# Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.from_dict with Detr->DeformableDetr
def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs):
"""
Overrides the `from_dict` method from the base class to make sure parameters are updated if image processor is
created using from_dict and kwargs e.g. `DeformableDetrImageProcessor.from_pretrained(checkpoint, size=600,
max_size=800)`
"""
image_processor_dict = image_processor_dict.copy()
if "max_size" in kwargs:
image_processor_dict["max_size"] = kwargs.pop("max_size")
if "pad_and_return_pixel_mask" in kwargs:
image_processor_dict["pad_and_return_pixel_mask"] = kwargs.pop("pad_and_return_pixel_mask")
return super().from_dict(image_processor_dict, **kwargs)
# Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.prepare_annotation with DETR->DeformableDetr # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.prepare_annotation with DETR->DeformableDetr
def prepare_annotation( def prepare_annotation(
self, self,
......
...@@ -797,6 +797,20 @@ class DetrImageProcessor(BaseImageProcessor): ...@@ -797,6 +797,20 @@ class DetrImageProcessor(BaseImageProcessor):
self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
self.do_pad = do_pad self.do_pad = do_pad
@classmethod
def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs):
"""
Overrides the `from_dict` method from the base class to make sure parameters are updated if image processor is
created using from_dict and kwargs e.g. `DetrImageProcessor.from_pretrained(checkpoint, size=600,
max_size=800)`
"""
image_processor_dict = image_processor_dict.copy()
if "max_size" in kwargs:
image_processor_dict["max_size"] = kwargs.pop("max_size")
if "pad_and_return_pixel_mask" in kwargs:
image_processor_dict["pad_and_return_pixel_mask"] = kwargs.pop("pad_and_return_pixel_mask")
return super().from_dict(image_processor_dict, **kwargs)
def prepare_annotation( def prepare_annotation(
self, self,
image: np.ndarray, image: np.ndarray,
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
import math import math
import random import random
from functools import lru_cache from functools import lru_cache
from typing import Dict, Iterable, List, Optional, Tuple, Union from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
import numpy as np import numpy as np
...@@ -293,6 +293,19 @@ class FlavaImageProcessor(BaseImageProcessor): ...@@ -293,6 +293,19 @@ class FlavaImageProcessor(BaseImageProcessor):
self.codebook_image_mean = codebook_image_mean if codebook_image_mean is not None else FLAVA_CODEBOOK_MEAN self.codebook_image_mean = codebook_image_mean if codebook_image_mean is not None else FLAVA_CODEBOOK_MEAN
self.codebook_image_std = codebook_image_std if codebook_image_std is not None else FLAVA_CODEBOOK_STD self.codebook_image_std = codebook_image_std if codebook_image_std is not None else FLAVA_CODEBOOK_STD
@classmethod
def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs):
"""
Overrides the `from_dict` method from the base class to make sure parameters are updated if image processor is
created using from_dict and kwargs e.g. `FlavaImageProcessor.from_pretrained(checkpoint, codebook_size=600)`
"""
image_processor_dict = image_processor_dict.copy()
if "codebook_size" in kwargs:
image_processor_dict["codebook_size"] = kwargs.pop("codebook_size")
if "codebook_crop_size" in kwargs:
image_processor_dict["codebook_crop_size"] = kwargs.pop("codebook_crop_size")
return super().from_dict(image_processor_dict, **kwargs)
@lru_cache() @lru_cache()
def masking_generator( def masking_generator(
self, self,
......
...@@ -400,7 +400,7 @@ class MaskFormerImageProcessor(BaseImageProcessor): ...@@ -400,7 +400,7 @@ class MaskFormerImageProcessor(BaseImageProcessor):
if "size_divisibility" in kwargs: if "size_divisibility" in kwargs:
warnings.warn( warnings.warn(
"The `size_divisibility` argument is deprecated and will be removed in v4.27. Please use " "The `size_divisibility` argument is deprecated and will be removed in v4.27. Please use "
"`size_divisibility` instead.", "`size_divisor` instead.",
FutureWarning, FutureWarning,
) )
size_divisor = kwargs.pop("size_divisibility") size_divisor = kwargs.pop("size_divisibility")
...@@ -432,6 +432,19 @@ class MaskFormerImageProcessor(BaseImageProcessor): ...@@ -432,6 +432,19 @@ class MaskFormerImageProcessor(BaseImageProcessor):
self.ignore_index = ignore_index self.ignore_index = ignore_index
self.reduce_labels = reduce_labels self.reduce_labels = reduce_labels
@classmethod
def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs):
"""
Overrides the `from_dict` method from the base class to make sure parameters are updated if image processor is
created using from_dict and kwargs e.g. `MaskFormerImageProcessor.from_pretrained(checkpoint, max_size=800)`
"""
image_processor_dict = image_processor_dict.copy()
if "max_size" in kwargs:
image_processor_dict["max_size"] = kwargs.pop("max_size")
if "size_divisibility" in kwargs:
image_processor_dict["size_divisibility"] = kwargs.pop("size_divisibility")
return super().from_dict(image_processor_dict, **kwargs)
@property @property
def size_divisibility(self): def size_divisibility(self):
warnings.warn( warnings.warn(
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
"""Image processor class for Segformer.""" """Image processor class for Segformer."""
import warnings import warnings
from typing import Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
...@@ -119,6 +119,18 @@ class SegformerImageProcessor(BaseImageProcessor): ...@@ -119,6 +119,18 @@ class SegformerImageProcessor(BaseImageProcessor):
self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
self.do_reduce_labels = do_reduce_labels self.do_reduce_labels = do_reduce_labels
@classmethod
def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs):
"""
Overrides the `from_dict` method from the base class to make sure `reduce_labels` is updated if image processor
is created using from_dict and kwargs e.g. `SegformerImageProcessor.from_pretrained(checkpoint,
reduce_labels=True)`
"""
image_processor_dict = image_processor_dict.copy()
if "reduce_labels" in kwargs:
image_processor_dict["reduce_labels"] = kwargs.pop("reduce_labels")
return super().from_dict(image_processor_dict, **kwargs)
def resize( def resize(
self, self,
image: np.ndarray, image: np.ndarray,
......
...@@ -185,6 +185,18 @@ class ViltImageProcessor(BaseImageProcessor): ...@@ -185,6 +185,18 @@ class ViltImageProcessor(BaseImageProcessor):
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
self.do_pad = do_pad self.do_pad = do_pad
@classmethod
def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs):
"""
Overrides the `from_dict` method from the base class to make sure `reduce_labels` is updated if image processor
is created using from_dict and kwargs e.g. `ViltImageProcessor.from_pretrained(checkpoint,
pad_and_return_pixel_mask=False)`
"""
image_processor_dict = image_processor_dict.copy()
if "pad_and_return_pixel_mask" in kwargs:
image_processor_dict["pad_and_return_pixel_mask"] = kwargs.pop("pad_and_return_pixel_mask")
return super().from_dict(image_processor_dict, **kwargs)
def resize( def resize(
self, self,
image: np.ndarray, image: np.ndarray,
......
...@@ -725,6 +725,21 @@ class YolosImageProcessor(BaseImageProcessor): ...@@ -725,6 +725,21 @@ class YolosImageProcessor(BaseImageProcessor):
self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
self.do_pad = do_pad self.do_pad = do_pad
@classmethod
# Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.from_dict with Detr->Yolos
def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs):
"""
Overrides the `from_dict` method from the base class to make sure parameters are updated if image processor is
created using from_dict and kwargs e.g. `YolosImageProcessor.from_pretrained(checkpoint, size=600,
max_size=800)`
"""
image_processor_dict = image_processor_dict.copy()
if "max_size" in kwargs:
image_processor_dict["max_size"] = kwargs.pop("max_size")
if "pad_and_return_pixel_mask" in kwargs:
image_processor_dict["pad_and_return_pixel_mask"] = kwargs.pop("pad_and_return_pixel_mask")
return super().from_dict(image_processor_dict, **kwargs)
# Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.prepare_annotation # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.prepare_annotation
def prepare_annotation( def prepare_annotation(
self, self,
......
...@@ -125,6 +125,19 @@ class BeitFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC ...@@ -125,6 +125,19 @@ class BeitFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
self.assertTrue(hasattr(feature_extractor, "image_mean")) self.assertTrue(hasattr(feature_extractor, "image_mean"))
self.assertTrue(hasattr(feature_extractor, "image_std")) self.assertTrue(hasattr(feature_extractor, "image_std"))
def test_feat_extract_from_dict_with_kwargs(self):
feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict)
self.assertEqual(feature_extractor.size, {"height": 20, "width": 20})
self.assertEqual(feature_extractor.crop_size, {"height": 18, "width": 18})
self.assertEqual(feature_extractor.do_reduce_labels, False)
feature_extractor = self.feature_extraction_class.from_dict(
self.feat_extract_dict, size=42, crop_size=84, reduce_labels=True
)
self.assertEqual(feature_extractor.size, {"height": 42, "width": 42})
self.assertEqual(feature_extractor.crop_size, {"height": 84, "width": 84})
self.assertEqual(feature_extractor.do_reduce_labels, True)
def test_batch_feature(self): def test_batch_feature(self):
pass pass
......
...@@ -135,6 +135,15 @@ class ChineseCLIPFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittes ...@@ -135,6 +135,15 @@ class ChineseCLIPFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittes
self.assertTrue(hasattr(feature_extractor, "image_std")) self.assertTrue(hasattr(feature_extractor, "image_std"))
self.assertTrue(hasattr(feature_extractor, "do_convert_rgb")) self.assertTrue(hasattr(feature_extractor, "do_convert_rgb"))
def test_feat_extract_from_dict_with_kwargs(self):
feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict)
self.assertEqual(feature_extractor.size, {"height": 224, "width": 224})
self.assertEqual(feature_extractor.crop_size, {"height": 18, "width": 18})
feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict, size=42, crop_size=84)
self.assertEqual(feature_extractor.size, {"shortest_edge": 42})
self.assertEqual(feature_extractor.crop_size, {"height": 84, "width": 84})
def test_batch_feature(self): def test_batch_feature(self):
pass pass
......
...@@ -135,6 +135,15 @@ class CLIPFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC ...@@ -135,6 +135,15 @@ class CLIPFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
self.assertTrue(hasattr(feature_extractor, "image_std")) self.assertTrue(hasattr(feature_extractor, "image_std"))
self.assertTrue(hasattr(feature_extractor, "do_convert_rgb")) self.assertTrue(hasattr(feature_extractor, "do_convert_rgb"))
def test_feat_extract_from_dict_with_kwargs(self):
feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict)
self.assertEqual(feature_extractor.size, {"shortest_edge": 20})
self.assertEqual(feature_extractor.crop_size, {"height": 18, "width": 18})
feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict, size=42, crop_size=84)
self.assertEqual(feature_extractor.size, {"shortest_edge": 42})
self.assertEqual(feature_extractor.crop_size, {"height": 84, "width": 84})
def test_batch_feature(self): def test_batch_feature(self):
pass pass
......
...@@ -133,6 +133,17 @@ class ConditionalDetrFeatureExtractionTest(FeatureExtractionSavingTestMixin, uni ...@@ -133,6 +133,17 @@ class ConditionalDetrFeatureExtractionTest(FeatureExtractionSavingTestMixin, uni
self.assertTrue(hasattr(feature_extractor, "do_resize")) self.assertTrue(hasattr(feature_extractor, "do_resize"))
self.assertTrue(hasattr(feature_extractor, "size")) self.assertTrue(hasattr(feature_extractor, "size"))
def test_feat_extract_from_dict_with_kwargs(self):
feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict)
self.assertEqual(feature_extractor.size, {"shortest_edge": 18, "longest_edge": 1333})
self.assertEqual(feature_extractor.do_pad, True)
feature_extractor = self.feature_extraction_class.from_dict(
self.feat_extract_dict, size=42, max_size=84, pad_and_return_pixel_mask=False
)
self.assertEqual(feature_extractor.size, {"shortest_edge": 42, "longest_edge": 84})
self.assertEqual(feature_extractor.do_pad, False)
def test_batch_feature(self): def test_batch_feature(self):
pass pass
......
...@@ -96,6 +96,13 @@ class ConvNextFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.T ...@@ -96,6 +96,13 @@ class ConvNextFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.T
self.assertTrue(hasattr(feature_extractor, "image_mean")) self.assertTrue(hasattr(feature_extractor, "image_mean"))
self.assertTrue(hasattr(feature_extractor, "image_std")) self.assertTrue(hasattr(feature_extractor, "image_std"))
def test_feat_extract_from_dict_with_kwargs(self):
feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict)
self.assertEqual(feature_extractor.size, {"shortest_edge": 20})
feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict, size=42)
self.assertEqual(feature_extractor.size, {"shortest_edge": 42})
def test_batch_feature(self): def test_batch_feature(self):
pass pass
......
...@@ -135,6 +135,17 @@ class DeformableDetrFeatureExtractionTest(FeatureExtractionSavingTestMixin, unit ...@@ -135,6 +135,17 @@ class DeformableDetrFeatureExtractionTest(FeatureExtractionSavingTestMixin, unit
self.assertTrue(hasattr(feature_extractor, "do_pad")) self.assertTrue(hasattr(feature_extractor, "do_pad"))
self.assertTrue(hasattr(feature_extractor, "size")) self.assertTrue(hasattr(feature_extractor, "size"))
def test_feat_extract_from_dict_with_kwargs(self):
feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict)
self.assertEqual(feature_extractor.size, {"shortest_edge": 18, "longest_edge": 1333})
self.assertEqual(feature_extractor.do_pad, True)
feature_extractor = self.feature_extraction_class.from_dict(
self.feat_extract_dict, size=42, max_size=84, pad_and_return_pixel_mask=False
)
self.assertEqual(feature_extractor.size, {"shortest_edge": 42, "longest_edge": 84})
self.assertEqual(feature_extractor.do_pad, False)
def test_batch_feature(self): def test_batch_feature(self):
pass pass
......
...@@ -103,6 +103,15 @@ class DeiTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC ...@@ -103,6 +103,15 @@ class DeiTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
self.assertTrue(hasattr(feature_extractor, "image_mean")) self.assertTrue(hasattr(feature_extractor, "image_mean"))
self.assertTrue(hasattr(feature_extractor, "image_std")) self.assertTrue(hasattr(feature_extractor, "image_std"))
def test_feat_extract_from_dict_with_kwargs(self):
feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict)
self.assertEqual(feature_extractor.size, {"height": 20, "width": 20})
self.assertEqual(feature_extractor.crop_size, {"height": 18, "width": 18})
feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict, size=42, crop_size=84)
self.assertEqual(feature_extractor.size, {"height": 42, "width": 42})
self.assertEqual(feature_extractor.crop_size, {"height": 84, "width": 84})
def test_batch_feature(self): def test_batch_feature(self):
pass pass
......
...@@ -136,6 +136,17 @@ class DetrFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC ...@@ -136,6 +136,17 @@ class DetrFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
self.assertTrue(hasattr(feature_extractor, "size")) self.assertTrue(hasattr(feature_extractor, "size"))
self.assertTrue(hasattr(feature_extractor, "do_pad")) self.assertTrue(hasattr(feature_extractor, "do_pad"))
def test_feat_extract_from_dict_with_kwargs(self):
feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict)
self.assertEqual(feature_extractor.size, {"shortest_edge": 18, "longest_edge": 1333})
self.assertEqual(feature_extractor.do_pad, True)
feature_extractor = self.feature_extraction_class.from_dict(
self.feat_extract_dict, size=42, max_size=84, pad_and_return_pixel_mask=False
)
self.assertEqual(feature_extractor.size, {"shortest_edge": 42, "longest_edge": 84})
self.assertEqual(feature_extractor.do_pad, False)
def test_batch_feature(self): def test_batch_feature(self):
pass pass
......
...@@ -103,6 +103,17 @@ class DonutFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.Test ...@@ -103,6 +103,17 @@ class DonutFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.Test
self.assertTrue(hasattr(feature_extractor, "image_mean")) self.assertTrue(hasattr(feature_extractor, "image_mean"))
self.assertTrue(hasattr(feature_extractor, "image_std")) self.assertTrue(hasattr(feature_extractor, "image_std"))
def test_feat_extract_from_dict_with_kwargs(self):
feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict)
self.assertEqual(feature_extractor.size, {"height": 18, "width": 20})
feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict, size=42)
self.assertEqual(feature_extractor.size, {"height": 42, "width": 42})
# Previous config had dimensions in (width, height) order
feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict, size=(42, 84))
self.assertEqual(feature_extractor.size, {"height": 84, "width": 42})
def test_batch_feature(self): def test_batch_feature(self):
pass pass
......
...@@ -92,6 +92,13 @@ class DPTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCa ...@@ -92,6 +92,13 @@ class DPTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCa
self.assertTrue(hasattr(feature_extractor, "do_resize")) self.assertTrue(hasattr(feature_extractor, "do_resize"))
self.assertTrue(hasattr(feature_extractor, "size")) self.assertTrue(hasattr(feature_extractor, "size"))
def test_feat_extract_from_dict_with_kwargs(self):
feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict)
self.assertEqual(feature_extractor.size, {"height": 18, "width": 18})
feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict, size=42)
self.assertEqual(feature_extractor.size, {"height": 42, "width": 42})
def test_call_pil(self): def test_call_pil(self):
# Initialize feature_extractor # Initialize feature_extractor
feature_extractor = self.feature_extraction_class(**self.feat_extract_dict) feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment