Unverified Commit ae1cffaf authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

Add Donut image processor (#20425)



* Add Donut image processor

* Update src/transformers/image_transforms.py
Co-authored-by: default avatarAlara Dirik <8944735+alaradirik@users.noreply.github.com>

* Fix docstrings

* Full var names in docstring
Co-authored-by: default avatarAlara Dirik <8944735+alaradirik@users.noreply.github.com>
parent 28247e78
......@@ -194,6 +194,11 @@ We refer to the [tutorial notebooks](https://github.com/NielsRogge/Transformers-
[[autodoc]] DonutSwinConfig
## DonutImageProcessor
[[autodoc]] DonutImageProcessor
- preprocess
## DonutFeatureExtractor
[[autodoc]] DonutFeatureExtractor
......
......@@ -727,7 +727,7 @@ else:
_import_structure["models.deit"].extend(["DeiTFeatureExtractor", "DeiTImageProcessor"])
_import_structure["models.detr"].append("DetrFeatureExtractor")
_import_structure["models.conditional_detr"].append("ConditionalDetrFeatureExtractor")
_import_structure["models.donut"].append("DonutFeatureExtractor")
_import_structure["models.donut"].extend(["DonutFeatureExtractor", "DonutImageProcessor"])
_import_structure["models.dpt"].extend(["DPTFeatureExtractor", "DPTImageProcessor"])
_import_structure["models.flava"].extend(["FlavaFeatureExtractor", "FlavaProcessor", "FlavaImageProcessor"])
_import_structure["models.glpn"].extend(["GLPNFeatureExtractor", "GLPNImageProcessor"])
......@@ -3853,7 +3853,7 @@ if TYPE_CHECKING:
from .models.deformable_detr import DeformableDetrFeatureExtractor
from .models.deit import DeiTFeatureExtractor, DeiTImageProcessor
from .models.detr import DetrFeatureExtractor
from .models.donut import DonutFeatureExtractor
from .models.donut import DonutFeatureExtractor, DonutImageProcessor
from .models.dpt import DPTFeatureExtractor, DPTImageProcessor
from .models.flava import FlavaFeatureExtractor, FlavaImageProcessor, FlavaProcessor
from .models.glpn import GLPNFeatureExtractor, GLPNImageProcessor
......
......@@ -223,11 +223,12 @@ def resize(
image,
size: Tuple[int, int],
resample=PILImageResampling.BILINEAR,
reducing_gap: Optional[int] = None,
data_format: Optional[ChannelDimension] = None,
return_numpy: bool = True,
) -> np.ndarray:
"""
Resizes `image` to (h, w) specified by `size` using the PIL library.
Resizes `image` to `(height, width)` specified by `size` using the PIL library.
Args:
image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
......@@ -236,8 +237,11 @@ def resize(
The size to use for resizing the image.
resample (`int`, *optional*, defaults to `PILImageResampling.BILINEAR`):
The filter to user for resampling.
reducing_gap (`int`, *optional*):
Apply optimization by resizing the image in two steps. The bigger `reducing_gap`, the closer the result to
the fair resampling. See corresponding Pillow documentation for more details.
data_format (`ChannelDimension`, *optional*):
The channel dimension format of the output image. If `None`, will use the inferred format from the input.
The channel dimension format of the output image. If unset, will use the inferred format from the input.
return_numpy (`bool`, *optional*, defaults to `True`):
Whether or not to return the resized image as a numpy array. If False a `PIL.Image.Image` object is
returned.
......@@ -260,7 +264,7 @@ def resize(
image = to_pil_image(image)
height, width = size
# PIL images are in the format (width, height)
resized_image = image.resize((width, height), resample=resample)
resized_image = image.resize((width, height), resample=resample, reducing_gap=reducing_gap)
if return_numpy:
resized_image = np.array(resized_image)
......@@ -290,7 +294,7 @@ def normalize(
std (`float` or `Iterable[float]`):
The standard deviation to use for normalization.
data_format (`ChannelDimension`, *optional*):
The channel dimension format of the output image. If `None`, will use the inferred format from the input.
The channel dimension format of the output image. If unset, will use the inferred format from the input.
"""
if isinstance(image, PIL.Image.Image):
warnings.warn(
......
......@@ -44,6 +44,7 @@ IMAGE_PROCESSOR_MAPPING_NAMES = OrderedDict(
("data2vec-vision", "BeitImageProcessor"),
("deit", "DeiTImageProcessor"),
("dinat", "ViTImageProcessor"),
("donut-swin", "DonutImageProcessor"),
("dpt", "DPTImageProcessor"),
("flava", "FlavaImageProcessor"),
("glpn", "GLPNImageProcessor"),
......
......@@ -44,6 +44,7 @@ except OptionalDependencyNotAvailable:
pass
else:
_import_structure["feature_extraction_donut"] = ["DonutFeatureExtractor"]
_import_structure["image_processing_donut"] = ["DonutImageProcessor"]
if TYPE_CHECKING:
......@@ -69,6 +70,7 @@ if TYPE_CHECKING:
pass
else:
from .feature_extraction_donut import DonutFeatureExtractor
from .image_processing_donut import DonutImageProcessor
else:
import sys
......
......@@ -14,197 +14,11 @@
# limitations under the License.
"""Feature extractor class for Donut."""
from typing import Optional, Tuple, Union
import numpy as np
from PIL import Image, ImageOps
from transformers.image_utils import PILImageResampling
from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
from ...image_utils import (
IMAGENET_STANDARD_MEAN,
IMAGENET_STANDARD_STD,
ImageFeatureExtractionMixin,
ImageInput,
is_torch_tensor,
)
from ...utils import TensorType, logging
from ...utils import logging
from .image_processing_donut import DonutImageProcessor
logger = logging.get_logger(__name__)
class DonutFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
r"""
Constructs a Donut feature extractor.
This feature extractor inherits from [`FeatureExtractionMixin`] which contains most of the main methods. Users
should refer to this superclass for more information regarding those methods.
Args:
do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize the shorter edge of the input to the minimum value of a certain `size`.
size (`Tuple(int)`, *optional*, defaults to [1920, 2560]):
Resize the shorter edge of the input to the minimum value of the given size. Should be a tuple of (width,
height). Only has an effect if `do_resize` is set to `True`.
resample (`int`, *optional*, defaults to `PILImageResampling.BILINEAR`):
An optional resampling filter. This can be one of `PILImageResampling.NEAREST`, `PILImageResampling.BOX`,
`PILImageResampling.BILINEAR`, `PILImageResampling.HAMMING`, `PILImageResampling.BICUBIC` or
`PILImageResampling.LANCZOS`. Only has an effect if `do_resize` is set to `True`.
do_thumbnail (`bool`, *optional*, defaults to `True`):
Whether to thumbnail the input to the given `size`.
do_align_long_axis (`bool`, *optional*, defaults to `False`):
Whether to rotate the input if the height is greater than width.
do_pad (`bool`, *optional*, defaults to `True`):
Whether or not to pad the input to `size`.
do_normalize (`bool`, *optional*, defaults to `True`):
Whether or not to normalize the input with mean and standard deviation.
image_mean (`List[int]`, defaults to `[0.5, 0.5, 0.5]`):
The sequence of means for each channel, to be used when normalizing images.
image_std (`List[int]`, defaults to `[0.5, 0.5, 0.5]`):
The sequence of standard deviations for each channel, to be used when normalizing images.
"""
model_input_names = ["pixel_values"]
def __init__(
self,
do_resize=True,
size=[1920, 2560],
resample=PILImageResampling.BILINEAR,
do_thumbnail=True,
do_align_long_axis=False,
do_pad=True,
do_normalize=True,
image_mean=None,
image_std=None,
**kwargs
):
super().__init__(**kwargs)
self.do_resize = do_resize
self.size = size
self.resample = resample
self.do_thumbnail = do_thumbnail
self.do_align_long_axis = do_align_long_axis
self.do_pad = do_pad
self.do_normalize = do_normalize
self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
def rotate_image(self, image, size):
if not isinstance(image, Image.Image):
image = self.to_pil_image(image)
if (size[1] > size[0] and image.width > image.height) or (size[1] < size[0] and image.width < image.height):
image = self.rotate(image, angle=-90, expand=True)
return image
def thumbnail(self, image, size):
if not isinstance(image, Image.Image):
image = self.to_pil_image(image)
image.thumbnail((size[0], size[1]))
return image
def pad(self, image: Image.Image, size: Tuple[int, int], random_padding: bool = False) -> Image.Image:
delta_width = size[0] - image.width
delta_height = size[1] - image.height
if random_padding:
pad_width = np.random.randint(low=0, high=delta_width + 1)
pad_height = np.random.randint(low=0, high=delta_height + 1)
else:
pad_width = delta_width // 2
pad_height = delta_height // 2
padding = (pad_width, pad_height, delta_width - pad_width, delta_height - pad_height)
return ImageOps.expand(image, padding)
def __call__(
self,
images: ImageInput,
return_tensors: Optional[Union[str, TensorType]] = None,
random_padding=False,
**kwargs
) -> BatchFeature:
"""
Main method to prepare for the model one or several image(s).
<Tip warning={true}>
NumPy arrays and PyTorch tensors are converted to PIL images when resizing, so the most efficient is to pass
PIL images.
</Tip>
Args:
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
number of channels, H and W are image height and width.
random_padding (`bool`, *optional*, defaults to `False`):
Whether to randomly pad the input to `size`.
return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`):
If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **pixel_values** -- Pixel values to be fed to a model, of shape (batch_size, num_channels, height,
width).
"""
# Input type checking for clearer error
valid_images = False
# Check that images has a valid type
if isinstance(images, (Image.Image, np.ndarray)) or is_torch_tensor(images):
valid_images = True
elif isinstance(images, (list, tuple)):
if len(images) == 0 or isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]):
valid_images = True
if not valid_images:
raise ValueError(
"Images must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single example), "
"`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of examples)."
)
is_batched = bool(
isinstance(images, (list, tuple))
and (isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]))
)
if not is_batched:
images = [images]
# transformations (rotating + resizing + thumbnailing + padding + normalization)
if self.do_align_long_axis:
images = [self.rotate_image(image, self.size) for image in images]
if self.do_resize and self.size is not None:
images = [
self.resize(image=image, size=min(self.size), resample=self.resample, default_to_square=False)
for image in images
]
if self.do_thumbnail and self.size is not None:
images = [self.thumbnail(image=image, size=self.size) for image in images]
if self.do_pad and self.size is not None:
images = [self.pad(image=image, size=self.size, random_padding=random_padding) for image in images]
if self.do_normalize:
images = [self.normalize(image=image, mean=self.image_mean, std=self.image_std) for image in images]
# return as BatchFeature
data = {"pixel_values": images}
encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
return encoded_inputs
DonutFeatureExtractor = DonutImageProcessor
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Image processor class for Donut."""
from typing import Dict, List, Optional, Union
import numpy as np
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import (
get_resize_output_image_size,
normalize,
pad,
rescale,
resize,
to_channel_dimension_format,
)
from ...image_utils import (
IMAGENET_STANDARD_MEAN,
IMAGENET_STANDARD_STD,
ChannelDimension,
ImageInput,
PILImageResampling,
get_image_size,
is_batched,
to_numpy_array,
valid_images,
)
from ...utils import TensorType, logging
from ...utils.import_utils import is_vision_available
logger = logging.get_logger(__name__)
if is_vision_available():
import PIL
class DonutImageProcessor(BaseImageProcessor):
r"""
Constructs a Donut image processor.
Args:
do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by
`do_resize` in the `preprocess` method.
size (`Dict[str, int]` *optional*, defaults to `{"shortest_edge": 224}`):
Size of the image after resizing. The shortest edge of the image is resized to size["shortest_edge"], with
the longest edge resized to keep the input aspect ratio. Can be overridden by `size` in the `preprocess`
method.
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method.
do_center_crop (`bool`, *optional*, defaults to `True`):
Whether to center crop the image to the specified `crop_size`. Can be overridden by `do_center_crop` in the
`preprocess` method.
crop_size (`Dict[str, int]` *optional*, defaults to 224):
Size of the output image after applying `center_crop`. Can be overridden by `crop_size` in the `preprocess`
method.
do_rescale (`bool`, *optional*, defaults to `True`):
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in
the `preprocess` method.
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess`
method.
do_normalize:
Whether to normalize the image. Can be overridden by `do_normalize` in the `preprocess` method.
image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
Image standard deviation.
do_convert_rgb (`bool`, *optional*, defaults to `True`):
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
"""
model_input_names = ["pixel_values"]
def __init__(
self,
do_resize: bool = True,
size: Dict[str, int] = None,
resample: PILImageResampling = PILImageResampling.BILINEAR,
do_thumbnail: bool = True,
do_align_long_axis: bool = False,
do_pad: bool = True,
do_rescale: bool = True,
rescale_factor: Union[int, float] = 1 / 255,
do_normalize: bool = True,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
**kwargs
) -> None:
super().__init__(**kwargs)
size = size if size is not None else {"height": 2560, "width": 1920}
if isinstance(size, (tuple, list)):
# The previous feature extractor size parameter was in (width, height) format
size = size[::-1]
size = get_size_dict(size)
self.do_resize = do_resize
self.size = size
self.resample = resample
self.do_thumbnail = do_thumbnail
self.do_align_long_axis = do_align_long_axis
self.do_pad = do_pad
self.do_rescale = do_rescale
self.rescale_factor = rescale_factor
self.do_normalize = do_normalize
self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
def align_long_axis(
self, image: np.ndarray, size: Dict[str, int], data_format: Optional[Union[str, ChannelDimension]] = None
) -> np.ndarray:
"""
Align the long axis of the image to the longest axis of the specified size.
Args:
image (`np.ndarray`):
The image to be aligned.
size (`Dict[str, int]`):
The size `{"height": h, "width": w}` to align the long axis to.
Returns:
`np.ndarray`: The aligned image.
"""
input_height, input_width = get_image_size(image)
output_height, output_width = size["height"], size["width"]
if (output_width < output_height and input_width > input_height) or (
output_width > output_height and input_width < input_height
):
image = np.rot90(image, 3)
if data_format is not None:
image = to_channel_dimension_format(image, data_format)
return image
def rotate_image(self, *args, **kwargs):
logger.info(
"rotate_image is deprecated and will be removed in version 4.27. Please use align_long_axis instead."
)
return self.align_long_axis(*args, **kwargs)
def pad_image(
self,
image: np.ndarray,
size: Dict[str, int],
random_padding: bool = False,
data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray:
"""
Pad the image to the specified size.
Args:
image (`np.ndarray`):
The image to be padded.
size (`Dict[str, int]`):
The size `{"height": h, "width": w}` to pad the image to.
random_padding (`bool`, *optional*, defaults to `False`):
Whether to use random padding or not.
data_format (`str` or `ChannelDimension`, *optional*):
The data format of the output image. If unset, the same format as the input image is used.
"""
output_height, output_width = size["height"], size["width"]
input_height, input_width = get_image_size(image)
delta_width = output_width - input_width
delta_height = output_height - input_height
if random_padding:
pad_top = np.random.randint(low=0, high=delta_height + 1)
pad_left = np.random.randint(low=0, high=delta_width + 1)
else:
pad_top = delta_height // 2
pad_left = delta_width // 2
pad_bottom = delta_height - pad_top
pad_right = delta_width - pad_left
padding = ((pad_top, pad_bottom), (pad_left, pad_right))
return pad(image, padding, data_format=data_format)
def pad(self, *args, **kwargs):
logger.info("pad is deprecated and will be removed in version 4.27. Please use pad_image instead.")
return self.pad_image(*args, **kwargs)
def thumbnail(
self,
image: np.ndarray,
size: Dict[str, int],
resample: PILImageResampling = PILImageResampling.BICUBIC,
data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs
) -> np.ndarray:
"""
Resize the image to the specified size using thumbnail method.
Args:
image (`np.ndarray`):
The image to be resized.
size (`Dict[str, int]`):
The size `{"height": h, "width": w}` to resize the image to.
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
The resampling filter to use.
data_format (`Optional[Union[str, ChannelDimension]]`, *optional*):
The data format of the output image. If unset, the same format as the input image is used.
"""
output_size = (size["height"], size["width"])
return resize(image, size=output_size, resample=resample, reducing_gap=2.0, data_format=data_format, **kwargs)
def resize(
self,
image: np.ndarray,
size: Dict[str, int],
resample: PILImageResampling = PILImageResampling.BICUBIC,
data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs
) -> np.ndarray:
"""
Resize an image. The shortest edge of the image is resized to size["shortest_edge"], with the longest edge
resized to keep the input aspect ratio.
Args:
image (`np.ndarray`):
Image to resize.
size (`Dict[str, int]`):
Size of the output image.
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
Resampling filter to use when resiizing the image.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
size = get_size_dict(size)
shortest_edge = min(size["height"], size["width"])
output_size = get_resize_output_image_size(image, size=shortest_edge, default_to_square=False)
return resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs)
def rescale(
self,
image: np.ndarray,
scale: Union[int, float],
data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs
):
"""
Rescale an image by a scale factor. image = image * scale.
Args:
image (`np.ndarray`):
Image to rescale.
scale (`int` or `float`):
Scale to apply to the image.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
return rescale(image, scale=scale, data_format=data_format, **kwargs)
def normalize(
self,
image: np.ndarray,
mean: Union[float, List[float]],
std: Union[float, List[float]],
data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs
) -> np.ndarray:
"""
Normalize an image. image = (image - image_mean) / image_std.
Args:
image (`np.ndarray`):
Image to normalize.
image_mean (`float` or `List[float]`):
Image mean.
image_std (`float` or `List[float]`):
Image standard deviation.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
return normalize(image, mean=mean, std=std, data_format=data_format, **kwargs)
def preprocess(
self,
images: ImageInput,
do_resize: bool = None,
size: Dict[str, int] = None,
resample: PILImageResampling = None,
do_thumbnail: bool = None,
do_align_long_axis: bool = None,
do_pad: bool = None,
random_padding: bool = False,
do_rescale: bool = None,
rescale_factor: float = None,
do_normalize: bool = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
**kwargs
) -> PIL.Image.Image:
"""
Preprocess an image or batch of images.
Args:
images (`ImageInput`):
Image to preprocess.
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
Whether to resize the image.
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
Size of the image after resizing. Shortest edge of the image is resized to min(size["height"],
size["width"]) with the longest edge resized to keep the input aspect ratio.
resample (`int`, *optional*, defaults to `self.resample`):
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
has an effect if `do_resize` is set to `True`.
do_thumbnail (`bool`, *optional*, defaults to `self.do_thumbnail`):
Whether to resize the image using thumbnail method.
do_align_long_axis (`bool`, *optional*, defaults to `self.do_align_long_axis`):
Whether to align the long axis of the image with the long axis of `size` by rotating by 90 degrees.
do_pad (`bool`, *optional*, defaults to `self.do_pad`):
Whether to pad the image. If `random_padding` is set to `True`, each image is padded with a random
amont of padding on each size, up to the largest image size in the batch. Otherwise, all images are
padded to the largest image size in the batch.
random_padding (`bool`, *optional*, defaults to `self.random_padding`):
Whether to use random padding when padding the image. If `True`, each image in the batch with be padded
with a random amount of padding on each side up to the size of the largest image in the batch.
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
Whether to rescale the image pixel values.
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
Whether to normalize the image.
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
Image mean to use for normalization.
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
Image standard deviation to use for normalization.
return_tensors (`str` or `TensorType`, *optional*):
The type of tensors to return. Can be one of:
- Unset: Return a list of `np.ndarray`.
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
The channel dimension format for the output image. Can be one of:
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- Unset: defaults to the channel dimension format of the input image.
"""
do_resize = do_resize if do_resize is not None else self.do_resize
size = size if size is not None else self.size
if isinstance(size, (tuple, list)):
# Previous feature extractor had size in (width, height) format
size = size[::-1]
size = get_size_dict(size)
resample = resample if resample is not None else self.resample
do_thumbnail = do_thumbnail if do_thumbnail is not None else self.do_thumbnail
do_align_long_axis = do_align_long_axis if do_align_long_axis is not None else self.do_align_long_axis
do_pad = do_pad if do_pad is not None else self.do_pad
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
image_mean = image_mean if image_mean is not None else self.image_mean
image_std = image_std if image_std is not None else self.image_std
if not is_batched(images):
images = [images]
if not valid_images(images):
raise ValueError(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
if do_resize and size is None:
raise ValueError("Size must be specified if do_resize is True.")
if do_rescale and rescale_factor is None:
raise ValueError("Rescale factor must be specified if do_rescale is True.")
if do_pad and size is None:
raise ValueError("Size must be specified if do_pad is True.")
if do_normalize and (image_mean is None or image_std is None):
raise ValueError("Image mean and std must be specified if do_normalize is True.")
# All transformations expect numpy arrays.
images = [to_numpy_array(image) for image in images]
if do_align_long_axis:
images = [self.align_long_axis(image) for image in images]
if do_resize:
images = [self.resize(image=image, size=size, resample=resample) for image in images]
if do_thumbnail:
images = [self.thumbnail(image=image, size=size) for image in images]
if do_pad:
images = [self.pad(image=image, size=size, random_padding=random_padding) for image in images]
if do_rescale:
images = [self.rescale(image=image, scale=rescale_factor) for image in images]
if do_normalize:
images = [self.normalize(image=image, mean=image_mean, std=image_std) for image in images]
images = [to_channel_dimension_format(image, data_format) for image in images]
data = {"pixel_values": images}
return BatchFeature(data=data, tensor_type=return_tensors)
......@@ -37,12 +37,27 @@ class DonutProcessor(ProcessorMixin):
tokenizer ([`XLMRobertaTokenizer`/`XLMRobertaTokenizerFast`]):
An instance of [`XLMRobertaTokenizer`/`XLMRobertaTokenizerFast`]. The tokenizer is a required input.
"""
feature_extractor_class = "AutoFeatureExtractor"
attributes = ["image_processor", "tokenizer"]
image_processor_class = "AutoImageProcessor"
tokenizer_class = "AutoTokenizer"
def __init__(self, feature_extractor, tokenizer):
super().__init__(feature_extractor, tokenizer)
self.current_processor = self.feature_extractor
def __init__(self, image_processor=None, tokenizer=None, **kwargs):
if "feature_extractor" in kwargs:
warnings.warn(
"The `feature_extractor` argument is deprecated and will be removed in v4.27, use `image_processor`"
" instead.",
FutureWarning,
)
feature_extractor = kwargs.pop("feature_extractor")
image_processor = image_processor if image_processor is not None else feature_extractor
if image_processor is None:
raise ValueError("You need to specify an `image_processor`.")
if tokenizer is None:
raise ValueError("You need to specify a `tokenizer`.")
super().__init__(image_processor, tokenizer)
self.current_processor = self.image_processor
self._in_target_context_manager = False
def __call__(self, *args, **kwargs):
......@@ -66,7 +81,7 @@ class DonutProcessor(ProcessorMixin):
raise ValueError("You need to specify either an `images` or `text` input to process.")
if images is not None:
inputs = self.feature_extractor(images, *args, **kwargs)
inputs = self.image_processor(images, *args, **kwargs)
if text is not None:
encodings = self.tokenizer(text, **kwargs)
......@@ -105,7 +120,7 @@ class DonutProcessor(ProcessorMixin):
self._in_target_context_manager = True
self.current_processor = self.tokenizer
yield
self.current_processor = self.feature_extractor
self.current_processor = self.image_processor
self._in_target_context_manager = False
def token2json(self, tokens, is_inner_value=False, added_vocab=None):
......@@ -157,3 +172,12 @@ class DonutProcessor(ProcessorMixin):
return [output] if is_inner_value else output
else:
return [] if is_inner_value else {"text_sequence": tokens}
@property
def feature_extractor_class(self):
warnings.warn(
"`feature_extractor_class` is deprecated and will be removed in v4.27. Use `image_processor_class`"
" instead.",
FutureWarning,
)
return self.image_processor_class
......@@ -113,6 +113,13 @@ class DonutFeatureExtractor(metaclass=DummyObject):
requires_backends(self, ["vision"])
class DonutImageProcessor(metaclass=DummyObject):
_backends = ["vision"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["vision"])
class DPTFeatureExtractor(metaclass=DummyObject):
_backends = ["vision"]
......
......@@ -43,7 +43,7 @@ class DonutFeatureExtractionTester(unittest.TestCase):
min_resolution=30,
max_resolution=400,
do_resize=True,
size=[20, 18],
size=None,
do_thumbnail=True,
do_align_axis=False,
do_pad=True,
......@@ -58,7 +58,7 @@ class DonutFeatureExtractionTester(unittest.TestCase):
self.min_resolution = min_resolution
self.max_resolution = max_resolution
self.do_resize = do_resize
self.size = size
self.size = size if size is not None else {"height": 18, "width": 20}
self.do_thumbnail = do_thumbnail
self.do_align_axis = do_align_axis
self.do_pad = do_pad
......@@ -121,8 +121,8 @@ class DonutFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.Test
(
1,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.size[1],
self.feature_extract_tester.size[0],
self.feature_extract_tester.size["height"],
self.feature_extract_tester.size["width"],
),
)
......@@ -133,8 +133,8 @@ class DonutFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.Test
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.size[1],
self.feature_extract_tester.size[0],
self.feature_extract_tester.size["height"],
self.feature_extract_tester.size["width"],
),
)
......@@ -153,8 +153,8 @@ class DonutFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.Test
(
1,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.size[1],
self.feature_extract_tester.size[0],
self.feature_extract_tester.size["height"],
self.feature_extract_tester.size["width"],
),
)
......@@ -165,8 +165,8 @@ class DonutFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.Test
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.size[1],
self.feature_extract_tester.size[0],
self.feature_extract_tester.size["height"],
self.feature_extract_tester.size["width"],
),
)
......@@ -185,8 +185,8 @@ class DonutFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.Test
(
1,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.size[1],
self.feature_extract_tester.size[0],
self.feature_extract_tester.size["height"],
self.feature_extract_tester.size["width"],
),
)
......@@ -197,7 +197,7 @@ class DonutFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.Test
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels,
self.feature_extract_tester.size[1],
self.feature_extract_tester.size[0],
self.feature_extract_tester.size["height"],
self.feature_extract_tester.size["width"],
),
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment