Unverified Commit f53fe35b authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

Fast image processor (#28847)



* Draft fast image processors

* Draft working fast version

* py3.8 compatible cache

* Enable loading fast image processors through auto

* Tidy up; rescale behaviour based on input type

* Enable tests for fast image processors

* Smarter rescaling

* Don't default to Fast

* Safer imports

* Add necessary Pillow requirement

* Woops

* Add AutoImageProcessor test

* Fix up

* Fix test for imagegpt

* Fix test

* Review comments

* Add warning for TF and JAX input types

* Rearrange

* Return transforms

* NumpyToTensor transformation

* Rebase - include changes from upstream in ImageProcessingMixin

* Safe typing

* Fix up

* convert mean/std to tesnor to rescale

* Don't store transforms in state

* Fix up

* Update src/transformers/image_processing_utils_fast.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/models/auto/image_processing_auto.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/models/auto/image_processing_auto.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/models/auto/image_processing_auto.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Warn if fast image processor available

* Update src/transformers/models/vit/image_processing_vit_fast.py

* Transpose incoming numpy images to be in CHW format

* Update mapping names based on packages, auto set fast to None

* Fix up

* Fix

* Add AutoImageProcessor.from_pretrained(checkpoint, use_fast=True) test

* Update src/transformers/models/vit/image_processing_vit_fast.py
Co-authored-by: default avatarPavel Iakubovskii <qubvel@gmail.com>

* Add equivalence and speed tests

* Fix up

---------
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: default avatarPavel Iakubovskii <qubvel@gmail.com>
parent edc1dffd
......@@ -32,3 +32,8 @@ An image processor is in charge of preparing input features for vision models an
## BaseImageProcessor
[[autodoc]] image_processing_utils.BaseImageProcessor
## BaseImageProcessorFast
[[autodoc]] image_processing_utils_fast.BaseImageProcessorFast
......@@ -158,6 +158,11 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h
[[autodoc]] ViTImageProcessor
- preprocess
## ViTImageProcessorFast
[[autodoc]] ViTImageProcessorFast
- preprocess
<frameworkcontent>
<pt>
......
......@@ -29,3 +29,4 @@ timm
albumentations >= 1.4.5
torchmetrics
pycocotools
Pillow>=10.0.1,<=15.0
......@@ -1104,7 +1104,8 @@ except OptionalDependencyNotAvailable:
name for name in dir(dummy_vision_objects) if not name.startswith("_")
]
else:
_import_structure["image_processing_utils"] = ["ImageProcessingMixin"]
_import_structure["image_processing_base"] = ["ImageProcessingMixin"]
_import_structure["image_processing_utils"] = ["BaseImageProcessor"]
_import_structure["image_utils"] = ["ImageFeatureExtractionMixin"]
_import_structure["models.beit"].extend(["BeitFeatureExtractor", "BeitImageProcessor"])
_import_structure["models.bit"].extend(["BitImageProcessor"])
......@@ -1167,6 +1168,18 @@ else:
_import_structure["models.vivit"].append("VivitImageProcessor")
_import_structure["models.yolos"].extend(["YolosFeatureExtractor", "YolosImageProcessor"])
try:
if not is_torchvision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils import dummy_torchvision_objects
_import_structure["utils.dummy_torchvision_objects"] = [
name for name in dir(dummy_torchvision_objects) if not name.startswith("_")
]
else:
_import_structure["image_processing_utils_fast"] = ["BaseImageProcessorFast"]
_import_structure["models.vit"].append("ViTImageProcessorFast")
# PyTorch-backed objects
try:
......@@ -5703,7 +5716,8 @@ if TYPE_CHECKING:
except OptionalDependencyNotAvailable:
from .utils.dummy_vision_objects import *
else:
from .image_processing_utils import ImageProcessingMixin
from .image_processing_base import ImageProcessingMixin
from .image_processing_utils import BaseImageProcessor
from .image_utils import ImageFeatureExtractionMixin
from .models.beit import BeitFeatureExtractor, BeitImageProcessor
from .models.bit import BitImageProcessor
......@@ -5793,6 +5807,15 @@ if TYPE_CHECKING:
from .models.vivit import VivitImageProcessor
from .models.yolos import YolosFeatureExtractor, YolosImageProcessor
try:
if not is_torchvision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils.dummy_torchvision_objects import *
else:
from .image_processing_utils_fast import BaseImageProcessorFast
from .models.vit import ViTImageProcessorFast
# Modeling
try:
if not is_torch_available():
......
This diff is collapsed.
This diff is collapsed.
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
from dataclasses import dataclass
from .image_processing_utils import BaseImageProcessor
from .utils.import_utils import is_torchvision_available
if is_torchvision_available():
from torchvision.transforms import Compose
@dataclass(frozen=True)
class SizeDict:
"""
Hashable dictionary to store image size information.
"""
height: int = None
width: int = None
longest_edge: int = None
shortest_edge: int = None
max_height: int = None
max_width: int = None
def __getitem__(self, key):
if hasattr(self, key):
return getattr(self, key)
raise KeyError(f"Key {key} not found in SizeDict.")
class BaseImageProcessorFast(BaseImageProcessor):
_transform_params = None
def _build_transforms(self, **kwargs) -> "Compose":
"""
Given the input settings e.g. do_resize, build the image transforms.
"""
raise NotImplementedError
def _validate_params(self, **kwargs) -> None:
for k, v in kwargs.items():
if k not in self._transform_params:
raise ValueError(f"Invalid transform parameter {k}={v}.")
@functools.lru_cache(maxsize=1)
def get_transforms(self, **kwargs) -> "Compose":
self._validate_params(**kwargs)
return self._build_transforms(**kwargs)
......@@ -31,6 +31,7 @@ from .utils.import_utils import (
is_flax_available,
is_tf_available,
is_torch_available,
is_torchvision_available,
is_vision_available,
requires_backends,
)
......@@ -50,6 +51,9 @@ if is_tf_available():
if is_flax_available():
import jax.numpy as jnp
if is_torchvision_available():
from torchvision.transforms import functional as F
def to_channel_dimension_format(
image: np.ndarray,
......@@ -374,6 +378,7 @@ def normalize(
if input_data_format is None:
input_data_format = infer_channel_dimension_format(image)
channel_axis = get_channel_dimension_axis(image, input_data_format=input_data_format)
num_channels = image.shape[channel_axis]
......@@ -802,3 +807,48 @@ def flip_channel_order(
if data_format is not None:
image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
return image
def _cast_tensor_to_float(x):
if x.is_floating_point():
return x
return x.float()
class FusedRescaleNormalize:
"""
Rescale and normalize the input image in one step.
"""
def __init__(self, mean, std, rescale_factor: float = 1.0, inplace: bool = False):
self.mean = torch.tensor(mean) * (1.0 / rescale_factor)
self.std = torch.tensor(std) * (1.0 / rescale_factor)
self.inplace = inplace
def __call__(self, image: "torch.Tensor"):
image = _cast_tensor_to_float(image)
return F.normalize(image, self.mean, self.std, inplace=self.inplace)
class Rescale:
"""
Rescale the input image by rescale factor: image *= rescale_factor.
"""
def __init__(self, rescale_factor: float = 1.0):
self.rescale_factor = rescale_factor
def __call__(self, image: "torch.Tensor"):
image = image * self.rescale_factor
return image
class NumpyToTensor:
"""
Convert a numpy array to a PyTorch tensor.
"""
def __call__(self, image: np.ndarray):
# Same as in PyTorch, we assume incoming numpy images are in HWC format
# c.f. https://github.com/pytorch/vision/blob/61d97f41bc209e1407dcfbd685d2ee2da9c1cdad/torchvision/transforms/functional.py#L154
return torch.from_numpy(image.transpose(2, 0, 1)).contiguous()
......@@ -25,9 +25,11 @@ from packaging import version
from .utils import (
ExplicitEnum,
is_jax_tensor,
is_numpy_array,
is_tf_tensor,
is_torch_available,
is_torch_tensor,
is_torchvision_available,
is_vision_available,
logging,
requires_backends,
......@@ -52,6 +54,20 @@ if is_vision_available():
else:
PILImageResampling = PIL.Image
if is_torchvision_available():
from torchvision.transforms import InterpolationMode
pil_torch_interpolation_mapping = {
PILImageResampling.NEAREST: InterpolationMode.NEAREST,
PILImageResampling.BOX: InterpolationMode.BOX,
PILImageResampling.BILINEAR: InterpolationMode.BILINEAR,
PILImageResampling.HAMMING: InterpolationMode.HAMMING,
PILImageResampling.BICUBIC: InterpolationMode.BICUBIC,
PILImageResampling.LANCZOS: InterpolationMode.LANCZOS,
PILImageResampling.NEAREST: InterpolationMode.NEAREST,
}
if TYPE_CHECKING:
if is_torch_available():
import torch
......@@ -90,14 +106,30 @@ def is_pil_image(img):
return is_vision_available() and isinstance(img, PIL.Image.Image)
class ImageType(ExplicitEnum):
PIL = "pillow"
TORCH = "torch"
NUMPY = "numpy"
TENSORFLOW = "tensorflow"
JAX = "jax"
def get_image_type(image):
if is_pil_image(image):
return ImageType.PIL
if is_torch_tensor(image):
return ImageType.TORCH
if is_numpy_array(image):
return ImageType.NUMPY
if is_tf_tensor(image):
return ImageType.TENSORFLOW
if is_jax_tensor(image):
return ImageType.JAX
raise ValueError(f"Unrecognised image type {type(image)}")
def is_valid_image(img):
return (
(is_vision_available() and isinstance(img, PIL.Image.Image))
or isinstance(img, np.ndarray)
or is_torch_tensor(img)
or is_tf_tensor(img)
or is_jax_tensor(img)
)
return is_pil_image(img) or is_numpy_array(img) or is_torch_tensor(img) or is_tf_tensor(img) or is_jax_tensor(img)
def valid_images(imgs):
......
......@@ -19,6 +19,7 @@ from ...utils import (
is_flax_available,
is_tf_available,
is_torch_available,
is_torchvision_available,
is_vision_available,
)
......@@ -34,6 +35,15 @@ else:
_import_structure["feature_extraction_vit"] = ["ViTFeatureExtractor"]
_import_structure["image_processing_vit"] = ["ViTImageProcessor"]
try:
if not is_torchvision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["image_processing_vit_fast"] = ["ViTImageProcessorFast"]
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
......@@ -83,6 +93,14 @@ if TYPE_CHECKING:
from .feature_extraction_vit import ViTFeatureExtractor
from .image_processing_vit import ViTImageProcessor
try:
if not is_torchvision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .image_processing_vit_fast import ViTImageProcessorFast
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
......
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Fast Image processor class for ViT."""
import functools
from typing import Dict, List, Optional, Union
from ...image_processing_base import BatchFeature
from ...image_processing_utils import get_size_dict
from ...image_processing_utils_fast import BaseImageProcessorFast, SizeDict
from ...image_transforms import FusedRescaleNormalize, NumpyToTensor, Rescale
from ...image_utils import (
IMAGENET_STANDARD_MEAN,
IMAGENET_STANDARD_STD,
ChannelDimension,
ImageInput,
ImageType,
PILImageResampling,
get_image_type,
make_list_of_images,
pil_torch_interpolation_mapping,
)
from ...utils import TensorType, logging
from ...utils.import_utils import is_torch_available, is_torchvision_available
logger = logging.get_logger(__name__)
if is_torch_available():
import torch
if is_torchvision_available():
from torchvision.transforms import Compose, Normalize, PILToTensor, Resize
class ViTImageProcessorFast(BaseImageProcessorFast):
r"""
Constructs a ViT image processor.
Args:
do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize the image's (height, width) dimensions to the specified `(size["height"],
size["width"])`. Can be overridden by the `do_resize` parameter in the `preprocess` method.
size (`dict`, *optional*, defaults to `{"height": 224, "width": 224}`):
Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess`
method.
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the
`preprocess` method.
do_rescale (`bool`, *optional*, defaults to `True`):
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale`
parameter in the `preprocess` method.
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the
`preprocess` method.
do_normalize (`bool`, *optional*, defaults to `True`):
Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
method.
image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
"""
model_input_names = ["pixel_values"]
_transform_params = [
"do_resize",
"do_rescale",
"do_normalize",
"size",
"resample",
"rescale_factor",
"image_mean",
"image_std",
"image_type",
]
def __init__(
self,
do_resize: bool = True,
size: Optional[Dict[str, int]] = None,
resample: PILImageResampling = PILImageResampling.BILINEAR,
do_rescale: bool = True,
rescale_factor: Union[int, float] = 1 / 255,
do_normalize: bool = True,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
size = size if size is not None else {"height": 224, "width": 224}
size = get_size_dict(size)
self.do_resize = do_resize
self.do_rescale = do_rescale
self.do_normalize = do_normalize
self.size = size
self.resample = resample
self.rescale_factor = rescale_factor
self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
self._transform_settings = {}
def _build_transforms(
self,
do_resize: bool,
size: Dict[str, int],
resample: PILImageResampling,
do_rescale: bool,
rescale_factor: float,
do_normalize: bool,
image_mean: Union[float, List[float]],
image_std: Union[float, List[float]],
image_type: ImageType,
) -> "Compose":
"""
Given the input settings build the image transforms using `torchvision.transforms.Compose`.
"""
transforms = []
# All PIL and numpy values need to be converted to a torch tensor
# to keep cross compatibility with slow image processors
if image_type == ImageType.PIL:
transforms.append(PILToTensor())
elif image_type == ImageType.NUMPY:
transforms.append(NumpyToTensor())
if do_resize:
transforms.append(
Resize((size["height"], size["width"]), interpolation=pil_torch_interpolation_mapping[resample])
)
# We can combine rescale and normalize into a single operation for speed
if do_rescale and do_normalize:
transforms.append(FusedRescaleNormalize(image_mean, image_std, rescale_factor=rescale_factor))
elif do_rescale:
transforms.append(Rescale(rescale_factor=rescale_factor))
elif do_normalize:
transforms.append(Normalize(image_mean, image_std))
return Compose(transforms)
@functools.lru_cache(maxsize=1)
def _validate_input_arguments(
self,
return_tensors: Union[str, TensorType],
do_resize: bool,
size: Dict[str, int],
resample: PILImageResampling,
do_rescale: bool,
rescale_factor: float,
do_normalize: bool,
image_mean: Union[float, List[float]],
image_std: Union[float, List[float]],
data_format: Union[str, ChannelDimension],
image_type: ImageType,
):
if return_tensors != "pt":
raise ValueError("Only returning PyTorch tensors is currently supported.")
if data_format != ChannelDimension.FIRST:
raise ValueError("Only channel first data format is currently supported.")
if do_resize and None in (size, resample):
raise ValueError("Size and resample must be specified if do_resize is True.")
if do_rescale and rescale_factor is None:
raise ValueError("Rescale factor must be specified if do_rescale is True.")
if do_normalize and None in (image_mean, image_std):
raise ValueError("Image mean and standard deviation must be specified if do_normalize is True.")
def preprocess(
self,
images: ImageInput,
do_resize: Optional[bool] = None,
size: Dict[str, int] = None,
resample: PILImageResampling = None,
do_rescale: Optional[bool] = None,
rescale_factor: Optional[float] = None,
do_normalize: Optional[bool] = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
return_tensors: Optional[Union[str, TensorType]] = "pt",
data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs,
):
"""
Preprocess an image or batch of images.
Args:
images (`ImageInput`):
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
Whether to resize the image.
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
Dictionary in the format `{"height": h, "width": w}` specifying the size of the output image after
resizing.
resample (`PILImageResampling` filter, *optional*, defaults to `self.resample`):
`PILImageResampling` filter to use if resizing the image e.g. `PILImageResampling.BILINEAR`. Only has
an effect if `do_resize` is set to `True`.
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
Whether to rescale the image values between [0 - 1].
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
Whether to normalize the image.
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
Image mean to use if `do_normalize` is set to `True`.
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
Image standard deviation to use if `do_normalize` is set to `True`.
return_tensors (`str` or `TensorType`, *optional*):
The type of tensors to return. Only "pt" is supported
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
The channel dimension format for the output image. The following formats are currently supported:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
"""
do_resize = do_resize if do_resize is not None else self.do_resize
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
resample = resample if resample is not None else self.resample
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
image_mean = image_mean if image_mean is not None else self.image_mean
image_std = image_std if image_std is not None else self.image_std
size = size if size is not None else self.size
# Make hashable for cache
size = SizeDict(**size)
image_mean = tuple(image_mean) if isinstance(image_mean, list) else image_mean
image_std = tuple(image_std) if isinstance(image_std, list) else image_std
images = make_list_of_images(images)
image_type = get_image_type(images[0])
if image_type not in [ImageType.PIL, ImageType.TORCH, ImageType.NUMPY]:
raise ValueError(f"Unsupported input image type {image_type}")
self._validate_input_arguments(
do_resize=do_resize,
size=size,
resample=resample,
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
return_tensors=return_tensors,
data_format=data_format,
image_type=image_type,
)
transforms = self.get_transforms(
do_resize=do_resize,
do_rescale=do_rescale,
do_normalize=do_normalize,
size=size,
resample=resample,
rescale_factor=rescale_factor,
image_mean=image_mean,
image_std=image_std,
image_type=image_type,
)
transformed_images = [transforms(image) for image in images]
data = {"pixel_values": torch.vstack(transformed_images)}
return BatchFeature(data, tensor_type=return_tensors)
# This file is autogenerated by the command `make fix-copies`, do not edit.
from ..utils import DummyObject, requires_backends
class BaseImageProcessorFast(metaclass=DummyObject):
_backends = ["torchvision"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torchvision"])
class ViTImageProcessorFast(metaclass=DummyObject):
_backends = ["torchvision"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torchvision"])
......@@ -9,6 +9,13 @@ class ImageProcessingMixin(metaclass=DummyObject):
requires_backends(self, ["vision"])
class BaseImageProcessor(metaclass=DummyObject):
_backends = ["vision"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["vision"])
class ImageFeatureExtractionMixin(metaclass=DummyObject):
_backends = ["vision"]
......
......@@ -27,8 +27,10 @@ from transformers import (
AutoImageProcessor,
CLIPConfig,
CLIPImageProcessor,
ViTImageProcessor,
ViTImageProcessorFast,
)
from transformers.testing_utils import DUMMY_UNKNOWN_IDENTIFIER
from transformers.testing_utils import DUMMY_UNKNOWN_IDENTIFIER, require_torchvision, require_vision
sys.path.append(str(Path(__file__).parent.parent.parent.parent / "utils"))
......@@ -133,6 +135,23 @@ class AutoImageProcessorTest(unittest.TestCase):
):
_ = AutoImageProcessor.from_pretrained("hf-internal-testing/config-no-model")
@require_vision
@require_torchvision
def test_use_fast_selection(self):
checkpoint = "hf-internal-testing/tiny-random-vit"
# Slow image processor is selected by default
image_processor = AutoImageProcessor.from_pretrained(checkpoint)
self.assertIsInstance(image_processor, ViTImageProcessor)
# Fast image processor is selected when use_fast=True
image_processor = AutoImageProcessor.from_pretrained(checkpoint, use_fast=True)
self.assertIsInstance(image_processor, ViTImageProcessorFast)
# Slow image processor is selected when use_fast=False
image_processor = AutoImageProcessor.from_pretrained(checkpoint, use_fast=False)
self.assertIsInstance(image_processor, ViTImageProcessor)
def test_from_pretrained_dynamic_image_processor(self):
# If remote code is not set, we will time out when asking whether to load the model.
with self.assertRaises(ValueError):
......
......@@ -121,6 +121,7 @@ class BeitImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processing_class = BeitImageProcessor if is_vision_available() else None
def setUp(self):
super().setUp()
self.image_processor_tester = BeitImageProcessingTester(self)
@property
......
......@@ -90,6 +90,7 @@ class BlipImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processing_class = BlipImageProcessor if is_vision_available() else None
def setUp(self):
super().setUp()
self.image_processor_tester = BlipImageProcessingTester(self)
@property
......@@ -112,6 +113,7 @@ class BlipImageProcessingTestFourChannels(ImageProcessingTestMixin, unittest.Tes
image_processing_class = BlipImageProcessor if is_vision_available() else None
def setUp(self):
super().setUp()
self.image_processor_tester = BlipImageProcessingTester(self, num_channels=4)
self.expected_encoded_image_num_channels = 3
......
......@@ -136,6 +136,7 @@ class BridgeTowerImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase
image_processing_class = BridgeTowerImageProcessor if is_vision_available() else None
def setUp(self):
super().setUp()
self.image_processor_tester = BridgeTowerImageProcessingTester(self)
@property
......
......@@ -98,6 +98,7 @@ class ChineseCLIPImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase
image_processing_class = ChineseCLIPImageProcessor if is_vision_available() else None
def setUp(self):
super().setUp()
self.image_processor_tester = ChineseCLIPImageProcessingTester(self, do_center_crop=True)
@property
......@@ -135,6 +136,7 @@ class ChineseCLIPImageProcessingTestFourChannels(ImageProcessingTestMixin, unitt
image_processing_class = ChineseCLIPImageProcessor if is_vision_available() else None
def setUp(self):
super().setUp()
self.image_processor_tester = ChineseCLIPImageProcessingTester(self, num_channels=4, do_center_crop=True)
self.expected_encoded_image_num_channels = 3
......
......@@ -94,6 +94,7 @@ class CLIPImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processing_class = CLIPImageProcessor if is_vision_available() else None
def setUp(self):
super().setUp()
self.image_processor_tester = CLIPImageProcessingTester(self)
@property
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment