Unverified Commit a6b77598 authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

Add Image Processors (#19796)



* Add CLIP image processor

* Crop size as dict too

* Update warning

* Actually use logger this time

* Normalize doesn't change dtype of input

* Add perceiver image processor

* Tidy up

* Add DPT image processor

* Add Vilt image processor

* Tidy up

* Add poolformer image processor

* Tidy up

* Add LayoutLM v2 and v3 imsge processors

* Tidy up

* Add Flava image processor

* Tidy up

* Add deit image processor

* Tidy up

* Add ConvNext image processor

* Tidy up

* Add levit image processor

* Add segformer image processor

* Add in post processing

* Fix up

* Add ImageGPT image processor

* Fixup

* Add mobilevit image processor

* Tidy up

* Add postprocessing

* Fixup

* Add VideoMAE image processor

* Tidy up

* Add ImageGPT image processor

* Fixup

* Add ViT image processor

* Tidy up

* Add beit image processor

* Add mobilevit image processor

* Tidy up

* Add postprocessing

* Fixup

* Fix up

* Fix flava and remove tree module

* Fix image classification pipeline failing tests

* Update feature extractor in trainer scripts

* Update pad_if_smaller to accept tuple and int size

* Update for image segmentation pipeline

* Update src/transformers/models/perceiver/image_processing_perceiver.py
Co-authored-by: default avatarAlara Dirik <8944735+alaradirik@users.noreply.github.com>

* Update src/transformers/image_processing_utils.py
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update src/transformers/models/beit/image_processing_beit.py
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* PR comments - docstrings; remove accidentally added resize; var names

* Update docstrings

* Add exception if size is not in the right format

* Fix exception check

* Fix up

* Use shortest_edge in tuple in script
Co-authored-by: default avatarAlara Dirik <8944735+alaradirik@users.noreply.github.com>
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>
parent 2e3452af
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Image processor class for Segformer."""
import warnings
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
from transformers.utils import is_torch_available, is_torch_tensor, is_vision_available
from transformers.utils.generic import TensorType
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import center_crop, normalize, rescale, resize, to_channel_dimension_format
from ...image_utils import (
IMAGENET_STANDARD_MEAN,
IMAGENET_STANDARD_STD,
ChannelDimension,
ImageInput,
PILImageResampling,
is_batched,
to_numpy_array,
valid_images,
)
from ...utils import logging
if is_vision_available():
import PIL.Image
if is_torch_available():
import torch
logger = logging.get_logger(__name__)
class SegformerImageProcessor(BaseImageProcessor):
r"""
Constructs a Segformer image processor.
Args:
do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize the image's (height, width) dimensions to the specified `(size["height"],
size["width"])`. Can be overridden by the `do_resize` parameter in the `preprocess` method.
size (`Dict[str, int]` *optional*, defaults to `{"height": 512, "width": 512}`):
Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess`
method.
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the
`preprocess` method.
do_rescale (`bool`, *optional*, defaults to `True`):
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale`
parameter in the `preprocess` method.
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
method.
do_normalize (`bool`, *optional*, defaults to `True`):
Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
method.
image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
do_reduce_labels (`bool`, *optional*, defaults to `False`):
Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0 is
used for background, and background itself is not included in all classes of a dataset (e.g. ADE20k). The
background label will be replaced by 255. Can be overridden by the `do_reduce_labels` parameter in the
`preprocess` method.
"""
model_input_names = ["pixel_values"]
def __init__(
self,
do_resize: bool = True,
size: Dict[str, int] = None,
resample: PILImageResampling = PILImageResampling.BILINEAR,
do_rescale: bool = True,
rescale_factor: Union[int, float] = 1 / 255,
do_normalize: bool = True,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_reduce_labels: bool = False,
**kwargs
) -> None:
if "reduce_labels" in kwargs:
warnings.warn(
"The `reduce_labels` parameter is deprecated and will be removed in a future version. Please use "
"`do_reduce_labels` instead.",
FutureWarning,
)
do_reduce_labels = kwargs.pop("reduce_labels")
super().__init__(**kwargs)
size = size if size is not None else {"height": 512, "width": 512}
size = get_size_dict(size)
self.do_resize = do_resize
self.size = size
self.resample = resample
self.do_rescale = do_rescale
self.rescale_factor = rescale_factor
self.do_normalize = do_normalize
self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
self.do_reduce_labels = do_reduce_labels
def resize(
self,
image: np.ndarray,
size: Dict[str, int],
resample: PILImageResampling = PILImageResampling.BICUBIC,
data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs
) -> np.ndarray:
"""
Resize an image to `(size["height"], size["width"])`.
Args:
image (`np.ndarray`):
Image to resize.
size (`Dict[str, int]`):
Size of the output image.
resample (`PILImageResampling`, *optional*, defaults to `PIL.Image.BICUBIC`):
Resampling filter to use when resiizing the image.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
size = get_size_dict(size)
if "height" not in size or "width" not in size:
raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}")
return resize(
image, size=(size["height"], size["width"]), resample=resample, data_format=data_format, **kwargs
)
def center_crop(
self,
image: np.ndarray,
size: Dict[str, int],
data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs
) -> np.ndarray:
"""
Center crop an image to `(size["height"], size["width"])`. If the input size is smaller than `crop_size` along
any edge, the image is padded with 0's and then center cropped.
Args:
image (`np.ndarray`):
Image to center crop.
size (`Dict[str, int]`):
Size of the output image.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
size = get_size_dict(size)
return center_crop(image, size=(size["height"], size["width"]), data_format=data_format, **kwargs)
def rescale(
self,
image: np.ndarray,
scale: Union[int, float],
data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs
):
"""
Rescale an image by a scale factor. image = image * scale.
Args:
image (`np.ndarray`):
Image to rescale.
scale (`int` or `float`):
Scale to apply to the image.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
return rescale(image, scale=scale, data_format=data_format, **kwargs)
def normalize(
self,
image: np.ndarray,
mean: Union[float, List[float]],
std: Union[float, List[float]],
data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs
) -> np.ndarray:
"""
Normalize an image. image = (image - image_mean) / image_std.
Args:
image (`np.ndarray`):
Image to normalize.
image_mean (`float` or `List[float]`):
Image mean.
image_std (`float` or `List[float]`):
Image standard deviation.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
return normalize(image, mean=mean, std=std, data_format=data_format, **kwargs)
def reduce_label(self, label: ImageInput) -> np.ndarray:
label = to_numpy_array(label)
# Avoid using underflow conversion
label[label == 0] = 255
label = label - 1
label[label == 254] = 255
return label
def _preprocess(
self,
image: ImageInput,
do_reduce_labels: bool,
do_resize: bool,
do_rescale: bool,
do_normalize: bool,
size: Optional[Dict[str, int]] = None,
resample: Optional[PILImageResampling] = None,
rescale_factor: Optional[float] = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
):
if do_reduce_labels:
image = self.reduce_label(image)
if do_resize:
image = self.resize(image=image, size=size, resample=resample)
if do_rescale:
image = self.rescale(image=image, scale=rescale_factor)
if do_normalize:
image = self.normalize(image=image, mean=image_mean, std=image_std)
return image
def _preprocess_image(
self,
image: ImageInput,
do_resize: bool = None,
size: Dict[str, int] = None,
resample: PILImageResampling = None,
do_rescale: bool = None,
rescale_factor: float = None,
do_normalize: bool = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray:
"""Preprocesses a single image."""
# All transformations expect numpy arrays.
image = to_numpy_array(image)
image = self._preprocess(
image=image,
do_reduce_labels=False,
do_resize=do_resize,
size=size,
resample=resample,
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
)
if data_format is not None:
image = to_channel_dimension_format(image, data_format)
return image
def _preprocess_mask(
self,
segmentation_map: ImageInput,
do_reduce_labels: bool = None,
do_resize: bool = None,
size: Dict[str, int] = None,
resample: PILImageResampling = None,
) -> np.ndarray:
"""Preprocesses a single mask."""
segmentation_map = to_numpy_array(segmentation_map)
# Add channel dimension if missing - needed for certain transformations
added_channel_dim = False
if segmentation_map.ndim == 2:
added_channel_dim = True
segmentation_map = segmentation_map[None, ...]
# reduce zero label if needed
segmentation_map = self._preprocess(
image=segmentation_map,
do_reduce_labels=do_reduce_labels,
do_resize=do_resize,
resample=PIL.Image.NEAREST,
size=size,
do_rescale=False,
do_normalize=False,
)
# Remove extra channel dimension if added for processing
if added_channel_dim:
segmentation_map = segmentation_map.squeeze(0)
segmentation_map = segmentation_map.astype(np.int64)
return segmentation_map
def __call__(self, images, segmentation_maps=None, **kwargs):
"""
Preprocesses a batch of images and optionally segmentation maps.
Overrides the `__call__` method of the `Preprocessor` class so that both images and segmentation maps can be
passed in as positional arguments.
"""
return super().__call__(images, segmentation_maps=segmentation_maps, **kwargs)
def preprocess(
self,
images: ImageInput,
segmentation_maps: Optional[ImageInput] = None,
do_resize: Optional[bool] = None,
size: Optional[Dict[str, int]] = None,
resample: Optional[PILImageResampling] = None,
do_rescale: Optional[bool] = None,
rescale_factor: Optional[float] = None,
do_normalize: Optional[bool] = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_reduce_labels: Optional[bool] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: ChannelDimension = ChannelDimension.FIRST,
) -> PIL.Image.Image:
"""
Preprocess an image or batch of images.
Args:
images (`ImageInput`):
Image to preprocess.
segmentation_maps (`ImageInput`, *optional*):
Segmentation map to preprocess.
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
Whether to resize the image.
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
Size of the image after `resize` is applied.
resample (`int`, *optional*, defaults to `self.resample`):
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`, Only
has an effect if `do_resize` is set to `True`.
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
Whether to rescale the image values between [0 - 1].
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
Whether to normalize the image.
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
Image mean.
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
Image standard deviation.
do_reduce_labels (`bool`, *optional*, defaults to `self.do_reduce_labels`):
Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0
is used for background, and background itself is not included in all classes of a dataset (e.g.
ADE20k). The background label will be replaced by 255.
return_tensors (`str` or `TensorType`, *optional*):
The type of tensors to return. Can be one of:
- Unset: Return a list of `np.ndarray`.
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
The channel dimension format for the output image. Can be one of:
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
"""
do_resize = do_resize if do_resize is not None else self.do_resize
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
do_reduce_labels = do_reduce_labels if do_reduce_labels is not None else self.do_reduce_labels
resample = resample if resample is not None else self.resample
size = size if size is not None else self.size
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
image_mean = image_mean if image_mean is not None else self.image_mean
image_std = image_std if image_std is not None else self.image_std
if not is_batched(images):
images = [images]
segmentation_maps = [segmentation_maps] if segmentation_maps is not None else None
if not valid_images(images):
raise ValueError(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
if segmentation_maps is not None and not valid_images(segmentation_maps):
raise ValueError(
"Invalid segmentation map type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
if do_resize and size is None or resample is None:
raise ValueError("Size and resample must be specified if do_resize is True.")
if do_rescale and rescale_factor is None:
raise ValueError("Rescale factor must be specified if do_rescale is True.")
if do_normalize and (image_mean is None or image_std is None):
raise ValueError("Image mean and std must be specified if do_normalize is True.")
images = [
self._preprocess_image(
image=img,
do_resize=do_resize,
resample=resample,
size=size,
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
data_format=data_format,
)
for img in images
]
data = {"pixel_values": images}
if segmentation_maps is not None:
segmentation_maps = [
self._preprocess_mask(
segmentation_map=segmentation_map,
do_reduce_labels=do_reduce_labels,
do_resize=do_resize,
resample=PIL.Image.NEAREST,
size=size,
)
for segmentation_map in segmentation_maps
]
data["labels"] = segmentation_maps
return BatchFeature(data=data, tensor_type=return_tensors)
def post_process_semantic_segmentation(self, outputs, target_sizes: List[Tuple] = None):
"""
Converts the output of [`SegformerForSemanticSegmentation`] into semantic segmentation maps. Only supports
PyTorch.
Args:
outputs ([`SegformerForSemanticSegmentation`]):
Raw outputs of the model.
target_sizes (`List[Tuple]` of length `batch_size`, *optional*):
List of tuples corresponding to the requested final size (height, width) of each prediction. If left to
None, predictions will not be resized.
Returns:
semantic_segmentation: `List[torch.Tensor]` of length `batch_size`, where each item is a semantic
segmentation map of shape (height, width) corresponding to the target_sizes entry (if `target_sizes` is
specified). Each entry of each `torch.Tensor` correspond to a semantic class id.
"""
# TODO: add support for other frameworks
logits = outputs.logits
# Resize logits and compute semantic segmentation maps
if target_sizes is not None:
if len(logits) != len(target_sizes):
raise ValueError(
"Make sure that you pass in as many target sizes as the batch dimension of the logits"
)
if is_torch_tensor(target_sizes):
target_sizes = target_sizes.numpy()
semantic_segmentation = []
for idx in range(len(logits)):
resized_logits = torch.nn.functional.interpolate(
logits[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False
)
semantic_map = resized_logits[0].argmax(dim=0)
semantic_segmentation.append(semantic_map)
else:
semantic_segmentation = logits.argmax(dim=1)
semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])]
return semantic_segmentation
...@@ -14,159 +14,11 @@ ...@@ -14,159 +14,11 @@
# limitations under the License. # limitations under the License.
"""Feature extractor class for VideoMAE.""" """Feature extractor class for VideoMAE."""
from typing import Optional, Union from ...utils import logging
from .image_processing_videomae import VideoMAEImageProcessor
import numpy as np
from PIL import Image
from transformers.image_utils import PILImageResampling
from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
from ...image_utils import ImageFeatureExtractionMixin, ImageInput, is_torch_tensor
from ...utils import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, TensorType, logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
class VideoMAEFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin): VideoMAEFeatureExtractor = VideoMAEImageProcessor
r"""
Constructs a VideoMAE feature extractor. This feature extractor can be used to prepare videos for the model.
This feature extractor inherits from [`FeatureExtractionMixin`] which contains most of the main methods. Users
should refer to this superclass for more information regarding those methods.
Args:
do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize the shorter edge of the input to a certain `size`.
size (`int`, *optional*, defaults to 224):
Resize the shorter edge of the input to the given size. Only has an effect if `do_resize` is set to `True`.
resample (`int`, *optional*, defaults to `PIL.Image.Resampling.BILINEAR`):
An optional resampling filter. This can be one of `PIL.Image.Resampling.NEAREST`,
`PIL.Image.Resampling.BOX`, `PIL.Image.Resampling.BILINEAR`, `PIL.Image.Resampling.HAMMING`,
`PIL.Image.Resampling.BICUBIC` or `PIL.Image.Resampling.LANCZOS`. Only has an effect if `do_resize` is set
to `True`.
do_center_crop (`bool`, *optional*, defaults to `True`):
Whether to center crop the input to a certain `size`.
do_normalize (`bool`, *optional*, defaults to `True`):
Whether or not to normalize the input with mean and standard deviation.
image_mean (`List[int]`, defaults to `[0.485, 0.456, 0.406]`):
The sequence of means for each channel, to be used when normalizing images.
image_std (`List[int]`, defaults to `[0.229, 0.224, 0.225]`):
The sequence of standard deviations for each channel, to be used when normalizing images.
"""
model_input_names = ["pixel_values"]
def __init__(
self,
do_resize=True,
size=224,
resample=PILImageResampling.BILINEAR,
do_center_crop=True,
do_normalize=True,
image_mean=None,
image_std=None,
**kwargs
):
super().__init__(**kwargs)
self.do_resize = do_resize
self.size = size
self.resample = resample
self.do_center_crop = do_center_crop
self.do_normalize = do_normalize
self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
def resize_video(self, video, size, resample="bilinear"):
return [self.resize(frame, size, resample, default_to_square=False) for frame in video]
def crop_video(self, video, size):
return [self.center_crop(frame, size) for frame in video]
def normalize_video(self, video, mean, std):
# video can be a list of PIL images, list of NumPy arrays or list of PyTorch tensors
# first: convert to list of NumPy arrays
video = [self.to_numpy_array(frame) for frame in video]
# second: stack to get (num_frames, num_channels, height, width)
video = np.stack(video, axis=0)
# third: normalize
if not isinstance(mean, np.ndarray):
mean = np.array(mean).astype(video.dtype)
if not isinstance(std, np.ndarray):
std = np.array(std).astype(video.dtype)
return (video - mean[None, :, None, None]) / std[None, :, None, None]
def __call__(
self, videos: ImageInput, return_tensors: Optional[Union[str, TensorType]] = None, **kwargs
) -> BatchFeature:
"""
Main method to prepare for the model one or several video(s).
<Tip warning={true}>
NumPy arrays are converted to PIL images when resizing, so the most efficient is to pass PIL images.
</Tip>
Args:
videos (`List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`, `List[List[PIL.Image.Image]]`, `List[List[np.ndarrray]]`,:
`List[List[torch.Tensor]]`): The video or batch of videos to be prepared. Each video should be a list
of frames, which can be either PIL images or NumPy arrays. In case of NumPy arrays/PyTorch tensors,
each frame should be of shape (H, W, C), where H and W are frame height and width, and C is a number of
channels.
return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`):
If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **pixel_values** -- Pixel values to be fed to a model, of shape (batch_size, num_channels, num_frames,
height, width).
"""
# Input type checking for clearer error
valid_videos = False
is_batched = False
# Check that videos have a valid type
if isinstance(videos, (list, tuple)):
if isinstance(videos[0], (Image.Image, np.ndarray)) or is_torch_tensor(videos[0]):
valid_videos = True
elif isinstance(videos[0], (list, tuple)) and (
isinstance(videos[0][0], (Image.Image, np.ndarray)) or is_torch_tensor(videos[0][0])
):
valid_videos = True
is_batched = True
if not valid_videos:
raise ValueError(
"Videos must of type `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]` (single"
" example), `List[List[PIL.Image.Image]]`, `List[List[np.ndarray]]`, `List[List[torch.Tensor]]` (batch"
" of examples)."
)
if not is_batched:
videos = [videos]
# transformations (resizing + center cropping + normalization)
if self.do_resize and self.size is not None:
videos = [self.resize_video(video, size=self.size, resample=self.resample) for video in videos]
if self.do_center_crop and self.size is not None:
videos = [self.crop_video(video, size=self.size) for video in videos]
if self.do_normalize:
videos = [self.normalize_video(video, mean=self.image_mean, std=self.image_std) for video in videos]
# return as BatchFeature
data = {"pixel_values": videos}
encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
return encoded_inputs
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Image processor class for VideoMAE."""
from typing import Dict, List, Optional, Union
import numpy as np
from transformers.utils import is_vision_available
from transformers.utils.generic import TensorType
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import (
center_crop,
get_resize_output_image_size,
normalize,
rescale,
resize,
to_channel_dimension_format,
)
from ...image_utils import (
IMAGENET_STANDARD_MEAN,
IMAGENET_STANDARD_STD,
ChannelDimension,
ImageInput,
PILImageResampling,
is_valid_image,
to_numpy_array,
valid_images,
)
from ...utils import logging
if is_vision_available():
import PIL
logger = logging.get_logger(__name__)
def make_batched(videos) -> List[List[ImageInput]]:
if isinstance(videos, (list, tuple)) and isinstance(videos[0], (list, tuple)) and is_valid_image(videos[0][0]):
return videos
elif isinstance(videos, (list, tuple)) and is_valid_image(videos[0]):
return [videos]
elif is_valid_image(videos):
return [[videos]]
raise ValueError(f"Could not make batched video from {videos}")
class VideoMAEImageProcessor(BaseImageProcessor):
r"""
Constructs a VideoMAE image processor.
Args:
do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the
`do_resize` parameter in the `preprocess` method.
size (`Dict[str, int]` *optional*, defaults to `{"shortest_edge": 224}`):
Size of the output image after resizing. The shortest edge of the image will be resized to
`size["shortest_edge"]` while maintaining the aspect ratio of the original image. Can be overriden by
`size` in the `preprocess` method.
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the
`preprocess` method.
do_center_crop (`bool`, *optional*, defaults to `True`):
Whether to center crop the image to the specified `crop_size`. Can be overridden by the `do_center_crop`
parameter in the `preprocess` method.
crop_size (`Dict[str, int]`, *optional*, defaults to `{"height": 224, "width": 224}`):
Size of the image after applying the center crop. Can be overridden by the `crop_size` parameter in the
`preprocess` method.
do_rescale (`bool`, *optional*, defaults to `True`):
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale`
parameter in the `preprocess` method.
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
Defines the scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter
in the `preprocess` method.
do_normalize (`bool`, *optional*, defaults to `True`):
Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
method.
image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
"""
model_input_names = ["pixel_values"]
def __init__(
self,
do_resize: bool = True,
size: Dict[str, int] = None,
resample: PILImageResampling = PILImageResampling.BILINEAR,
do_center_crop: bool = True,
crop_size: Dict[str, int] = None,
do_rescale: bool = True,
rescale_factor: Union[int, float] = 1 / 255,
do_normalize: bool = True,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
**kwargs
) -> None:
super().__init__(**kwargs)
size = size if size is not None else {"shortest_edge": 224}
size = get_size_dict(size, default_to_square=False)
crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
crop_size = get_size_dict(crop_size)
self.do_resize = do_resize
self.size = size
self.do_center_crop = do_center_crop
self.crop_size = crop_size
self.resample = resample
self.do_rescale = do_rescale
self.rescale_factor = rescale_factor
self.do_normalize = do_normalize
self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
def resize(
self,
image: np.ndarray,
size: Dict[str, int],
resample: PILImageResampling = PILImageResampling.BILINEAR,
data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs
) -> np.ndarray:
"""
Resize an image.
Args:
image (`np.ndarray`):
Image to resize.
size (`Dict[str, int]`):
Size of the output image. If `size` is of the form `{"height": h, "width": w}`, the output image will
have the size `(h, w)`. If `size` is of the form `{"shortest_edge": s}`, the output image will have its
shortest edge of length `s` while keeping the aspect ratio of the original image.
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
Resampling filter to use when resiizing the image.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
size = get_size_dict(size)
if "shortest_edge" in size:
output_size = get_resize_output_image_size(image, size["shortest_edge"], default_to_square=False)
elif "height" in size and "width" in size:
output_size = (size["height"], size["width"])
else:
raise ValueError(f"Size must have 'height' and 'width' or 'shortest_edge' as keys. Got {size.keys()}")
return resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs)
def center_crop(
self,
image: np.ndarray,
size: Dict[str, int],
data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs
) -> np.ndarray:
"""
Center crop an image to `(size["height"], size["width"])`. If the input size is smaller than `size` along any
edge, the image is padded with 0's and then center cropped.
Args:
image (`np.ndarray`):
Image to center crop.
size (`Dict[str, int]`):
Size of the output image.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
size = get_size_dict(size)
return center_crop(image, size=(size["height"], size["width"]), data_format=data_format, **kwargs)
def rescale(
self,
image: np.ndarray,
scale: Union[int, float],
data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs
):
"""
Rescale an image by a scale factor. image = image * scale.
Args:
image (`np.ndarray`):
Image to rescale.
scale (`int` or `float`):
Scale to apply to the image.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
return rescale(image, scale=scale, data_format=data_format, **kwargs)
def normalize(
self,
image: np.ndarray,
mean: Union[float, List[float]],
std: Union[float, List[float]],
data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs
) -> np.ndarray:
"""
Normalize an image. image = (image - image_mean) / image_std.
Args:
image (`np.ndarray`):
Image to normalize.
image_mean (`float` or `List[float]`):
Image mean.
image_std (`float` or `List[float]`):
Image standard deviation.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
return normalize(image, mean=mean, std=std, data_format=data_format, **kwargs)
def _preprocess_image(
self,
image: ImageInput,
do_resize: bool = None,
size: Dict[str, int] = None,
resample: PILImageResampling = None,
do_center_crop: bool = None,
crop_size: Dict[str, int] = None,
do_rescale: bool = None,
rescale_factor: float = None,
do_normalize: bool = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
) -> np.ndarray:
"""Preprocesses a single image."""
if do_resize and size is None or resample is None:
raise ValueError("Size and resample must be specified if do_resize is True.")
if do_center_crop and crop_size is None:
raise ValueError("Crop size must be specified if do_center_crop is True.")
if do_rescale and rescale_factor is None:
raise ValueError("Rescale factor must be specified if do_rescale is True.")
if do_normalize and (image_mean is None or image_std is None):
raise ValueError("Image mean and std must be specified if do_normalize is True.")
# All transformations expect numpy arrays.
image = to_numpy_array(image)
if do_resize:
image = self.resize(image=image, size=size, resample=resample)
if do_center_crop:
image = self.center_crop(image, size=crop_size)
if do_rescale:
image = self.rescale(image=image, scale=rescale_factor)
if do_normalize:
image = self.normalize(image=image, mean=image_mean, std=image_std)
image = to_channel_dimension_format(image, data_format)
return image
def preprocess(
self,
videos: ImageInput,
do_resize: bool = None,
size: Dict[str, int] = None,
resample: PILImageResampling = None,
do_center_crop: bool = None,
crop_size: Dict[str, int] = None,
do_rescale: bool = None,
rescale_factor: float = None,
do_normalize: bool = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: ChannelDimension = ChannelDimension.FIRST,
) -> PIL.Image.Image:
"""
Preprocess an image or batch of images.
Args:
images (`ImageInput`):
Image to preprocess.
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
Whether to resize the image.
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
Size of the image after applying resize.
resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`, Only
has an effect if `do_resize` is set to `True`.
do_center_crop (`bool`, *optional*, defaults to `self.do_centre_crop`):
Whether to centre crop the image.
crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`):
Size of the image after applying the centre crop.
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
Whether to rescale the image values between [0 - 1].
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
Whether to normalize the image.
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
Image mean.
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
Image standard deviation.
return_tensors (`str` or `TensorType`, *optional*):
The type of tensors to return. Can be one of:
- Unset: Return a list of `np.ndarray`.
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
The channel dimension format for the output image. Can be one of:
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- Unset: Use the inferred channel dimension format of the input image.
"""
do_resize = do_resize if do_resize is not None else self.do_resize
resample = resample if resample is not None else self.resample
do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
image_mean = image_mean if image_mean is not None else self.image_mean
image_std = image_std if image_std is not None else self.image_std
size = size if size is not None else self.size
size = get_size_dict(size, default_to_square=False)
crop_size = crop_size if crop_size is not None else self.crop_size
crop_size = get_size_dict(crop_size)
if not valid_images(videos):
raise ValueError(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
videos = make_batched(videos)
videos = [
[
self._preprocess_image(
image=img,
do_resize=do_resize,
size=size,
resample=resample,
do_center_crop=do_center_crop,
crop_size=crop_size,
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
data_format=data_format,
)
for img in video
]
for video in videos
]
data = {"pixel_values": videos}
return BatchFeature(data=data, tensor_type=return_tensors)
...@@ -14,282 +14,11 @@ ...@@ -14,282 +14,11 @@
# limitations under the License. # limitations under the License.
"""Feature extractor class for ViLT.""" """Feature extractor class for ViLT."""
from typing import List, Optional, Union from ...utils import logging
from .image_processing_vilt import ViltImageProcessor
import numpy as np
from PIL import Image
from transformers.image_utils import PILImageResampling
from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
from ...image_utils import (
IMAGENET_STANDARD_MEAN,
IMAGENET_STANDARD_STD,
ImageFeatureExtractionMixin,
ImageInput,
is_torch_tensor,
)
from ...utils import TensorType, is_torch_available, logging
if is_torch_available():
import torch
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
class ViltFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin): ViltFeatureExtractor = ViltImageProcessor
r"""
Constructs a ViLT feature extractor.
This feature extractor inherits from [`FeatureExtractionMixin`] which contains most of the main methods. Users
should refer to this superclass for more information regarding those methods.
Args:
do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize the input based on `size`.
size (`int`, *optional*, defaults to 384):
Resize the shorter side of the input to the given size. Should be an integer. The longer side will be
limited to under int((1333 / 800) * size) while preserving the aspect ratio. Only has an effect if
`do_resize` is set to `True`.
size_divisor (`int`, *optional*, defaults to 32):
The size by which to make sure both the height and width can be divided.
resample (`int`, *optional*, defaults to `PIL.Image.Resampling.BICUBIC`):
An optional resampling filter. This can be one of `PIL.Image.Resampling.NEAREST`,
`PIL.Image.Resampling.BOX`, `PIL.Image.Resampling.BILINEAR`, `PIL.Image.Resampling.HAMMING`,
`PIL.Image.Resampling.BICUBIC` or `PIL.Image.Resampling.LANCZOS`. Only has an effect if `do_resize` is set
to `True`.
do_normalize (`bool`, *optional*, defaults to `True`):
Whether or not to normalize the input with mean and standard deviation.
image_mean (`List[int]`, defaults to `[0.5, 0.5, 0.5]`):
The sequence of means for each channel, to be used when normalizing images.
image_std (`List[int]`, defaults to `[0.5, 0.5, 0.5]`):
The sequence of standard deviations for each channel, to be used when normalizing images.
"""
model_input_names = ["pixel_values", "pixel_mask"]
def __init__(
self,
do_resize=True,
size=384,
size_divisor=32,
resample=PILImageResampling.BICUBIC,
do_normalize=True,
image_mean=None,
image_std=None,
**kwargs
):
super().__init__(**kwargs)
self.do_resize = do_resize
self.size = size
self.size_divisor = size_divisor
self.resample = resample
self.do_normalize = do_normalize
self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
def _resize(self, image, shorter=800, longer=1333, size_divisor=32, resample=PILImageResampling.BICUBIC):
"""
Resizes the shorter edge of `image` to `shorter` and limits the longer edge to under `longer`, while preserving
the aspect ratio. Also makes sure that both the height and width can be divided by `size_divisor`.
Based on original implementation:
https://github.com/dandelin/ViLT/blob/3db8b5035464afee84d951bf6322e1b27f1d072d/vilt/transforms/utils.py#L5
Args:
image (`PIL.Image`):
The image to resize.
shorter (`int`, *optional*, defaults to `800`):
The size to which to resize the shorter side of the image.
longer (`int`, *optional*, defaults to `1333`):
The size by which to limit the longer side of the image, while preserving the aspect ratio.
size_divisor (`int`, *optional*, defaults to `32`):
The size by which both the height and the width must be divisible.
resample (`int`, *optional*, defaults to `PIL.Image.Resampling.BICUBIC`):
An optional resampling filter.
"""
if not isinstance(image, Image.Image):
image = self.to_pil_image(image)
w, h = image.size
min_size = shorter
max_size = longer
scale = min_size / min(w, h)
if h < w:
newh, neww = min_size, scale * w
else:
newh, neww = scale * h, min_size
if max(newh, neww) > max_size:
scale = max_size / max(newh, neww)
newh = newh * scale
neww = neww * scale
newh, neww = int(newh + 0.5), int(neww + 0.5)
newh, neww = newh // size_divisor * size_divisor, neww // size_divisor * size_divisor
return self.resize(image, size=(neww, newh), resample=resample)
def _max_by_axis(self, the_list):
# type: (List[List[int]]) -> List[int]
maxes = the_list[0]
for sublist in the_list[1:]:
for index, item in enumerate(sublist):
maxes[index] = max(maxes[index], item)
return maxes
def pad_and_create_pixel_mask(
self, pixel_values_list: List["torch.Tensor"], return_tensors: Optional[Union[str, TensorType]] = None
):
"""
Pad images up to the largest image in a batch and create a corresponding `pixel_mask`.
Args:
pixel_values_list (`List[torch.Tensor]`):
List of images (pixel values) to be padded. Each image should be a tensor of shape (C, H, W).
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors instead of NumPy arrays. If set to `'pt'`, return PyTorch `torch.Tensor`
objects.
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **pixel_values** -- Pixel values to be fed to a model.
- **pixel_mask** -- Pixel mask to be fed to a model (when `pad_and_return_pixel_mask=True` or if
*"pixel_mask"* is in `self.model_input_names`).
"""
max_size = self._max_by_axis([list(image.shape) for image in pixel_values_list])
c, h, w = max_size
padded_images = []
pixel_mask = []
for image in pixel_values_list:
# create padded image
padded_image = np.zeros((c, h, w), dtype=np.float32)
padded_image[: image.shape[0], : image.shape[1], : image.shape[2]] = np.copy(image)
padded_images.append(padded_image)
# create pixel mask
mask = np.zeros((h, w), dtype=np.int64)
mask[: image.shape[1], : image.shape[2]] = True
pixel_mask.append(mask)
# return as BatchFeature
data = {"pixel_values": padded_images, "pixel_mask": pixel_mask}
encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
return encoded_inputs
def __call__(
self,
images: ImageInput,
pad_and_return_pixel_mask: Optional[bool] = True,
return_tensors: Optional[Union[str, TensorType]] = None,
**kwargs
) -> BatchFeature:
"""
Main method to prepare for the model one or several image(s).
<Tip warning={true}>
NumPy arrays and PyTorch tensors are converted to PIL images when resizing, so the most efficient is to pass
PIL images.
</Tip>
Args:
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
number of channels, H and W are image height and width.
pad_and_return_pixel_mask (`bool`, *optional*, defaults to `True`):
Whether or not to pad images up to the largest image in a batch and create a pixel mask.
If left to the default, will return a pixel mask that is:
- 1 for pixels that are real (i.e. **not masked**),
- 0 for pixels that are padding (i.e. **masked**).
return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`):
If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **pixel_values** -- Pixel values to be fed to a model, of shape (batch_size, num_channels, height,
width).
- **pixel_mask** -- Pixel mask to be fed to a model (when `return_pixel_mask=True` or if *"pixel_mask"* is
in `self.model_input_names`).
"""
# Input type checking for clearer error
valid_images = False
# Check that images has a valid type
if isinstance(images, (Image.Image, np.ndarray)) or is_torch_tensor(images):
valid_images = True
elif isinstance(images, (list, tuple)):
if len(images) == 0 or isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]):
valid_images = True
if not valid_images:
raise ValueError(
"Images must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single example), "
"`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of examples)."
)
is_batched = bool(
isinstance(images, (list, tuple))
and (isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]))
)
if not is_batched:
images = [images]
# transformations (resizing + normalization)
if self.do_resize and self.size is not None:
longer = int((1333 / 800) * self.size)
images = [
self._resize(
image=image,
shorter=self.size,
longer=longer,
size_divisor=self.size_divisor,
resample=self.resample,
)
for image in images
]
if self.do_normalize:
images = [self.normalize(image=image, mean=self.image_mean, std=self.image_std) for image in images]
if pad_and_return_pixel_mask:
# pad images up to largest image in batch and create pixel_mask
max_size = self._max_by_axis([list(image.shape) for image in images])
c, h, w = max_size
padded_images = []
pixel_mask = []
for image in images:
# create padded image
padded_image = np.zeros((c, h, w), dtype=np.float32)
padded_image[: image.shape[0], : image.shape[1], : image.shape[2]] = np.copy(image)
padded_images.append(padded_image)
# create pixel mask
mask = np.zeros((h, w), dtype=np.int64)
mask[: image.shape[1], : image.shape[2]] = True
pixel_mask.append(mask)
images = padded_images
# return as BatchFeature
data = {}
data["pixel_values"] = images
if pad_and_return_pixel_mask:
data["pixel_mask"] = pixel_mask
encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
return encoded_inputs
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Image processor class for Vilt."""
import warnings
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
import numpy as np
from transformers.utils import is_vision_available
from transformers.utils.generic import TensorType
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import normalize, rescale, resize, to_channel_dimension_format
from ...image_utils import (
IMAGENET_STANDARD_MEAN,
IMAGENET_STANDARD_STD,
ChannelDimension,
ImageInput,
PILImageResampling,
get_image_size,
infer_channel_dimension_format,
is_batched,
to_numpy_array,
valid_images,
)
from ...utils import logging
if is_vision_available():
import PIL
logger = logging.get_logger(__name__)
def max_across_indices(values: Iterable[Any]) -> List[Any]:
"""
Return the maximum value across all indices of an iterable of values.
"""
return [max(values_i) for values_i in zip(*values)]
def pad(
image: np.ndarray,
output_size: Tuple[int, int],
input_channel_dimension: Optional[ChannelDimension] = None,
data_format: Optional[ChannelDimension] = None,
) -> np.ndarray:
"""
Pad the bottom and right of the image with zeros to the output size.
Args:
image (`np.ndarray`):
Image to pad.
output_size (`Tuple[int, int]`):
Output size of the image.
input_channel_dimension (`ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be inferred from the input image.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
if input_channel_dimension is None:
input_channel_dimension = infer_channel_dimension_format(image)
output_height, output_width = output_size
input_height, input_width = get_image_size(image)
pad_bottom = output_height - input_height
pad_right = output_width - input_width
if input_channel_dimension == ChannelDimension.FIRST:
padded_image = np.pad(image, [(0, 0), (0, pad_bottom), (0, pad_right)], mode="constant", constant_values=0)
elif input_channel_dimension == ChannelDimension.LAST:
padded_image = np.pad(image, [(0, pad_bottom), (0, pad_right), (0, 0)], mode="constant", constant_values=0)
else:
raise ValueError(f"Invalid channel dimension format: {input_channel_dimension}")
if data_format is not None:
padded_image = to_channel_dimension_format(padded_image, data_format)
return padded_image
def make_pixel_mask(image: np.ndarray, output_size: Tuple[int, int]) -> np.ndarray:
"""
Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.
Args:
image (`np.ndarray`):
Image to make the pixel mask for.
output_size (`Tuple[int, int]`):
Output size of the mask.
"""
input_height, input_width = get_image_size(image)
mask = np.zeros(output_size, dtype=np.int64)
mask[:input_height, :input_width] = 1
return mask
def get_max_dimensions(images: List[np.ndarray]) -> List[int]:
"""
Get the maximum height and width across all images in a batch.
"""
input_channel_dimension = infer_channel_dimension_format(images[0])
if input_channel_dimension == ChannelDimension.FIRST:
_, max_height, max_width = max_across_indices([img.shape for img in images])
elif input_channel_dimension == ChannelDimension.LAST:
max_height, max_width, _ = max_across_indices([img.shape for img in images])
else:
raise ValueError(f"Invalid channel dimension format: {input_channel_dimension}")
return (max_height, max_width)
def get_resize_output_image_size(
input_image: np.ndarray, shorter: int = 800, longer: int = 1333, size_divisor: int = 32
) -> Tuple[int, int]:
input_height, input_width = get_image_size(input_image)
min_size, max_size = shorter, longer
scale = min_size / min(input_height, input_width)
if input_height < input_width:
new_height = min_size
new_width = scale * input_width
else:
new_height = scale * input_height
new_width = min_size
if max(new_height, new_width) > max_size:
scale = max_size / max(new_height, new_width)
new_height = scale * new_height
new_width = scale * new_width
new_height, new_width = int(new_height + 0.5), int(new_width + 0.5)
new_height = new_height // size_divisor * size_divisor
new_width = new_width // size_divisor * size_divisor
return new_height, new_width
class ViltImageProcessor(BaseImageProcessor):
r"""
Constructs a ViLT image processor.
Args:
do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the
`do_resize` parameter in the `preprocess` method.
size (`Dict[str, int]` *optional*, defaults to `{"shortest_edge": 384}`):
Resize the shorter side of the input to `size["shortest_edge"]`. The longer side will be limited to under
`int((1333 / 800) * size["shortest_edge"])` while preserving the aspect ratio. Only has an effect if
`do_resize` is set to `True`. Can be overridden by the `size` parameter in the `preprocess` method.
size_divisor (`int`, *optional*, defaults to 32):
The size by which to make sure both the height and width can be divided. Only has an effect if `do_resize`
is set to `True`. Can be overridden by the `size_divisor` parameter in the `preprocess` method.
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`. Can be
overridden by the `resample` parameter in the `preprocess` method.
do_rescale (`bool`, *optional*, defaults to `True`):
Wwhether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the
`do_rescale` parameter in the `preprocess` method.
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
Scale factor to use if rescaling the image. Only has an effect if `do_rescale` is set to `True`. Can be
overridden by the `rescale_factor` parameter in the `preprocess` method.
do_normalize (`bool`, *optional*, defaults to `True`):
Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
method. Can be overridden by the `do_normalize` parameter in the `preprocess` method.
image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can be
overridden by the `image_mean` parameter in the `preprocess` method.
image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
Can be overridden by the `image_std` parameter in the `preprocess` method.
do_pad (`bool`, *optional*, defaults to `True`):
Whether to pad the image to the `(max_height, max_width)` of the images in the batch. Can be overridden by
the `do_pad` parameter in the `preprocess` method.
"""
model_input_names = ["pixel_values"]
def __init__(
self,
do_resize: bool = True,
size: Dict[str, int] = None,
size_divisor: int = 32,
resample: PILImageResampling = PILImageResampling.BICUBIC,
do_rescale: bool = True,
rescale_factor: Union[int, float] = 1 / 255,
do_normalize: bool = True,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_pad: bool = True,
**kwargs
) -> None:
if "pad_and_return_pixel_mask" in kwargs:
do_pad = kwargs.pop("pad_and_return_pixel_mask")
super().__init__(**kwargs)
size = size if size is not None else {"shortest_edge": 384}
size = get_size_dict(size, default_to_square=False)
self.do_resize = do_resize
self.size = size
self.size_divisor = size_divisor
self.resample = resample
self.do_rescale = do_rescale
self.rescale_factor = rescale_factor
self.do_normalize = do_normalize
self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
self.do_pad = do_pad
def resize(
self,
image: np.ndarray,
size: Dict[str, int],
size_divisor: int = 32,
resample: PILImageResampling = PILImageResampling.BICUBIC,
data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs
) -> np.ndarray:
"""
Resize an image.
Resizes the shorter side of the image to `size["shortest_edge"]` while preserving the aspect ratio. If the
longer side is larger than the max size `(int(`size["shortest_edge"]` * 1333 / 800))`, the longer side is then
resized to the max size while preserving the aspect ratio.
Args:
image (`np.ndarray`):
Image to resize.
size (`Dict[str, int]`):
Controls the size of the output image. Should be of the form `{"shortest_edge": int}`.
size_divisor (`int`, defaults to 32):
The image is resized to a size that is a multiple of this value.
resample (`PILImageResampling` filter, *optional*, defaults to `PILImageResampling.BICUBIC`):
Resampling filter to use when resiizing the image.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
size = get_size_dict(size, default_to_square=False)
if "shortest_edge" not in size:
raise ValueError(f"The `size` dictionary must contain the key `shortest_edge`. Got {size.keys()}")
shorter = size["shortest_edge"]
longer = int(1333 / 800 * shorter)
output_size = get_resize_output_image_size(image, shorter=shorter, longer=longer, size_divisor=size_divisor)
return resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs)
def rescale(
self,
image: np.ndarray,
scale: Union[int, float],
data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs
):
"""
Rescale an image by a scale factor. image = image * scale.
Args:
image (`np.ndarray`):
Image to rescale.
scale (`int` or `float`):
Scale to apply to the image.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
return rescale(image, scale=scale, data_format=data_format, **kwargs)
def normalize(
self,
image: np.ndarray,
mean: Union[float, List[float]],
std: Union[float, List[float]],
data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs
) -> np.ndarray:
"""
Normalize an image. image = (image - image_mean) / image_std.
Args:
image (`np.ndarray`):
Image to normalize.
mean (`float` or `List[float]`):
Image mean.
std (`float` or `List[float]`):
Image standard deviation.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
return normalize(image, mean=mean, std=std, data_format=data_format, **kwargs)
def pad(
self,
images: List[np.ndarray],
return_pixel_mask: bool = True,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Optional[ChannelDimension] = None,
) -> BatchFeature:
"""
Pads a batch of images with zeros to the size of largest height and width in the batch and optionally returns
their corresponding pixel mask.
Args:
images (`List[np.ndarray]`):
Batch of images to pad.
return_pixel_mask (`bool`, *optional*, defaults to `False`):
Whether to return the pixel mask.
return_tensors (`str` or `TensorType`, *optional*):
The type of tensors to return. Can be one of:
- Unset: Return a list of `np.ndarray`.
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
pad_size = get_max_dimensions(images)
padded_images = [pad(image=image, output_size=pad_size, data_format=data_format) for image in images]
data = {"pixel_values": padded_images}
if return_pixel_mask:
masks = [make_pixel_mask(image=image, output_size=pad_size) for image in images]
data["pixel_mask"] = masks
return BatchFeature(data=data, tensor_type=return_tensors)
def pad_and_create_pixel_mask(
self,
pixel_values_list: List[ImageInput],
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Optional[ChannelDimension] = None,
) -> BatchFeature:
"""
Pads a batch of images with zeros to the size of largest height and width in the batch and returns their
corresponding pixel mask.
Args:
images (`List[np.ndarray]`):
Batch of images to pad.
return_tensors (`str` or `TensorType`, *optional*):
The type of tensors to return. Can be one of:
- Unset: Return a list of `np.ndarray`.
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
warnings.warn(
"This method is deprecated and will be removed in v4.26.0. Please use pad instead.", FutureWarning
)
# pad expects a list of np.ndarray, but the previous feature extractors expected torch tensors
images = [to_numpy_array(image) for image in pixel_values_list]
return self.pad(
images=images,
return_pixel_mask=True,
return_tensors=return_tensors,
data_format=data_format,
)
def preprocess(
self,
images: ImageInput,
do_resize: Optional[bool] = None,
size: Optional[Dict[str, int]] = None,
size_divisor: Optional[int] = None,
resample: PILImageResampling = None,
do_rescale: Optional[bool] = None,
rescale_factor: Optional[float] = None,
do_normalize: Optional[bool] = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_pad: Optional[bool] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: ChannelDimension = ChannelDimension.FIRST,
) -> PIL.Image.Image:
"""
Preprocess an image or batch of images.
Args:
images (`ImageInput`):
Image to preprocess.
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
Whether to resize the image.
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
Controls the size of the image after `resize`. The shortest edge of the image is resized to
`size["shortest_edge"]` whilst preserving the aspect ratio. If the longest edge of this resized image
is > `int(size["shortest_edge"] * (1333 / 800))`, then the image is resized again to make the longest
edge equal to `int(size["shortest_edge"] * (1333 / 800))`.
size_divisor (`int`, *optional*, defaults to `self.size_divisor`):
The image is resized to a size that is a multiple of this value.
resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`.
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
Whether to rescale the image values between [0 - 1].
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
Whether to normalize the image.
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
Image mean to normalize the image by if `do_normalize` is set to `True`.
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
Image standard deviation to normalize the image by if `do_normalize` is set to `True`.
do_pad (`bool`, *optional*, defaults to `self.do_pad`):
Whether to pad the image to the (max_height, max_width) in the batch. If `True`, a pixel mask is also
created and returned.
return_tensors (`str` or `TensorType`, *optional*):
The type of tensors to return. Can be one of:
- Unset: Return a list of `np.ndarray`.
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
The channel dimension format for the output image. Can be one of:
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
"""
do_resize = do_resize if do_resize is not None else self.do_resize
size_divisor = size_divisor if size_divisor is not None else self.size_divisor
resample = resample if resample is not None else self.resample
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
image_mean = image_mean if image_mean is not None else self.image_mean
image_std = image_std if image_std is not None else self.image_std
do_pad = do_pad if do_pad is not None else self.do_pad
size = size if size is not None else self.size
size = get_size_dict(size, default_to_square=False)
if not is_batched(images):
images = [images]
if not valid_images(images):
raise ValueError(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
if do_resize and size is None or resample is None:
raise ValueError("Size and resample must be specified if do_resize is True.")
if do_rescale and rescale_factor is None:
raise ValueError("Rescale factor must be specified if do_rescale is True.")
if do_normalize and (image_mean is None or image_std is None):
raise ValueError("Image mean and std must be specified if do_normalize is True.")
# All transformations expect numpy arrays.
images = [to_numpy_array(image) for image in images]
if do_resize:
images = [
self.resize(image=image, size=size, size_divisor=size_divisor, resample=resample) for image in images
]
if do_rescale:
images = [self.rescale(image=image, scale=rescale_factor) for image in images]
if do_normalize:
images = [self.normalize(image=image, mean=image_mean, std=image_std) for image in images]
images = [to_channel_dimension_format(image, data_format) for image in images]
if do_pad:
encoded_outputs = self.pad(images, return_pixel_mask=True, return_tensors=return_tensors)
else:
encoded_outputs = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors)
return encoded_outputs
...@@ -14,139 +14,12 @@ ...@@ -14,139 +14,12 @@
# limitations under the License. # limitations under the License.
"""Feature extractor class for ViT.""" """Feature extractor class for ViT."""
from typing import Optional, Union from ...utils import logging
from .image_processing_vit import ViTImageProcessor
import numpy as np
from PIL import Image
from transformers.image_utils import PILImageResampling
from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
from ...image_utils import (
IMAGENET_STANDARD_MEAN,
IMAGENET_STANDARD_STD,
ImageFeatureExtractionMixin,
ImageInput,
is_torch_tensor,
)
from ...utils import TensorType, logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
class ViTFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin): # Feature extractor for ViT is being replaced by image processor
r""" ViTFeatureExtractor = ViTImageProcessor
Constructs a ViT feature extractor.
This feature extractor inherits from [`FeatureExtractionMixin`] which contains most of the main methods. Users
should refer to this superclass for more information regarding those methods.
Args:
do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize the input to a certain `size`.
size (`int` or `Tuple(int)`, *optional*, defaults to 224):
Resize the input to the given size. If a tuple is provided, it should be (width, height). If only an
integer is provided, then the input will be resized to (size, size). Only has an effect if `do_resize` is
set to `True`.
resample (`int`, *optional*, defaults to `PIL.Image.Resampling.BILINEAR`):
An optional resampling filter. This can be one of `PIL.Image.Resampling.NEAREST`,
`PIL.Image.Resampling.BOX`, `PIL.Image.Resampling.BILINEAR`, `PIL.Image.Resampling.HAMMING`,
`PIL.Image.Resampling.BICUBIC` or `PIL.Image.Resampling.LANCZOS`. Only has an effect if `do_resize` is set
to `True`.
do_normalize (`bool`, *optional*, defaults to `True`):
Whether or not to normalize the input with mean and standard deviation.
image_mean (`List[int]`, defaults to `[0.5, 0.5, 0.5]`):
The sequence of means for each channel, to be used when normalizing images.
image_std (`List[int]`, defaults to `[0.5, 0.5, 0.5]`):
The sequence of standard deviations for each channel, to be used when normalizing images.
"""
model_input_names = ["pixel_values"]
def __init__(
self,
do_resize=True,
size=224,
resample=PILImageResampling.BILINEAR,
do_normalize=True,
image_mean=None,
image_std=None,
**kwargs
):
super().__init__(**kwargs)
self.do_resize = do_resize
self.size = size
self.resample = resample
self.do_normalize = do_normalize
self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
def __call__(
self, images: ImageInput, return_tensors: Optional[Union[str, TensorType]] = None, **kwargs
) -> BatchFeature:
"""
Main method to prepare for the model one or several image(s).
<Tip warning={true}>
NumPy arrays and PyTorch tensors are converted to PIL images when resizing, so the most efficient is to pass
PIL images.
</Tip>
Args:
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
number of channels, H and W are image height and width.
return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`):
If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **pixel_values** -- Pixel values to be fed to a model, of shape (batch_size, num_channels, height,
width).
"""
# Input type checking for clearer error
valid_images = False
# Check that images has a valid type
if isinstance(images, (Image.Image, np.ndarray)) or is_torch_tensor(images):
valid_images = True
elif isinstance(images, (list, tuple)):
if len(images) == 0 or isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]):
valid_images = True
if not valid_images:
raise ValueError(
"Images must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single example), "
"`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of examples)."
)
is_batched = bool(
isinstance(images, (list, tuple))
and (isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]))
)
if not is_batched:
images = [images]
# transformations (resizing + normalization)
if self.do_resize and self.size is not None:
images = [self.resize(image=image, size=self.size, resample=self.resample) for image in images]
if self.do_normalize:
images = [self.normalize(image=image, mean=self.image_mean, std=self.image_std) for image in images]
# return as BatchFeature
data = {"pixel_values": images}
encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
return encoded_inputs
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Image processor class for ViT."""
from typing import Dict, List, Optional, Union
import numpy as np
from transformers.utils.generic import TensorType
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import normalize, rescale, resize, to_channel_dimension_format
from ...image_utils import (
IMAGENET_STANDARD_MEAN,
IMAGENET_STANDARD_STD,
ChannelDimension,
ImageInput,
PILImageResampling,
is_batched,
to_numpy_array,
valid_images,
)
from ...utils import logging
logger = logging.get_logger(__name__)
class ViTImageProcessor(BaseImageProcessor):
r"""
Constructs a ViT image processor.
Args:
do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize the image's (height, width) dimensions to the specified `(size["height"],
size["width"])`. Can be overridden by the `do_resize` parameter in the `preprocess` method.
size (`dict`, *optional*, defaults to `{"height": 224, "width": 224}`):
Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess`
method.
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the
`preprocess` method.
do_rescale (`bool`, *optional*, defaults to `True`):
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale`
parameter in the `preprocess` method.
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the
`preprocess` method.
do_normalize:
Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
method.
image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
"""
model_input_names = ["pixel_values"]
def __init__(
self,
do_resize: bool = True,
size: Optional[Dict[str, int]] = None,
resample: PILImageResampling = PILImageResampling.BILINEAR,
do_rescale: bool = True,
rescale_factor: Union[int, float] = 1 / 255,
do_normalize: bool = True,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
**kwargs
) -> None:
super().__init__(**kwargs)
size = size if size is not None else {"height": 224, "width": 224}
size = get_size_dict(size)
self.do_resize = do_resize
self.do_rescale = do_rescale
self.do_normalize = do_normalize
self.size = size
self.resample = resample
self.rescale_factor = rescale_factor
self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
def resize(
self,
image: np.ndarray,
size: Dict[str, int],
resample: PILImageResampling = PILImageResampling.BILINEAR,
data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs
) -> np.ndarray:
"""
Resize an image to `(size["height"], size["width"])`.
Args:
image (`np.ndarray`):
Image to resize.
size (`Dict[str, int]`):
Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
resample:
`PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BILINEAR`.
data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the output image. If unset, the channel dimension format of the input
image is used. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
Returns:
`np.ndarray`: The resized image.
"""
size = get_size_dict(size)
if "height" not in size or "width" not in size:
raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}")
return resize(
image, size=(size["height"], size["width"]), resample=resample, data_format=data_format, **kwargs
)
def rescale(
self, image: np.ndarray, scale: float, data_format: Optional[Union[str, ChannelDimension]] = None, **kwargs
) -> np.ndarray:
"""
Rescale an image by a scale factor. image = image * scale.
Args:
image (`np.ndarray`):
Image to rescale.
scale (`float`):
The scaling factor to rescale pixel values by.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format for the output image. If unset, the channel dimension format of the input
image is used. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
Returns:
`np.ndarray`: The rescaled image.
"""
return rescale(image, scale=scale, data_format=data_format, **kwargs)
def normalize(
self,
image: np.ndarray,
mean: Union[float, List[float]],
std: Union[float, List[float]],
data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs
) -> np.ndarray:
"""
Normalize an image. image = (image - image_mean) / image_std.
Args:
image (`np.ndarray`):
Image to normalize.
mean (`float` or `List[float]`):
Image mean to use for normalization.
std (`float` or `List[float]`):
Image standard deviation to use for normalization.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format for the output image. If unset, the channel dimension format of the input
image is used. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
Returns:
`np.ndarray`: The normalized image.
"""
return normalize(image, mean=mean, std=std, data_format=data_format, **kwargs)
def preprocess(
self,
images: ImageInput,
do_resize: Optional[bool] = None,
size: Dict[str, int] = None,
resample: PILImageResampling = None,
do_rescale: Optional[bool] = None,
rescale_factor: Optional[float] = None,
do_normalize: Optional[bool] = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
):
"""
Preprocess an image or batch of images.
Args:
images (`ImageInput`):
Image to preprocess.
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
Whether to resize the image.
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
Dictionary in the format `{"height": h, "width": w}` specifying the size of the output image after
resizing.
resample (`PILImageResampling` filter, *optional*, defaults to `self.resample`):
`PILImageResampling` filter to use if resizing the image e.g. `PILImageResampling.BILINEAR`. Only has
an effect if `do_resize` is set to `True`.
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
Whether to rescale the image values between [0 - 1].
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
Whether to normalize the image.
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
Image mean to use if `do_normalize` is set to `True`.
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
Image standard deviation to use if `do_normalize` is set to `True`.
return_tensors (`str` or `TensorType`, *optional*):
The type of tensors to return. Can be one of:
- Unset: Return a list of `np.ndarray`.
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
The channel dimension format for the output image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- Unset: Use the channel dimension format of the input image.
"""
do_resize = do_resize if do_resize is not None else self.do_resize
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
resample = resample if resample is not None else self.resample
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
image_mean = image_mean if image_mean is not None else self.image_mean
image_std = image_std if image_std is not None else self.image_std
size = size if size is not None else self.size
size_dict = get_size_dict(size)
if not is_batched(images):
images = [images]
if not valid_images(images):
raise ValueError(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
if do_resize and size is None:
raise ValueError("Size must be specified if do_resize is True.")
if do_rescale and rescale_factor is None:
raise ValueError("Rescale factor must be specified if do_rescale is True.")
# All transformations expect numpy arrays.
images = [to_numpy_array(image) for image in images]
if do_resize:
images = [self.resize(image=image, size=size_dict, resample=resample) for image in images]
if do_rescale:
images = [self.rescale(image=image, scale=rescale_factor) for image in images]
if do_normalize:
images = [self.normalize(image=image, mean=image_mean, std=image_std) for image in images]
images = [to_channel_dimension_format(image, data_format) for image in images]
data = {"pixel_values": images}
return BatchFeature(data=data, tensor_type=return_tensors)
...@@ -44,14 +44,16 @@ class BeitFeatureExtractionTester(unittest.TestCase): ...@@ -44,14 +44,16 @@ class BeitFeatureExtractionTester(unittest.TestCase):
min_resolution=30, min_resolution=30,
max_resolution=400, max_resolution=400,
do_resize=True, do_resize=True,
size=20, size=None,
do_center_crop=True, do_center_crop=True,
crop_size=18, crop_size=None,
do_normalize=True, do_normalize=True,
image_mean=[0.5, 0.5, 0.5], image_mean=[0.5, 0.5, 0.5],
image_std=[0.5, 0.5, 0.5], image_std=[0.5, 0.5, 0.5],
reduce_labels=False, do_reduce_labels=False,
): ):
size = size if size is not None else {"height": 20, "width": 20}
crop_size = crop_size if crop_size is not None else {"height": 18, "width": 18}
self.parent = parent self.parent = parent
self.batch_size = batch_size self.batch_size = batch_size
self.num_channels = num_channels self.num_channels = num_channels
...@@ -65,7 +67,7 @@ class BeitFeatureExtractionTester(unittest.TestCase): ...@@ -65,7 +67,7 @@ class BeitFeatureExtractionTester(unittest.TestCase):
self.do_normalize = do_normalize self.do_normalize = do_normalize
self.image_mean = image_mean self.image_mean = image_mean
self.image_std = image_std self.image_std = image_std
self.reduce_labels = reduce_labels self.do_reduce_labels = do_reduce_labels
def prepare_feat_extract_dict(self): def prepare_feat_extract_dict(self):
return { return {
...@@ -76,7 +78,7 @@ class BeitFeatureExtractionTester(unittest.TestCase): ...@@ -76,7 +78,7 @@ class BeitFeatureExtractionTester(unittest.TestCase):
"do_normalize": self.do_normalize, "do_normalize": self.do_normalize,
"image_mean": self.image_mean, "image_mean": self.image_mean,
"image_std": self.image_std, "image_std": self.image_std,
"reduce_labels": self.reduce_labels, "do_reduce_labels": self.do_reduce_labels,
} }
...@@ -141,8 +143,8 @@ class BeitFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC ...@@ -141,8 +143,8 @@ class BeitFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
( (
1, 1,
self.feature_extract_tester.num_channels, self.feature_extract_tester.num_channels,
self.feature_extract_tester.crop_size, self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.crop_size, self.feature_extract_tester.crop_size["width"],
), ),
) )
...@@ -153,8 +155,8 @@ class BeitFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC ...@@ -153,8 +155,8 @@ class BeitFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
( (
self.feature_extract_tester.batch_size, self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels, self.feature_extract_tester.num_channels,
self.feature_extract_tester.crop_size, self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.crop_size, self.feature_extract_tester.crop_size["width"],
), ),
) )
...@@ -173,8 +175,8 @@ class BeitFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC ...@@ -173,8 +175,8 @@ class BeitFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
( (
1, 1,
self.feature_extract_tester.num_channels, self.feature_extract_tester.num_channels,
self.feature_extract_tester.crop_size, self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.crop_size, self.feature_extract_tester.crop_size["width"],
), ),
) )
...@@ -185,8 +187,8 @@ class BeitFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC ...@@ -185,8 +187,8 @@ class BeitFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
( (
self.feature_extract_tester.batch_size, self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels, self.feature_extract_tester.num_channels,
self.feature_extract_tester.crop_size, self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.crop_size, self.feature_extract_tester.crop_size["width"],
), ),
) )
...@@ -205,8 +207,8 @@ class BeitFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC ...@@ -205,8 +207,8 @@ class BeitFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
( (
1, 1,
self.feature_extract_tester.num_channels, self.feature_extract_tester.num_channels,
self.feature_extract_tester.crop_size, self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.crop_size, self.feature_extract_tester.crop_size["width"],
), ),
) )
...@@ -217,8 +219,8 @@ class BeitFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC ...@@ -217,8 +219,8 @@ class BeitFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
( (
self.feature_extract_tester.batch_size, self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels, self.feature_extract_tester.num_channels,
self.feature_extract_tester.crop_size, self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.crop_size, self.feature_extract_tester.crop_size["width"],
), ),
) )
...@@ -239,16 +241,16 @@ class BeitFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC ...@@ -239,16 +241,16 @@ class BeitFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
( (
1, 1,
self.feature_extract_tester.num_channels, self.feature_extract_tester.num_channels,
self.feature_extract_tester.crop_size, self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.crop_size, self.feature_extract_tester.crop_size["width"],
), ),
) )
self.assertEqual( self.assertEqual(
encoding["labels"].shape, encoding["labels"].shape,
( (
1, 1,
self.feature_extract_tester.crop_size, self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.crop_size, self.feature_extract_tester.crop_size["width"],
), ),
) )
self.assertEqual(encoding["labels"].dtype, torch.long) self.assertEqual(encoding["labels"].dtype, torch.long)
...@@ -262,16 +264,16 @@ class BeitFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC ...@@ -262,16 +264,16 @@ class BeitFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
( (
self.feature_extract_tester.batch_size, self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels, self.feature_extract_tester.num_channels,
self.feature_extract_tester.crop_size, self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.crop_size, self.feature_extract_tester.crop_size["width"],
), ),
) )
self.assertEqual( self.assertEqual(
encoding["labels"].shape, encoding["labels"].shape,
( (
self.feature_extract_tester.batch_size, self.feature_extract_tester.batch_size,
self.feature_extract_tester.crop_size, self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.crop_size, self.feature_extract_tester.crop_size["width"],
), ),
) )
self.assertEqual(encoding["labels"].dtype, torch.long) self.assertEqual(encoding["labels"].dtype, torch.long)
...@@ -287,16 +289,16 @@ class BeitFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC ...@@ -287,16 +289,16 @@ class BeitFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
( (
1, 1,
self.feature_extract_tester.num_channels, self.feature_extract_tester.num_channels,
self.feature_extract_tester.crop_size, self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.crop_size, self.feature_extract_tester.crop_size["width"],
), ),
) )
self.assertEqual( self.assertEqual(
encoding["labels"].shape, encoding["labels"].shape,
( (
1, 1,
self.feature_extract_tester.crop_size, self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.crop_size, self.feature_extract_tester.crop_size["width"],
), ),
) )
self.assertEqual(encoding["labels"].dtype, torch.long) self.assertEqual(encoding["labels"].dtype, torch.long)
...@@ -312,16 +314,16 @@ class BeitFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC ...@@ -312,16 +314,16 @@ class BeitFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
( (
2, 2,
self.feature_extract_tester.num_channels, self.feature_extract_tester.num_channels,
self.feature_extract_tester.crop_size, self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.crop_size, self.feature_extract_tester.crop_size["width"],
), ),
) )
self.assertEqual( self.assertEqual(
encoding["labels"].shape, encoding["labels"].shape,
( (
2, 2,
self.feature_extract_tester.crop_size, self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.crop_size, self.feature_extract_tester.crop_size["width"],
), ),
) )
self.assertEqual(encoding["labels"].dtype, torch.long) self.assertEqual(encoding["labels"].dtype, torch.long)
......
...@@ -43,14 +43,16 @@ class CLIPFeatureExtractionTester(unittest.TestCase): ...@@ -43,14 +43,16 @@ class CLIPFeatureExtractionTester(unittest.TestCase):
min_resolution=30, min_resolution=30,
max_resolution=400, max_resolution=400,
do_resize=True, do_resize=True,
size=20, size=None,
do_center_crop=True, do_center_crop=True,
crop_size=18, crop_size=None,
do_normalize=True, do_normalize=True,
image_mean=[0.48145466, 0.4578275, 0.40821073], image_mean=[0.48145466, 0.4578275, 0.40821073],
image_std=[0.26862954, 0.26130258, 0.27577711], image_std=[0.26862954, 0.26130258, 0.27577711],
do_convert_rgb=True, do_convert_rgb=True,
): ):
size = size if size is not None else {"shortest_edge": 20}
crop_size = crop_size if crop_size is not None else {"height": 18, "width": 18}
self.parent = parent self.parent = parent
self.batch_size = batch_size self.batch_size = batch_size
self.num_channels = num_channels self.num_channels = num_channels
...@@ -151,8 +153,8 @@ class CLIPFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC ...@@ -151,8 +153,8 @@ class CLIPFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
( (
1, 1,
self.feature_extract_tester.num_channels, self.feature_extract_tester.num_channels,
self.feature_extract_tester.crop_size, self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.crop_size, self.feature_extract_tester.crop_size["width"],
), ),
) )
...@@ -163,8 +165,8 @@ class CLIPFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC ...@@ -163,8 +165,8 @@ class CLIPFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
( (
self.feature_extract_tester.batch_size, self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels, self.feature_extract_tester.num_channels,
self.feature_extract_tester.crop_size, self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.crop_size, self.feature_extract_tester.crop_size["width"],
), ),
) )
...@@ -183,8 +185,8 @@ class CLIPFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC ...@@ -183,8 +185,8 @@ class CLIPFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
( (
1, 1,
self.feature_extract_tester.num_channels, self.feature_extract_tester.num_channels,
self.feature_extract_tester.crop_size, self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.crop_size, self.feature_extract_tester.crop_size["width"],
), ),
) )
...@@ -195,8 +197,8 @@ class CLIPFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC ...@@ -195,8 +197,8 @@ class CLIPFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
( (
self.feature_extract_tester.batch_size, self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels, self.feature_extract_tester.num_channels,
self.feature_extract_tester.crop_size, self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.crop_size, self.feature_extract_tester.crop_size["width"],
), ),
) )
...@@ -215,8 +217,8 @@ class CLIPFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC ...@@ -215,8 +217,8 @@ class CLIPFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
( (
1, 1,
self.feature_extract_tester.num_channels, self.feature_extract_tester.num_channels,
self.feature_extract_tester.crop_size, self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.crop_size, self.feature_extract_tester.crop_size["width"],
), ),
) )
...@@ -227,8 +229,8 @@ class CLIPFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC ...@@ -227,8 +229,8 @@ class CLIPFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
( (
self.feature_extract_tester.batch_size, self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels, self.feature_extract_tester.num_channels,
self.feature_extract_tester.crop_size, self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.crop_size, self.feature_extract_tester.crop_size["width"],
), ),
) )
...@@ -276,8 +278,8 @@ class CLIPFeatureExtractionTestFourChannels(FeatureExtractionSavingTestMixin, un ...@@ -276,8 +278,8 @@ class CLIPFeatureExtractionTestFourChannels(FeatureExtractionSavingTestMixin, un
( (
1, 1,
self.expected_encoded_image_num_channels, self.expected_encoded_image_num_channels,
self.feature_extract_tester.crop_size, self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.crop_size, self.feature_extract_tester.crop_size["width"],
), ),
) )
...@@ -288,7 +290,7 @@ class CLIPFeatureExtractionTestFourChannels(FeatureExtractionSavingTestMixin, un ...@@ -288,7 +290,7 @@ class CLIPFeatureExtractionTestFourChannels(FeatureExtractionSavingTestMixin, un
( (
self.feature_extract_tester.batch_size, self.feature_extract_tester.batch_size,
self.expected_encoded_image_num_channels, self.expected_encoded_image_num_channels,
self.feature_extract_tester.crop_size, self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.crop_size, self.feature_extract_tester.crop_size["width"],
), ),
) )
...@@ -43,12 +43,13 @@ class ConvNextFeatureExtractionTester(unittest.TestCase): ...@@ -43,12 +43,13 @@ class ConvNextFeatureExtractionTester(unittest.TestCase):
min_resolution=30, min_resolution=30,
max_resolution=400, max_resolution=400,
do_resize=True, do_resize=True,
size=20, size=None,
crop_pct=0.875, crop_pct=0.875,
do_normalize=True, do_normalize=True,
image_mean=[0.5, 0.5, 0.5], image_mean=[0.5, 0.5, 0.5],
image_std=[0.5, 0.5, 0.5], image_std=[0.5, 0.5, 0.5],
): ):
size = size if size is not None else {"shortest_edge": 20}
self.parent = parent self.parent = parent
self.batch_size = batch_size self.batch_size = batch_size
self.num_channels = num_channels self.num_channels = num_channels
...@@ -113,8 +114,8 @@ class ConvNextFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.T ...@@ -113,8 +114,8 @@ class ConvNextFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.T
( (
1, 1,
self.feature_extract_tester.num_channels, self.feature_extract_tester.num_channels,
self.feature_extract_tester.size, self.feature_extract_tester.size["shortest_edge"],
self.feature_extract_tester.size, self.feature_extract_tester.size["shortest_edge"],
), ),
) )
...@@ -125,8 +126,8 @@ class ConvNextFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.T ...@@ -125,8 +126,8 @@ class ConvNextFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.T
( (
self.feature_extract_tester.batch_size, self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels, self.feature_extract_tester.num_channels,
self.feature_extract_tester.size, self.feature_extract_tester.size["shortest_edge"],
self.feature_extract_tester.size, self.feature_extract_tester.size["shortest_edge"],
), ),
) )
...@@ -145,8 +146,8 @@ class ConvNextFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.T ...@@ -145,8 +146,8 @@ class ConvNextFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.T
( (
1, 1,
self.feature_extract_tester.num_channels, self.feature_extract_tester.num_channels,
self.feature_extract_tester.size, self.feature_extract_tester.size["shortest_edge"],
self.feature_extract_tester.size, self.feature_extract_tester.size["shortest_edge"],
), ),
) )
...@@ -157,8 +158,8 @@ class ConvNextFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.T ...@@ -157,8 +158,8 @@ class ConvNextFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.T
( (
self.feature_extract_tester.batch_size, self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels, self.feature_extract_tester.num_channels,
self.feature_extract_tester.size, self.feature_extract_tester.size["shortest_edge"],
self.feature_extract_tester.size, self.feature_extract_tester.size["shortest_edge"],
), ),
) )
...@@ -177,8 +178,8 @@ class ConvNextFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.T ...@@ -177,8 +178,8 @@ class ConvNextFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.T
( (
1, 1,
self.feature_extract_tester.num_channels, self.feature_extract_tester.num_channels,
self.feature_extract_tester.size, self.feature_extract_tester.size["shortest_edge"],
self.feature_extract_tester.size, self.feature_extract_tester.size["shortest_edge"],
), ),
) )
...@@ -189,7 +190,7 @@ class ConvNextFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.T ...@@ -189,7 +190,7 @@ class ConvNextFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.T
( (
self.feature_extract_tester.batch_size, self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels, self.feature_extract_tester.num_channels,
self.feature_extract_tester.size, self.feature_extract_tester.size["shortest_edge"],
self.feature_extract_tester.size, self.feature_extract_tester.size["shortest_edge"],
), ),
) )
...@@ -43,13 +43,16 @@ class DeiTFeatureExtractionTester(unittest.TestCase): ...@@ -43,13 +43,16 @@ class DeiTFeatureExtractionTester(unittest.TestCase):
min_resolution=30, min_resolution=30,
max_resolution=400, max_resolution=400,
do_resize=True, do_resize=True,
size=20, size=None,
do_center_crop=True, do_center_crop=True,
crop_size=18, crop_size=None,
do_normalize=True, do_normalize=True,
image_mean=[0.5, 0.5, 0.5], image_mean=[0.5, 0.5, 0.5],
image_std=[0.5, 0.5, 0.5], image_std=[0.5, 0.5, 0.5],
): ):
size = size if size is not None else {"height": 20, "width": 20}
crop_size = crop_size if crop_size is not None else {"height": 18, "width": 18}
self.parent = parent self.parent = parent
self.batch_size = batch_size self.batch_size = batch_size
self.num_channels = num_channels self.num_channels = num_channels
...@@ -117,8 +120,8 @@ class DeiTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC ...@@ -117,8 +120,8 @@ class DeiTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
( (
1, 1,
self.feature_extract_tester.num_channels, self.feature_extract_tester.num_channels,
self.feature_extract_tester.crop_size, self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.crop_size, self.feature_extract_tester.crop_size["width"],
), ),
) )
...@@ -129,8 +132,8 @@ class DeiTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC ...@@ -129,8 +132,8 @@ class DeiTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
( (
self.feature_extract_tester.batch_size, self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels, self.feature_extract_tester.num_channels,
self.feature_extract_tester.crop_size, self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.crop_size, self.feature_extract_tester.crop_size["width"],
), ),
) )
...@@ -149,8 +152,8 @@ class DeiTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC ...@@ -149,8 +152,8 @@ class DeiTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
( (
1, 1,
self.feature_extract_tester.num_channels, self.feature_extract_tester.num_channels,
self.feature_extract_tester.crop_size, self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.crop_size, self.feature_extract_tester.crop_size["width"],
), ),
) )
...@@ -161,8 +164,8 @@ class DeiTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC ...@@ -161,8 +164,8 @@ class DeiTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
( (
self.feature_extract_tester.batch_size, self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels, self.feature_extract_tester.num_channels,
self.feature_extract_tester.crop_size, self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.crop_size, self.feature_extract_tester.crop_size["width"],
), ),
) )
...@@ -181,8 +184,8 @@ class DeiTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC ...@@ -181,8 +184,8 @@ class DeiTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
( (
1, 1,
self.feature_extract_tester.num_channels, self.feature_extract_tester.num_channels,
self.feature_extract_tester.crop_size, self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.crop_size, self.feature_extract_tester.crop_size["width"],
), ),
) )
...@@ -193,7 +196,7 @@ class DeiTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC ...@@ -193,7 +196,7 @@ class DeiTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
( (
self.feature_extract_tester.batch_size, self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels, self.feature_extract_tester.num_channels,
self.feature_extract_tester.crop_size, self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.crop_size, self.feature_extract_tester.crop_size["width"],
), ),
) )
...@@ -43,11 +43,12 @@ class DPTFeatureExtractionTester(unittest.TestCase): ...@@ -43,11 +43,12 @@ class DPTFeatureExtractionTester(unittest.TestCase):
min_resolution=30, min_resolution=30,
max_resolution=400, max_resolution=400,
do_resize=True, do_resize=True,
size=18, size=None,
do_normalize=True, do_normalize=True,
image_mean=[0.5, 0.5, 0.5], image_mean=[0.5, 0.5, 0.5],
image_std=[0.5, 0.5, 0.5], image_std=[0.5, 0.5, 0.5],
): ):
size = size if size is not None else {"height": 18, "width": 18}
self.parent = parent self.parent = parent
self.batch_size = batch_size self.batch_size = batch_size
self.num_channels = num_channels self.num_channels = num_channels
...@@ -106,8 +107,8 @@ class DPTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCa ...@@ -106,8 +107,8 @@ class DPTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCa
( (
1, 1,
self.feature_extract_tester.num_channels, self.feature_extract_tester.num_channels,
self.feature_extract_tester.size, self.feature_extract_tester.size["height"],
self.feature_extract_tester.size, self.feature_extract_tester.size["width"],
), ),
) )
...@@ -118,8 +119,8 @@ class DPTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCa ...@@ -118,8 +119,8 @@ class DPTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCa
( (
self.feature_extract_tester.batch_size, self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels, self.feature_extract_tester.num_channels,
self.feature_extract_tester.size, self.feature_extract_tester.size["height"],
self.feature_extract_tester.size, self.feature_extract_tester.size["width"],
), ),
) )
...@@ -138,8 +139,8 @@ class DPTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCa ...@@ -138,8 +139,8 @@ class DPTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCa
( (
1, 1,
self.feature_extract_tester.num_channels, self.feature_extract_tester.num_channels,
self.feature_extract_tester.size, self.feature_extract_tester.size["height"],
self.feature_extract_tester.size, self.feature_extract_tester.size["width"],
), ),
) )
...@@ -150,8 +151,8 @@ class DPTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCa ...@@ -150,8 +151,8 @@ class DPTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCa
( (
self.feature_extract_tester.batch_size, self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels, self.feature_extract_tester.num_channels,
self.feature_extract_tester.size, self.feature_extract_tester.size["height"],
self.feature_extract_tester.size, self.feature_extract_tester.size["width"],
), ),
) )
...@@ -170,8 +171,8 @@ class DPTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCa ...@@ -170,8 +171,8 @@ class DPTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCa
( (
1, 1,
self.feature_extract_tester.num_channels, self.feature_extract_tester.num_channels,
self.feature_extract_tester.size, self.feature_extract_tester.size["height"],
self.feature_extract_tester.size, self.feature_extract_tester.size["width"],
), ),
) )
...@@ -182,7 +183,7 @@ class DPTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCa ...@@ -182,7 +183,7 @@ class DPTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCa
( (
self.feature_extract_tester.batch_size, self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels, self.feature_extract_tester.num_channels,
self.feature_extract_tester.size, self.feature_extract_tester.size["height"],
self.feature_extract_tester.size, self.feature_extract_tester.size["width"],
), ),
) )
...@@ -28,11 +28,10 @@ if is_torch_available(): ...@@ -28,11 +28,10 @@ if is_torch_available():
import torch import torch
if is_vision_available(): if is_vision_available():
from PIL import Image import PIL
from transformers import FlavaFeatureExtractor from transformers import FlavaFeatureExtractor
from transformers.image_utils import PILImageResampling from transformers.models.flava.image_processing_flava import (
from transformers.models.flava.feature_extraction_flava import (
FLAVA_CODEBOOK_MEAN, FLAVA_CODEBOOK_MEAN,
FLAVA_CODEBOOK_STD, FLAVA_CODEBOOK_STD,
FLAVA_IMAGE_MEAN, FLAVA_IMAGE_MEAN,
...@@ -51,10 +50,12 @@ class FlavaFeatureExtractionTester(unittest.TestCase): ...@@ -51,10 +50,12 @@ class FlavaFeatureExtractionTester(unittest.TestCase):
min_resolution=30, min_resolution=30,
max_resolution=400, max_resolution=400,
do_resize=True, do_resize=True,
size=224, size=None,
do_center_crop=True, do_center_crop=True,
crop_size=224, crop_size=None,
resample=None, resample=None,
do_rescale=True,
rescale_factor=1 / 255,
do_normalize=True, do_normalize=True,
image_mean=FLAVA_IMAGE_MEAN, image_mean=FLAVA_IMAGE_MEAN,
image_std=FLAVA_IMAGE_STD, image_std=FLAVA_IMAGE_STD,
...@@ -65,23 +66,30 @@ class FlavaFeatureExtractionTester(unittest.TestCase): ...@@ -65,23 +66,30 @@ class FlavaFeatureExtractionTester(unittest.TestCase):
mask_group_min_aspect_ratio=0.3, mask_group_min_aspect_ratio=0.3,
mask_group_max_aspect_ratio=None, mask_group_max_aspect_ratio=None,
codebook_do_resize=True, codebook_do_resize=True,
codebook_size=112, codebook_size=None,
codebook_resample=None, codebook_resample=None,
codebook_do_center_crop=True, codebook_do_center_crop=True,
codebook_crop_size=112, codebook_crop_size=None,
codebook_do_map_pixels=True, codebook_do_map_pixels=True,
codebook_do_normalize=True, codebook_do_normalize=True,
codebook_image_mean=FLAVA_CODEBOOK_MEAN, codebook_image_mean=FLAVA_CODEBOOK_MEAN,
codebook_image_std=FLAVA_CODEBOOK_STD, codebook_image_std=FLAVA_CODEBOOK_STD,
): ):
size = size if size is not None else {"height": 224, "width": 224}
crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
codebook_size = codebook_size if codebook_size is not None else {"height": 112, "width": 112}
codebook_crop_size = codebook_crop_size if codebook_crop_size is not None else {"height": 112, "width": 112}
self.parent = parent self.parent = parent
self.batch_size = batch_size self.batch_size = batch_size
self.num_channels = num_channels self.num_channels = num_channels
self.do_resize = do_resize self.do_resize = do_resize
self.do_rescale = do_rescale
self.rescale_factor = rescale_factor
self.min_resolution = min_resolution self.min_resolution = min_resolution
self.max_resolution = max_resolution self.max_resolution = max_resolution
self.size = size self.size = size
self.resample = resample if resample is not None else PILImageResampling.BICUBIC self.resample = resample if resample is not None else PIL.Image.Resampling.BICUBIC
self.do_normalize = do_normalize self.do_normalize = do_normalize
self.image_mean = image_mean self.image_mean = image_mean
self.image_std = image_std self.image_std = image_std
...@@ -97,7 +105,7 @@ class FlavaFeatureExtractionTester(unittest.TestCase): ...@@ -97,7 +105,7 @@ class FlavaFeatureExtractionTester(unittest.TestCase):
self.codebook_do_resize = codebook_do_resize self.codebook_do_resize = codebook_do_resize
self.codebook_size = codebook_size self.codebook_size = codebook_size
self.codebook_resample = codebook_resample if codebook_resample is not None else PILImageResampling.LANCZOS self.codebook_resample = codebook_resample if codebook_resample is not None else PIL.Image.Resampling.LANCZOS
self.codebook_do_center_crop = codebook_do_center_crop self.codebook_do_center_crop = codebook_do_center_crop
self.codebook_crop_size = codebook_crop_size self.codebook_crop_size = codebook_crop_size
self.codebook_do_map_pixels = codebook_do_map_pixels self.codebook_do_map_pixels = codebook_do_map_pixels
...@@ -113,6 +121,8 @@ class FlavaFeatureExtractionTester(unittest.TestCase): ...@@ -113,6 +121,8 @@ class FlavaFeatureExtractionTester(unittest.TestCase):
"do_resize": self.do_resize, "do_resize": self.do_resize,
"size": self.size, "size": self.size,
"resample": self.resample, "resample": self.resample,
"do_rescale": self.do_rescale,
"rescale_factor": self.rescale_factor,
"do_center_crop": self.do_center_crop, "do_center_crop": self.do_center_crop,
"crop_size": self.crop_size, "crop_size": self.crop_size,
"input_size_patches": self.input_size_patches, "input_size_patches": self.input_size_patches,
...@@ -133,7 +143,7 @@ class FlavaFeatureExtractionTester(unittest.TestCase): ...@@ -133,7 +143,7 @@ class FlavaFeatureExtractionTester(unittest.TestCase):
} }
def get_expected_image_size(self): def get_expected_image_size(self):
return (self.size, self.size) if not isinstance(self.size, tuple) else self.size return (self.size["height"], self.size["width"])
def get_expected_mask_size(self): def get_expected_mask_size(self):
return ( return (
...@@ -143,10 +153,7 @@ class FlavaFeatureExtractionTester(unittest.TestCase): ...@@ -143,10 +153,7 @@ class FlavaFeatureExtractionTester(unittest.TestCase):
) )
def get_expected_codebook_image_size(self): def get_expected_codebook_image_size(self):
if not isinstance(self.codebook_size, tuple): return (self.codebook_size["height"], self.codebook_size["width"])
return (self.codebook_size, self.codebook_size)
else:
return self.codebook_size
@require_torch @require_torch
...@@ -172,6 +179,8 @@ class FlavaFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.Test ...@@ -172,6 +179,8 @@ class FlavaFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.Test
self.assertTrue(hasattr(feature_extractor, "resample")) self.assertTrue(hasattr(feature_extractor, "resample"))
self.assertTrue(hasattr(feature_extractor, "crop_size")) self.assertTrue(hasattr(feature_extractor, "crop_size"))
self.assertTrue(hasattr(feature_extractor, "do_center_crop")) self.assertTrue(hasattr(feature_extractor, "do_center_crop"))
self.assertTrue(hasattr(feature_extractor, "do_rescale"))
self.assertTrue(hasattr(feature_extractor, "rescale_factor"))
self.assertTrue(hasattr(feature_extractor, "masking_generator")) self.assertTrue(hasattr(feature_extractor, "masking_generator"))
self.assertTrue(hasattr(feature_extractor, "codebook_do_resize")) self.assertTrue(hasattr(feature_extractor, "codebook_do_resize"))
self.assertTrue(hasattr(feature_extractor, "codebook_size")) self.assertTrue(hasattr(feature_extractor, "codebook_size"))
...@@ -192,7 +201,7 @@ class FlavaFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.Test ...@@ -192,7 +201,7 @@ class FlavaFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.Test
# create random PIL images # create random PIL images
image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False) image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False)
for image in image_inputs: for image in image_inputs:
self.assertIsInstance(image, Image.Image) self.assertIsInstance(image, PIL.Image.Image)
# Test not batched input # Test not batched input
encoded_images = feature_extractor(image_inputs[0], return_tensors="pt") encoded_images = feature_extractor(image_inputs[0], return_tensors="pt")
...@@ -324,7 +333,7 @@ class FlavaFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.Test ...@@ -324,7 +333,7 @@ class FlavaFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.Test
# create random PIL images # create random PIL images
image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False) image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False)
for image in image_inputs: for image in image_inputs:
self.assertIsInstance(image, Image.Image) self.assertIsInstance(image, PIL.Image.Image)
# Test not batched input # Test not batched input
encoded_images = feature_extractor(image_inputs[0], return_codebook_pixels=True, return_tensors="pt") encoded_images = feature_extractor(image_inputs[0], return_codebook_pixels=True, return_tensors="pt")
......
...@@ -32,7 +32,7 @@ if is_vision_available(): ...@@ -32,7 +32,7 @@ if is_vision_available():
from PIL import Image from PIL import Image
from transformers import FlavaFeatureExtractor, FlavaProcessor from transformers import FlavaFeatureExtractor, FlavaProcessor
from transformers.models.flava.feature_extraction_flava import ( from transformers.models.flava.image_processing_flava import (
FLAVA_CODEBOOK_MEAN, FLAVA_CODEBOOK_MEAN,
FLAVA_CODEBOOK_STD, FLAVA_CODEBOOK_STD,
FLAVA_IMAGE_MEAN, FLAVA_IMAGE_MEAN,
...@@ -69,7 +69,6 @@ class FlavaProcessorTest(unittest.TestCase): ...@@ -69,7 +69,6 @@ class FlavaProcessorTest(unittest.TestCase):
"mask_group_max_aspect_ratio": None, "mask_group_max_aspect_ratio": None,
"codebook_do_resize": True, "codebook_do_resize": True,
"codebook_size": 112, "codebook_size": 112,
"codebook_resample": None,
"codebook_do_center_crop": True, "codebook_do_center_crop": True,
"codebook_crop_size": 112, "codebook_crop_size": 112,
"codebook_do_map_pixels": True, "codebook_do_map_pixels": True,
......
...@@ -47,9 +47,10 @@ class ImageGPTFeatureExtractionTester(unittest.TestCase): ...@@ -47,9 +47,10 @@ class ImageGPTFeatureExtractionTester(unittest.TestCase):
min_resolution=30, min_resolution=30,
max_resolution=400, max_resolution=400,
do_resize=True, do_resize=True,
size=18, size=None,
do_normalize=True, do_normalize=True,
): ):
size = size if size is not None else {"height": 18, "width": 18}
self.parent = parent self.parent = parent
self.batch_size = batch_size self.batch_size = batch_size
self.num_channels = num_channels self.num_channels = num_channels
......
...@@ -43,9 +43,10 @@ class LayoutLMv2FeatureExtractionTester(unittest.TestCase): ...@@ -43,9 +43,10 @@ class LayoutLMv2FeatureExtractionTester(unittest.TestCase):
min_resolution=30, min_resolution=30,
max_resolution=400, max_resolution=400,
do_resize=True, do_resize=True,
size=18, size=None,
apply_ocr=True, apply_ocr=True,
): ):
size = size if size is not None else {"height": 18, "width": 18}
self.parent = parent self.parent = parent
self.batch_size = batch_size self.batch_size = batch_size
self.num_channels = num_channels self.num_channels = num_channels
...@@ -97,8 +98,8 @@ class LayoutLMv2FeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest ...@@ -97,8 +98,8 @@ class LayoutLMv2FeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
( (
1, 1,
self.feature_extract_tester.num_channels, self.feature_extract_tester.num_channels,
self.feature_extract_tester.size, self.feature_extract_tester.size["height"],
self.feature_extract_tester.size, self.feature_extract_tester.size["width"],
), ),
) )
...@@ -112,8 +113,8 @@ class LayoutLMv2FeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest ...@@ -112,8 +113,8 @@ class LayoutLMv2FeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
( (
self.feature_extract_tester.batch_size, self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels, self.feature_extract_tester.num_channels,
self.feature_extract_tester.size, self.feature_extract_tester.size["height"],
self.feature_extract_tester.size, self.feature_extract_tester.size["width"],
), ),
) )
...@@ -132,8 +133,8 @@ class LayoutLMv2FeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest ...@@ -132,8 +133,8 @@ class LayoutLMv2FeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
( (
1, 1,
self.feature_extract_tester.num_channels, self.feature_extract_tester.num_channels,
self.feature_extract_tester.size, self.feature_extract_tester.size["height"],
self.feature_extract_tester.size, self.feature_extract_tester.size["width"],
), ),
) )
...@@ -144,8 +145,8 @@ class LayoutLMv2FeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest ...@@ -144,8 +145,8 @@ class LayoutLMv2FeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
( (
self.feature_extract_tester.batch_size, self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels, self.feature_extract_tester.num_channels,
self.feature_extract_tester.size, self.feature_extract_tester.size["height"],
self.feature_extract_tester.size, self.feature_extract_tester.size["width"],
), ),
) )
...@@ -164,8 +165,8 @@ class LayoutLMv2FeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest ...@@ -164,8 +165,8 @@ class LayoutLMv2FeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
( (
1, 1,
self.feature_extract_tester.num_channels, self.feature_extract_tester.num_channels,
self.feature_extract_tester.size, self.feature_extract_tester.size["height"],
self.feature_extract_tester.size, self.feature_extract_tester.size["width"],
), ),
) )
...@@ -176,8 +177,8 @@ class LayoutLMv2FeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest ...@@ -176,8 +177,8 @@ class LayoutLMv2FeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
( (
self.feature_extract_tester.batch_size, self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels, self.feature_extract_tester.num_channels,
self.feature_extract_tester.size, self.feature_extract_tester.size["height"],
self.feature_extract_tester.size, self.feature_extract_tester.size["width"],
), ),
) )
...@@ -210,12 +211,4 @@ class LayoutLMv2FeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest ...@@ -210,12 +211,4 @@ class LayoutLMv2FeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
encoding = feature_extractor(image, return_tensors="pt") encoding = feature_extractor(image, return_tensors="pt")
self.assertEqual( self.assertEqual(encoding.pixel_values.shape, (1, 3, 224, 224))
encoding.pixel_values.shape,
(
1,
3,
224,
224,
),
)
...@@ -43,9 +43,10 @@ class LayoutLMv3FeatureExtractionTester(unittest.TestCase): ...@@ -43,9 +43,10 @@ class LayoutLMv3FeatureExtractionTester(unittest.TestCase):
min_resolution=30, min_resolution=30,
max_resolution=400, max_resolution=400,
do_resize=True, do_resize=True,
size=18, size=None,
apply_ocr=True, apply_ocr=True,
): ):
size = size if size is not None else {"height": 18, "width": 18}
self.parent = parent self.parent = parent
self.batch_size = batch_size self.batch_size = batch_size
self.num_channels = num_channels self.num_channels = num_channels
...@@ -97,8 +98,8 @@ class LayoutLMv3FeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest ...@@ -97,8 +98,8 @@ class LayoutLMv3FeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
( (
1, 1,
self.feature_extract_tester.num_channels, self.feature_extract_tester.num_channels,
self.feature_extract_tester.size, self.feature_extract_tester.size["height"],
self.feature_extract_tester.size, self.feature_extract_tester.size["width"],
), ),
) )
...@@ -112,8 +113,8 @@ class LayoutLMv3FeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest ...@@ -112,8 +113,8 @@ class LayoutLMv3FeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
( (
self.feature_extract_tester.batch_size, self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels, self.feature_extract_tester.num_channels,
self.feature_extract_tester.size, self.feature_extract_tester.size["height"],
self.feature_extract_tester.size, self.feature_extract_tester.size["width"],
), ),
) )
...@@ -132,8 +133,8 @@ class LayoutLMv3FeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest ...@@ -132,8 +133,8 @@ class LayoutLMv3FeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
( (
1, 1,
self.feature_extract_tester.num_channels, self.feature_extract_tester.num_channels,
self.feature_extract_tester.size, self.feature_extract_tester.size["height"],
self.feature_extract_tester.size, self.feature_extract_tester.size["width"],
), ),
) )
...@@ -144,8 +145,8 @@ class LayoutLMv3FeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest ...@@ -144,8 +145,8 @@ class LayoutLMv3FeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
( (
self.feature_extract_tester.batch_size, self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels, self.feature_extract_tester.num_channels,
self.feature_extract_tester.size, self.feature_extract_tester.size["height"],
self.feature_extract_tester.size, self.feature_extract_tester.size["width"],
), ),
) )
...@@ -164,8 +165,8 @@ class LayoutLMv3FeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest ...@@ -164,8 +165,8 @@ class LayoutLMv3FeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
( (
1, 1,
self.feature_extract_tester.num_channels, self.feature_extract_tester.num_channels,
self.feature_extract_tester.size, self.feature_extract_tester.size["height"],
self.feature_extract_tester.size, self.feature_extract_tester.size["width"],
), ),
) )
...@@ -176,8 +177,8 @@ class LayoutLMv3FeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest ...@@ -176,8 +177,8 @@ class LayoutLMv3FeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
( (
self.feature_extract_tester.batch_size, self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels, self.feature_extract_tester.num_channels,
self.feature_extract_tester.size, self.feature_extract_tester.size["height"],
self.feature_extract_tester.size, self.feature_extract_tester.size["width"],
), ),
) )
......
...@@ -43,12 +43,15 @@ class LevitFeatureExtractionTester(unittest.TestCase): ...@@ -43,12 +43,15 @@ class LevitFeatureExtractionTester(unittest.TestCase):
min_resolution=30, min_resolution=30,
max_resolution=400, max_resolution=400,
do_resize=True, do_resize=True,
size=18, size=None,
do_center_crop=True, do_center_crop=True,
crop_size=None,
do_normalize=True, do_normalize=True,
image_mean=[0.5, 0.5, 0.5], image_mean=[0.5, 0.5, 0.5],
image_std=[0.5, 0.5, 0.5], image_std=[0.5, 0.5, 0.5],
): ):
size = size if size is not None else {"shortest_edge": 18}
crop_size = crop_size if crop_size is not None else {"height": 18, "width": 18}
self.parent = parent self.parent = parent
self.batch_size = batch_size self.batch_size = batch_size
self.num_channels = num_channels self.num_channels = num_channels
...@@ -58,6 +61,7 @@ class LevitFeatureExtractionTester(unittest.TestCase): ...@@ -58,6 +61,7 @@ class LevitFeatureExtractionTester(unittest.TestCase):
self.do_resize = do_resize self.do_resize = do_resize
self.size = size self.size = size
self.do_center_crop = do_center_crop self.do_center_crop = do_center_crop
self.crop_size = crop_size
self.do_normalize = do_normalize self.do_normalize = do_normalize
self.image_mean = image_mean self.image_mean = image_mean
self.image_std = image_std self.image_std = image_std
...@@ -70,6 +74,7 @@ class LevitFeatureExtractionTester(unittest.TestCase): ...@@ -70,6 +74,7 @@ class LevitFeatureExtractionTester(unittest.TestCase):
"do_resize": self.do_resize, "do_resize": self.do_resize,
"do_center_crop": self.do_center_crop, "do_center_crop": self.do_center_crop,
"size": self.size, "size": self.size,
"crop_size": self.crop_size,
} }
...@@ -113,8 +118,8 @@ class LevitFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.Test ...@@ -113,8 +118,8 @@ class LevitFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.Test
( (
1, 1,
self.feature_extract_tester.num_channels, self.feature_extract_tester.num_channels,
self.feature_extract_tester.size, self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.size, self.feature_extract_tester.crop_size["width"],
), ),
) )
...@@ -125,8 +130,8 @@ class LevitFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.Test ...@@ -125,8 +130,8 @@ class LevitFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.Test
( (
self.feature_extract_tester.batch_size, self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels, self.feature_extract_tester.num_channels,
self.feature_extract_tester.size, self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.size, self.feature_extract_tester.crop_size["width"],
), ),
) )
...@@ -145,8 +150,8 @@ class LevitFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.Test ...@@ -145,8 +150,8 @@ class LevitFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.Test
( (
1, 1,
self.feature_extract_tester.num_channels, self.feature_extract_tester.num_channels,
self.feature_extract_tester.size, self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.size, self.feature_extract_tester.crop_size["width"],
), ),
) )
...@@ -157,8 +162,8 @@ class LevitFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.Test ...@@ -157,8 +162,8 @@ class LevitFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.Test
( (
self.feature_extract_tester.batch_size, self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels, self.feature_extract_tester.num_channels,
self.feature_extract_tester.size, self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.size, self.feature_extract_tester.crop_size["width"],
), ),
) )
...@@ -177,8 +182,8 @@ class LevitFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.Test ...@@ -177,8 +182,8 @@ class LevitFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.Test
( (
1, 1,
self.feature_extract_tester.num_channels, self.feature_extract_tester.num_channels,
self.feature_extract_tester.size, self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.size, self.feature_extract_tester.crop_size["width"],
), ),
) )
...@@ -189,7 +194,7 @@ class LevitFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.Test ...@@ -189,7 +194,7 @@ class LevitFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.Test
( (
self.feature_extract_tester.batch_size, self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels, self.feature_extract_tester.num_channels,
self.feature_extract_tester.size, self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.size, self.feature_extract_tester.crop_size["width"],
), ),
) )
...@@ -43,11 +43,13 @@ class MobileViTFeatureExtractionTester(unittest.TestCase): ...@@ -43,11 +43,13 @@ class MobileViTFeatureExtractionTester(unittest.TestCase):
min_resolution=30, min_resolution=30,
max_resolution=400, max_resolution=400,
do_resize=True, do_resize=True,
size=20, size=None,
do_center_crop=True, do_center_crop=True,
crop_size=18, crop_size=None,
do_flip_channel_order=True, do_flip_channel_order=True,
): ):
size = size if size is not None else {"shortest_edge": 20}
crop_size = crop_size if crop_size is not None else {"height": 18, "width": 18}
self.parent = parent self.parent = parent
self.batch_size = batch_size self.batch_size = batch_size
self.num_channels = num_channels self.num_channels = num_channels
...@@ -109,8 +111,8 @@ class MobileViTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest. ...@@ -109,8 +111,8 @@ class MobileViTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.
( (
1, 1,
self.feature_extract_tester.num_channels, self.feature_extract_tester.num_channels,
self.feature_extract_tester.crop_size, self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.crop_size, self.feature_extract_tester.crop_size["width"],
), ),
) )
...@@ -121,8 +123,8 @@ class MobileViTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest. ...@@ -121,8 +123,8 @@ class MobileViTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.
( (
self.feature_extract_tester.batch_size, self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels, self.feature_extract_tester.num_channels,
self.feature_extract_tester.crop_size, self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.crop_size, self.feature_extract_tester.crop_size["width"],
), ),
) )
...@@ -141,8 +143,8 @@ class MobileViTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest. ...@@ -141,8 +143,8 @@ class MobileViTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.
( (
1, 1,
self.feature_extract_tester.num_channels, self.feature_extract_tester.num_channels,
self.feature_extract_tester.crop_size, self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.crop_size, self.feature_extract_tester.crop_size["width"],
), ),
) )
...@@ -153,8 +155,8 @@ class MobileViTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest. ...@@ -153,8 +155,8 @@ class MobileViTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.
( (
self.feature_extract_tester.batch_size, self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels, self.feature_extract_tester.num_channels,
self.feature_extract_tester.crop_size, self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.crop_size, self.feature_extract_tester.crop_size["width"],
), ),
) )
...@@ -173,8 +175,8 @@ class MobileViTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest. ...@@ -173,8 +175,8 @@ class MobileViTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.
( (
1, 1,
self.feature_extract_tester.num_channels, self.feature_extract_tester.num_channels,
self.feature_extract_tester.crop_size, self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.crop_size, self.feature_extract_tester.crop_size["width"],
), ),
) )
...@@ -185,7 +187,7 @@ class MobileViTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest. ...@@ -185,7 +187,7 @@ class MobileViTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.
( (
self.feature_extract_tester.batch_size, self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels, self.feature_extract_tester.num_channels,
self.feature_extract_tester.crop_size, self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.crop_size, self.feature_extract_tester.crop_size["width"],
), ),
) )
...@@ -41,12 +41,15 @@ class PoolFormerFeatureExtractionTester(unittest.TestCase): ...@@ -41,12 +41,15 @@ class PoolFormerFeatureExtractionTester(unittest.TestCase):
min_resolution=30, min_resolution=30,
max_resolution=400, max_resolution=400,
do_resize_and_center_crop=True, do_resize_and_center_crop=True,
size=30, size=None,
crop_pct=0.9, crop_pct=0.9,
crop_size=None,
do_normalize=True, do_normalize=True,
image_mean=[0.5, 0.5, 0.5], image_mean=[0.5, 0.5, 0.5],
image_std=[0.5, 0.5, 0.5], image_std=[0.5, 0.5, 0.5],
): ):
size = size if size is not None else {"shortest_edge": 30}
crop_size = crop_size if crop_size is not None else {"height": 30, "width": 30}
self.parent = parent self.parent = parent
self.batch_size = batch_size self.batch_size = batch_size
self.num_channels = num_channels self.num_channels = num_channels
...@@ -55,6 +58,7 @@ class PoolFormerFeatureExtractionTester(unittest.TestCase): ...@@ -55,6 +58,7 @@ class PoolFormerFeatureExtractionTester(unittest.TestCase):
self.do_resize_and_center_crop = do_resize_and_center_crop self.do_resize_and_center_crop = do_resize_and_center_crop
self.size = size self.size = size
self.crop_pct = crop_pct self.crop_pct = crop_pct
self.crop_size = crop_size
self.do_normalize = do_normalize self.do_normalize = do_normalize
self.image_mean = image_mean self.image_mean = image_mean
self.image_std = image_std self.image_std = image_std
...@@ -64,6 +68,7 @@ class PoolFormerFeatureExtractionTester(unittest.TestCase): ...@@ -64,6 +68,7 @@ class PoolFormerFeatureExtractionTester(unittest.TestCase):
"size": self.size, "size": self.size,
"do_resize_and_center_crop": self.do_resize_and_center_crop, "do_resize_and_center_crop": self.do_resize_and_center_crop,
"crop_pct": self.crop_pct, "crop_pct": self.crop_pct,
"crop_size": self.crop_size,
"do_normalize": self.do_normalize, "do_normalize": self.do_normalize,
"image_mean": self.image_mean, "image_mean": self.image_mean,
"image_std": self.image_std, "image_std": self.image_std,
...@@ -111,8 +116,8 @@ class PoolFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest ...@@ -111,8 +116,8 @@ class PoolFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
( (
1, 1,
self.feature_extract_tester.num_channels, self.feature_extract_tester.num_channels,
self.feature_extract_tester.size, self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.size, self.feature_extract_tester.crop_size["width"],
), ),
) )
...@@ -123,8 +128,8 @@ class PoolFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest ...@@ -123,8 +128,8 @@ class PoolFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
( (
self.feature_extract_tester.batch_size, self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels, self.feature_extract_tester.num_channels,
self.feature_extract_tester.size, self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.size, self.feature_extract_tester.crop_size["width"],
), ),
) )
...@@ -143,8 +148,8 @@ class PoolFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest ...@@ -143,8 +148,8 @@ class PoolFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
( (
1, 1,
self.feature_extract_tester.num_channels, self.feature_extract_tester.num_channels,
self.feature_extract_tester.size, self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.size, self.feature_extract_tester.crop_size["width"],
), ),
) )
...@@ -155,8 +160,8 @@ class PoolFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest ...@@ -155,8 +160,8 @@ class PoolFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
( (
self.feature_extract_tester.batch_size, self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels, self.feature_extract_tester.num_channels,
self.feature_extract_tester.size, self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.size, self.feature_extract_tester.crop_size["width"],
), ),
) )
...@@ -175,8 +180,8 @@ class PoolFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest ...@@ -175,8 +180,8 @@ class PoolFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
( (
1, 1,
self.feature_extract_tester.num_channels, self.feature_extract_tester.num_channels,
self.feature_extract_tester.size, self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.size, self.feature_extract_tester.crop_size["width"],
), ),
) )
...@@ -187,7 +192,7 @@ class PoolFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest ...@@ -187,7 +192,7 @@ class PoolFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
( (
self.feature_extract_tester.batch_size, self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels, self.feature_extract_tester.num_channels,
self.feature_extract_tester.size, self.feature_extract_tester.crop_size["height"],
self.feature_extract_tester.size, self.feature_extract_tester.crop_size["width"],
), ),
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment