Unverified Commit 6bca43bb authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

Input data format (#25464)

* Add copied from statements for image processors

* Move out rescale and normalize to base image processor

* Remove rescale and normalize from vit (post rebase)

* Update docstrings and tidy up

* PR comments

* Add input_data_format as preprocess argument

* Resolve tests and tidy up

* Remove num_channels argument

* Update doc strings -> default ints not in code formatting
parent a6609caf
...@@ -521,7 +521,12 @@ class BaseImageProcessor(ImageProcessingMixin): ...@@ -521,7 +521,12 @@ class BaseImageProcessor(ImageProcessingMixin):
raise NotImplementedError("Each image processor must implement its own preprocess method") raise NotImplementedError("Each image processor must implement its own preprocess method")
def rescale( def rescale(
self, image: np.ndarray, scale: float, data_format: Optional[Union[str, ChannelDimension]] = None, **kwargs self,
image: np.ndarray,
scale: float,
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs,
) -> np.ndarray: ) -> np.ndarray:
""" """
Rescale an image by a scale factor. image = image * scale. Rescale an image by a scale factor. image = image * scale.
...@@ -536,11 +541,16 @@ class BaseImageProcessor(ImageProcessingMixin): ...@@ -536,11 +541,16 @@ class BaseImageProcessor(ImageProcessingMixin):
image is used. Can be one of: image is used. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
Returns: Returns:
`np.ndarray`: The rescaled image. `np.ndarray`: The rescaled image.
""" """
return rescale(image, scale=scale, data_format=data_format, **kwargs) return rescale(image, scale=scale, data_format=data_format, input_data_format=input_data_format, **kwargs)
def normalize( def normalize(
self, self,
...@@ -548,6 +558,7 @@ class BaseImageProcessor(ImageProcessingMixin): ...@@ -548,6 +558,7 @@ class BaseImageProcessor(ImageProcessingMixin):
mean: Union[float, Iterable[float]], mean: Union[float, Iterable[float]],
std: Union[float, Iterable[float]], std: Union[float, Iterable[float]],
data_format: Optional[Union[str, ChannelDimension]] = None, data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs, **kwargs,
) -> np.ndarray: ) -> np.ndarray:
""" """
...@@ -565,17 +576,25 @@ class BaseImageProcessor(ImageProcessingMixin): ...@@ -565,17 +576,25 @@ class BaseImageProcessor(ImageProcessingMixin):
image is used. Can be one of: image is used. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
Returns: Returns:
`np.ndarray`: The normalized image. `np.ndarray`: The normalized image.
""" """
return normalize(image, mean=mean, std=std, data_format=data_format, **kwargs) return normalize(
image, mean=mean, std=std, data_format=data_format, input_data_format=input_data_format, **kwargs
)
def center_crop( def center_crop(
self, self,
image: np.ndarray, image: np.ndarray,
size: Dict[str, int], size: Dict[str, int],
data_format: Optional[Union[str, ChannelDimension]] = None, data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs, **kwargs,
) -> np.ndarray: ) -> np.ndarray:
""" """
...@@ -588,12 +607,26 @@ class BaseImageProcessor(ImageProcessingMixin): ...@@ -588,12 +607,26 @@ class BaseImageProcessor(ImageProcessingMixin):
size (`Dict[str, int]`): size (`Dict[str, int]`):
Size of the output image. Size of the output image.
data_format (`str` or `ChannelDimension`, *optional*): data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image. The channel dimension format for the output image. If unset, the channel dimension format of the input
image is used. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
""" """
size = get_size_dict(size) size = get_size_dict(size)
if "height" not in size or "width" not in size: if "height" not in size or "width" not in size:
raise ValueError(f"The size dictionary must have keys 'height' and 'width'. Got {size.keys()}") raise ValueError(f"The size dictionary must have keys 'height' and 'width'. Got {size.keys()}")
return center_crop(image, size=(size["height"], size["width"]), data_format=data_format, **kwargs) return center_crop(
image,
size=(size["height"], size["width"]),
data_format=data_format,
input_data_format=input_data_format,
**kwargs,
)
VALID_SIZE_DICT_KEYS = ({"height", "width"}, {"shortest_edge"}, {"shortest_edge", "longest_edge"}, {"longest_edge"}) VALID_SIZE_DICT_KEYS = ({"height", "width"}, {"shortest_edge"}, {"shortest_edge", "longest_edge"}, {"longest_edge"})
......
...@@ -27,6 +27,7 @@ from ...image_utils import ( ...@@ -27,6 +27,7 @@ from ...image_utils import (
ChannelDimension, ChannelDimension,
ImageInput, ImageInput,
PILImageResampling, PILImageResampling,
infer_channel_dimension_format,
make_list_of_images, make_list_of_images,
to_numpy_array, to_numpy_array,
valid_images, valid_images,
...@@ -145,6 +146,7 @@ class BeitImageProcessor(BaseImageProcessor): ...@@ -145,6 +146,7 @@ class BeitImageProcessor(BaseImageProcessor):
size: Dict[str, int], size: Dict[str, int],
resample: PILImageResampling = PILImageResampling.BICUBIC, resample: PILImageResampling = PILImageResampling.BICUBIC,
data_format: Optional[Union[str, ChannelDimension]] = None, data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs, **kwargs,
) -> np.ndarray: ) -> np.ndarray:
""" """
...@@ -159,12 +161,19 @@ class BeitImageProcessor(BaseImageProcessor): ...@@ -159,12 +161,19 @@ class BeitImageProcessor(BaseImageProcessor):
Resampling filter to use when resiizing the image. Resampling filter to use when resiizing the image.
data_format (`str` or `ChannelDimension`, *optional*): data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image. The channel dimension format of the image. If not provided, it will be the same as the input image.
input_data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred.
""" """
size = get_size_dict(size, default_to_square=True, param_name="size") size = get_size_dict(size, default_to_square=True, param_name="size")
if "height" not in size or "width" not in size: if "height" not in size or "width" not in size:
raise ValueError(f"The `size` argument must contain `height` and `width` keys. Got {size.keys()}") raise ValueError(f"The `size` argument must contain `height` and `width` keys. Got {size.keys()}")
return resize( return resize(
image, size=(size["height"], size["width"]), resample=resample, data_format=data_format, **kwargs image,
size=(size["height"], size["width"]),
resample=resample,
data_format=data_format,
input_data_format=input_data_format,
**kwargs,
) )
def reduce_label(self, label: ImageInput) -> np.ndarray: def reduce_label(self, label: ImageInput) -> np.ndarray:
...@@ -189,21 +198,22 @@ class BeitImageProcessor(BaseImageProcessor): ...@@ -189,21 +198,22 @@ class BeitImageProcessor(BaseImageProcessor):
do_normalize: bool = None, do_normalize: bool = None,
image_mean: Optional[Union[float, List[float]]] = None, image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
): ):
if do_reduce_labels: if do_reduce_labels:
image = self.reduce_label(image) image = self.reduce_label(image)
if do_resize: if do_resize:
image = self.resize(image=image, size=size, resample=resample) image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
if do_center_crop: if do_center_crop:
image = self.center_crop(image=image, size=crop_size) image = self.center_crop(image=image, size=crop_size, input_data_format=input_data_format)
if do_rescale: if do_rescale:
image = self.rescale(image=image, scale=rescale_factor) image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
if do_normalize: if do_normalize:
image = self.normalize(image=image, mean=image_mean, std=image_std) image = self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
return image return image
...@@ -221,10 +231,13 @@ class BeitImageProcessor(BaseImageProcessor): ...@@ -221,10 +231,13 @@ class BeitImageProcessor(BaseImageProcessor):
image_mean: Optional[Union[float, List[float]]] = None, image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None,
data_format: Optional[Union[str, ChannelDimension]] = None, data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray: ) -> np.ndarray:
"""Preprocesses a single image.""" """Preprocesses a single image."""
# All transformations expect numpy arrays. # All transformations expect numpy arrays.
image = to_numpy_array(image) image = to_numpy_array(image)
if input_data_format is None:
input_data_format = infer_channel_dimension_format(image)
image = self._preprocess( image = self._preprocess(
image, image,
do_reduce_labels=False, do_reduce_labels=False,
...@@ -238,9 +251,10 @@ class BeitImageProcessor(BaseImageProcessor): ...@@ -238,9 +251,10 @@ class BeitImageProcessor(BaseImageProcessor):
do_normalize=do_normalize, do_normalize=do_normalize,
image_mean=image_mean, image_mean=image_mean,
image_std=image_std, image_std=image_std,
input_data_format=input_data_format,
) )
if data_format is not None: if data_format is not None:
image = to_channel_dimension_format(image, data_format) image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
return image return image
def _preprocess_segmentation_map( def _preprocess_segmentation_map(
...@@ -252,6 +266,7 @@ class BeitImageProcessor(BaseImageProcessor): ...@@ -252,6 +266,7 @@ class BeitImageProcessor(BaseImageProcessor):
do_center_crop: bool = None, do_center_crop: bool = None,
crop_size: Dict[str, int] = None, crop_size: Dict[str, int] = None,
do_reduce_labels: bool = None, do_reduce_labels: bool = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
): ):
"""Preprocesses a single segmentation map.""" """Preprocesses a single segmentation map."""
# All transformations expect numpy arrays. # All transformations expect numpy arrays.
...@@ -260,8 +275,11 @@ class BeitImageProcessor(BaseImageProcessor): ...@@ -260,8 +275,11 @@ class BeitImageProcessor(BaseImageProcessor):
if segmentation_map.ndim == 2: if segmentation_map.ndim == 2:
segmentation_map = segmentation_map[None, ...] segmentation_map = segmentation_map[None, ...]
added_dimension = True added_dimension = True
input_data_format = ChannelDimension.FIRST
else: else:
added_dimension = False added_dimension = False
if input_data_format is None:
input_data_format = infer_channel_dimension_format(segmentation_map, num_channels=1)
segmentation_map = self._preprocess( segmentation_map = self._preprocess(
image=segmentation_map, image=segmentation_map,
do_reduce_labels=do_reduce_labels, do_reduce_labels=do_reduce_labels,
...@@ -272,6 +290,7 @@ class BeitImageProcessor(BaseImageProcessor): ...@@ -272,6 +290,7 @@ class BeitImageProcessor(BaseImageProcessor):
crop_size=crop_size, crop_size=crop_size,
do_normalize=False, do_normalize=False,
do_rescale=False, do_rescale=False,
input_data_format=ChannelDimension.FIRST,
) )
# Remove extra axis if added # Remove extra axis if added
if added_dimension: if added_dimension:
...@@ -301,6 +320,7 @@ class BeitImageProcessor(BaseImageProcessor): ...@@ -301,6 +320,7 @@ class BeitImageProcessor(BaseImageProcessor):
do_reduce_labels: Optional[bool] = None, do_reduce_labels: Optional[bool] = None,
return_tensors: Optional[Union[str, TensorType]] = None, return_tensors: Optional[Union[str, TensorType]] = None,
data_format: ChannelDimension = ChannelDimension.FIRST, data_format: ChannelDimension = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs, **kwargs,
) -> PIL.Image.Image: ) -> PIL.Image.Image:
""" """
...@@ -344,8 +364,15 @@ class BeitImageProcessor(BaseImageProcessor): ...@@ -344,8 +364,15 @@ class BeitImageProcessor(BaseImageProcessor):
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
The channel dimension format for the output image. Can be one of: The channel dimension format for the output image. Can be one of:
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `ChannelDimension.LAST`: image in (height, width, num_channels) format. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- Unset: Use the channel dimension format of the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
""" """
do_resize = do_resize if do_resize is not None else self.do_resize do_resize = do_resize if do_resize is not None else self.do_resize
size = size if size is not None else self.size size = size if size is not None else self.size
...@@ -403,6 +430,7 @@ class BeitImageProcessor(BaseImageProcessor): ...@@ -403,6 +430,7 @@ class BeitImageProcessor(BaseImageProcessor):
image_mean=image_mean, image_mean=image_mean,
image_std=image_std, image_std=image_std,
data_format=data_format, data_format=data_format,
input_data_format=input_data_format,
) )
for img in images for img in images
] ]
......
...@@ -31,6 +31,7 @@ from ...image_utils import ( ...@@ -31,6 +31,7 @@ from ...image_utils import (
ChannelDimension, ChannelDimension,
ImageInput, ImageInput,
PILImageResampling, PILImageResampling,
infer_channel_dimension_format,
make_list_of_images, make_list_of_images,
to_numpy_array, to_numpy_array,
valid_images, valid_images,
...@@ -125,6 +126,7 @@ class BitImageProcessor(BaseImageProcessor): ...@@ -125,6 +126,7 @@ class BitImageProcessor(BaseImageProcessor):
size: Dict[str, int], size: Dict[str, int],
resample: PILImageResampling = PILImageResampling.BICUBIC, resample: PILImageResampling = PILImageResampling.BICUBIC,
data_format: Optional[Union[str, ChannelDimension]] = None, data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs, **kwargs,
) -> np.ndarray: ) -> np.ndarray:
""" """
...@@ -140,12 +142,23 @@ class BitImageProcessor(BaseImageProcessor): ...@@ -140,12 +142,23 @@ class BitImageProcessor(BaseImageProcessor):
Resampling filter to use when resiizing the image. Resampling filter to use when resiizing the image.
data_format (`str` or `ChannelDimension`, *optional*): data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image. The channel dimension format of the image. If not provided, it will be the same as the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred.
""" """
size = get_size_dict(size, default_to_square=False) size = get_size_dict(size, default_to_square=False)
if "shortest_edge" not in size: if "shortest_edge" not in size:
raise ValueError(f"The `size` parameter must contain the key `shortest_edge`. Got {size.keys()}") raise ValueError(f"The `size` parameter must contain the key `shortest_edge`. Got {size.keys()}")
output_size = get_resize_output_image_size(image, size=size["shortest_edge"], default_to_square=False) output_size = get_resize_output_image_size(
return resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs) image, size=size["shortest_edge"], default_to_square=False, input_data_format=input_data_format
)
return resize(
image,
size=output_size,
resample=resample,
data_format=data_format,
input_data_format=input_data_format,
**kwargs,
)
def preprocess( def preprocess(
self, self,
...@@ -163,6 +176,7 @@ class BitImageProcessor(BaseImageProcessor): ...@@ -163,6 +176,7 @@ class BitImageProcessor(BaseImageProcessor):
do_convert_rgb: bool = None, do_convert_rgb: bool = None,
return_tensors: Optional[Union[str, TensorType]] = None, return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs, **kwargs,
) -> PIL.Image.Image: ) -> PIL.Image.Image:
""" """
...@@ -205,9 +219,15 @@ class BitImageProcessor(BaseImageProcessor): ...@@ -205,9 +219,15 @@ class BitImageProcessor(BaseImageProcessor):
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
The channel dimension format for the output image. Can be one of: The channel dimension format for the output image. Can be one of:
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `ChannelDimension.LAST`: image in (height, width, num_channels) format. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- Unset: defaults to the channel dimension format of the input image. - Unset: Use the channel dimension format of the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
""" """
do_resize = do_resize if do_resize is not None else self.do_resize do_resize = do_resize if do_resize is not None else self.do_resize
size = size if size is not None else self.size size = size if size is not None else self.size
...@@ -250,19 +270,36 @@ class BitImageProcessor(BaseImageProcessor): ...@@ -250,19 +270,36 @@ class BitImageProcessor(BaseImageProcessor):
# All transformations expect numpy arrays. # All transformations expect numpy arrays.
images = [to_numpy_array(image) for image in images] images = [to_numpy_array(image) for image in images]
if input_data_format is None:
# We assume that all images have the same channel dimension format.
input_data_format = infer_channel_dimension_format(images[0])
if do_resize: if do_resize:
images = [self.resize(image=image, size=size, resample=resample) for image in images] images = [
self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
for image in images
]
if do_center_crop: if do_center_crop:
images = [self.center_crop(image=image, size=crop_size) for image in images] images = [
self.center_crop(image=image, size=crop_size, input_data_format=input_data_format) for image in images
]
if do_rescale: if do_rescale:
images = [self.rescale(image=image, scale=rescale_factor) for image in images] images = [
self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
for image in images
]
if do_normalize: if do_normalize:
images = [self.normalize(image=image, mean=image_mean, std=image_std) for image in images] images = [
self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
for image in images
]
images = [to_channel_dimension_format(image, data_format) for image in images] images = [
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
]
data = {"pixel_values": images} data = {"pixel_values": images}
return BatchFeature(data=data, tensor_type=return_tensors) return BatchFeature(data=data, tensor_type=return_tensors)
...@@ -26,6 +26,7 @@ from ...image_utils import ( ...@@ -26,6 +26,7 @@ from ...image_utils import (
ChannelDimension, ChannelDimension,
ImageInput, ImageInput,
PILImageResampling, PILImageResampling,
infer_channel_dimension_format,
make_list_of_images, make_list_of_images,
to_numpy_array, to_numpy_array,
valid_images, valid_images,
...@@ -111,6 +112,7 @@ class BlipImageProcessor(BaseImageProcessor): ...@@ -111,6 +112,7 @@ class BlipImageProcessor(BaseImageProcessor):
size: Dict[str, int], size: Dict[str, int],
resample: PILImageResampling = PILImageResampling.BICUBIC, resample: PILImageResampling = PILImageResampling.BICUBIC,
data_format: Optional[Union[str, ChannelDimension]] = None, data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs, **kwargs,
) -> np.ndarray: ) -> np.ndarray:
""" """
...@@ -128,6 +130,13 @@ class BlipImageProcessor(BaseImageProcessor): ...@@ -128,6 +130,13 @@ class BlipImageProcessor(BaseImageProcessor):
image is used. Can be one of: image is used. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
Returns: Returns:
`np.ndarray`: The resized image. `np.ndarray`: The resized image.
...@@ -136,7 +145,14 @@ class BlipImageProcessor(BaseImageProcessor): ...@@ -136,7 +145,14 @@ class BlipImageProcessor(BaseImageProcessor):
if "height" not in size or "width" not in size: if "height" not in size or "width" not in size:
raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}") raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}")
output_size = (size["height"], size["width"]) output_size = (size["height"], size["width"])
return resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs) return resize(
image,
size=output_size,
resample=resample,
data_format=data_format,
input_data_format=input_data_format,
**kwargs,
)
def preprocess( def preprocess(
self, self,
...@@ -152,6 +168,7 @@ class BlipImageProcessor(BaseImageProcessor): ...@@ -152,6 +168,7 @@ class BlipImageProcessor(BaseImageProcessor):
return_tensors: Optional[Union[str, TensorType]] = None, return_tensors: Optional[Union[str, TensorType]] = None,
do_convert_rgb: bool = None, do_convert_rgb: bool = None,
data_format: ChannelDimension = ChannelDimension.FIRST, data_format: ChannelDimension = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs, **kwargs,
) -> PIL.Image.Image: ) -> PIL.Image.Image:
""" """
...@@ -190,8 +207,15 @@ class BlipImageProcessor(BaseImageProcessor): ...@@ -190,8 +207,15 @@ class BlipImageProcessor(BaseImageProcessor):
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
The channel dimension format for the output image. Can be one of: The channel dimension format for the output image. Can be one of:
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `ChannelDimension.LAST`: image in (height, width, num_channels) format. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- Unset: Use the channel dimension format of the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
""" """
do_resize = do_resize if do_resize is not None else self.do_resize do_resize = do_resize if do_resize is not None else self.do_resize
resample = resample if resample is not None else self.resample resample = resample if resample is not None else self.resample
...@@ -229,16 +253,31 @@ class BlipImageProcessor(BaseImageProcessor): ...@@ -229,16 +253,31 @@ class BlipImageProcessor(BaseImageProcessor):
# All transformations expect numpy arrays. # All transformations expect numpy arrays.
images = [to_numpy_array(image) for image in images] images = [to_numpy_array(image) for image in images]
if input_data_format is None:
# We assume that all images have the same channel dimension format.
input_data_format = infer_channel_dimension_format(images[0])
if do_resize: if do_resize:
images = [self.resize(image=image, size=size, resample=resample) for image in images] images = [
self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
for image in images
]
if do_rescale: if do_rescale:
images = [self.rescale(image=image, scale=rescale_factor) for image in images] images = [
self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
for image in images
]
if do_normalize: if do_normalize:
images = [self.normalize(image=image, mean=image_mean, std=image_std) for image in images] images = [
self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
images = [to_channel_dimension_format(image, data_format) for image in images] for image in images
]
images = [
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
]
encoded_outputs = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors) encoded_outputs = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors)
......
...@@ -50,7 +50,9 @@ def max_across_indices(values: Iterable[Any]) -> List[Any]: ...@@ -50,7 +50,9 @@ def max_across_indices(values: Iterable[Any]) -> List[Any]:
# Copied from transformers.models.vilt.image_processing_vilt.make_pixel_mask # Copied from transformers.models.vilt.image_processing_vilt.make_pixel_mask
def make_pixel_mask(image: np.ndarray, output_size: Tuple[int, int]) -> np.ndarray: def make_pixel_mask(
image: np.ndarray, output_size: Tuple[int, int], input_data_format: Optional[Union[str, ChannelDimension]] = None
) -> np.ndarray:
""" """
Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding. Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.
...@@ -60,33 +62,40 @@ def make_pixel_mask(image: np.ndarray, output_size: Tuple[int, int]) -> np.ndarr ...@@ -60,33 +62,40 @@ def make_pixel_mask(image: np.ndarray, output_size: Tuple[int, int]) -> np.ndarr
output_size (`Tuple[int, int]`): output_size (`Tuple[int, int]`):
Output size of the mask. Output size of the mask.
""" """
input_height, input_width = get_image_size(image) input_height, input_width = get_image_size(image, channel_dim=input_data_format)
mask = np.zeros(output_size, dtype=np.int64) mask = np.zeros(output_size, dtype=np.int64)
mask[:input_height, :input_width] = 1 mask[:input_height, :input_width] = 1
return mask return mask
# Copied from transformers.models.vilt.image_processing_vilt.get_max_height_width # Copied from transformers.models.vilt.image_processing_vilt.get_max_height_width
def get_max_height_width(images: List[np.ndarray]) -> List[int]: def get_max_height_width(
images: List[np.ndarray], input_data_format: Optional[Union[str, ChannelDimension]] = None
) -> List[int]:
""" """
Get the maximum height and width across all images in a batch. Get the maximum height and width across all images in a batch.
""" """
input_channel_dimension = infer_channel_dimension_format(images[0]) if input_data_format is None:
input_data_format = infer_channel_dimension_format(images[0])
if input_channel_dimension == ChannelDimension.FIRST: if input_data_format == ChannelDimension.FIRST:
_, max_height, max_width = max_across_indices([img.shape for img in images]) _, max_height, max_width = max_across_indices([img.shape for img in images])
elif input_channel_dimension == ChannelDimension.LAST: elif input_data_format == ChannelDimension.LAST:
max_height, max_width, _ = max_across_indices([img.shape for img in images]) max_height, max_width, _ = max_across_indices([img.shape for img in images])
else: else:
raise ValueError(f"Invalid channel dimension format: {input_channel_dimension}") raise ValueError(f"Invalid channel dimension format: {input_data_format}")
return (max_height, max_width) return (max_height, max_width)
# Copied from transformers.models.vilt.image_processing_vilt.get_resize_output_image_size # Copied from transformers.models.vilt.image_processing_vilt.get_resize_output_image_size
def get_resize_output_image_size( def get_resize_output_image_size(
input_image: np.ndarray, shorter: int = 800, longer: int = 1333, size_divisor: int = 32 input_image: np.ndarray,
shorter: int = 800,
longer: int = 1333,
size_divisor: int = 32,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> Tuple[int, int]: ) -> Tuple[int, int]:
input_height, input_width = get_image_size(input_image) input_height, input_width = get_image_size(input_image, input_data_format)
min_size, max_size = shorter, longer min_size, max_size = shorter, longer
scale = min_size / min(input_height, input_width) scale = min_size / min(input_height, input_width)
...@@ -122,7 +131,7 @@ class BridgeTowerImageProcessor(BaseImageProcessor): ...@@ -122,7 +131,7 @@ class BridgeTowerImageProcessor(BaseImageProcessor):
Resize the shorter side of the input to `size["shortest_edge"]`. The longer side will be limited to under Resize the shorter side of the input to `size["shortest_edge"]`. The longer side will be limited to under
`int((1333 / 800) * size["shortest_edge"])` while preserving the aspect ratio. Only has an effect if `int((1333 / 800) * size["shortest_edge"])` while preserving the aspect ratio. Only has an effect if
`do_resize` is set to `True`. Can be overridden by the `size` parameter in the `preprocess` method. `do_resize` is set to `True`. Can be overridden by the `size` parameter in the `preprocess` method.
size_divisor (`int`, *optional*, defaults to `32`): size_divisor (`int`, *optional*, defaults to 32):
The size by which to make sure both the height and width can be divided. Only has an effect if `do_resize` The size by which to make sure both the height and width can be divided. Only has an effect if `do_resize`
is set to `True`. Can be overridden by the `size_divisor` parameter in the `preprocess` method. is set to `True`. Can be overridden by the `size_divisor` parameter in the `preprocess` method.
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
...@@ -197,6 +206,7 @@ class BridgeTowerImageProcessor(BaseImageProcessor): ...@@ -197,6 +206,7 @@ class BridgeTowerImageProcessor(BaseImageProcessor):
size_divisor: int = 32, size_divisor: int = 32,
resample: PILImageResampling = PILImageResampling.BICUBIC, resample: PILImageResampling = PILImageResampling.BICUBIC,
data_format: Optional[Union[str, ChannelDimension]] = None, data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs, **kwargs,
) -> np.ndarray: ) -> np.ndarray:
""" """
...@@ -217,20 +227,32 @@ class BridgeTowerImageProcessor(BaseImageProcessor): ...@@ -217,20 +227,32 @@ class BridgeTowerImageProcessor(BaseImageProcessor):
Resampling filter to use when resiizing the image. Resampling filter to use when resiizing the image.
data_format (`str` or `ChannelDimension`, *optional*): data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image. The channel dimension format of the image. If not provided, it will be the same as the input image.
input_data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred.
""" """
size = get_size_dict(size, default_to_square=False) size = get_size_dict(size, default_to_square=False)
if "shortest_edge" not in size: if "shortest_edge" not in size:
raise ValueError(f"The `size` dictionary must contain the key `shortest_edge`. Got {size.keys()}") raise ValueError(f"The `size` dictionary must contain the key `shortest_edge`. Got {size.keys()}")
shorter = size["shortest_edge"] shorter = size["shortest_edge"]
longer = int(1333 / 800 * shorter) longer = int(1333 / 800 * shorter)
output_size = get_resize_output_image_size(image, shorter=shorter, longer=longer, size_divisor=size_divisor) output_size = get_resize_output_image_size(
return resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs) image, shorter=shorter, longer=longer, size_divisor=size_divisor, input_data_format=input_data_format
)
return resize(
image,
size=output_size,
resample=resample,
data_format=data_format,
input_data_format=input_data_format,
**kwargs,
)
def center_crop( def center_crop(
self, self,
image: np.ndarray, image: np.ndarray,
size: Dict[str, int], size: Dict[str, int],
data_format: Optional[Union[str, ChannelDimension]] = None, data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs, **kwargs,
) -> np.ndarray: ) -> np.ndarray:
""" """
...@@ -244,9 +266,18 @@ class BridgeTowerImageProcessor(BaseImageProcessor): ...@@ -244,9 +266,18 @@ class BridgeTowerImageProcessor(BaseImageProcessor):
Size of the output image in the form `{"height": h, "width": w}`. Size of the output image in the form `{"height": h, "width": w}`.
data_format (`str` or `ChannelDimension`, *optional*): data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image. The channel dimension format of the image. If not provided, it will be the same as the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred from the input
image.
""" """
output_size = size["shortest_edge"] output_size = size["shortest_edge"]
return center_crop(image, size=(output_size, output_size), data_format=data_format, **kwargs) return center_crop(
image,
size=(output_size, output_size),
data_format=data_format,
input_data_format=input_data_format,
**kwargs,
)
# Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor._pad_image # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor._pad_image
def _pad_image( def _pad_image(
...@@ -255,18 +286,24 @@ class BridgeTowerImageProcessor(BaseImageProcessor): ...@@ -255,18 +286,24 @@ class BridgeTowerImageProcessor(BaseImageProcessor):
output_size: Tuple[int, int], output_size: Tuple[int, int],
constant_values: Union[float, Iterable[float]] = 0, constant_values: Union[float, Iterable[float]] = 0,
data_format: Optional[ChannelDimension] = None, data_format: Optional[ChannelDimension] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray: ) -> np.ndarray:
""" """
Pad an image with zeros to the given size. Pad an image with zeros to the given size.
""" """
input_height, input_width = get_image_size(image) input_height, input_width = get_image_size(image, channel_dim=input_data_format)
output_height, output_width = output_size output_height, output_width = output_size
pad_bottom = output_height - input_height pad_bottom = output_height - input_height
pad_right = output_width - input_width pad_right = output_width - input_width
padding = ((0, pad_bottom), (0, pad_right)) padding = ((0, pad_bottom), (0, pad_right))
padded_image = pad( padded_image = pad(
image, padding, mode=PaddingMode.CONSTANT, constant_values=constant_values, data_format=data_format image,
padding,
mode=PaddingMode.CONSTANT,
constant_values=constant_values,
data_format=data_format,
input_data_format=input_data_format,
) )
return padded_image return padded_image
...@@ -278,6 +315,7 @@ class BridgeTowerImageProcessor(BaseImageProcessor): ...@@ -278,6 +315,7 @@ class BridgeTowerImageProcessor(BaseImageProcessor):
return_pixel_mask: bool = True, return_pixel_mask: bool = True,
return_tensors: Optional[Union[str, TensorType]] = None, return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Optional[ChannelDimension] = None, data_format: Optional[ChannelDimension] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> BatchFeature: ) -> BatchFeature:
""" """
Pads a batch of images to the bottom and right of the image with zeros to the size of largest height and width Pads a batch of images to the bottom and right of the image with zeros to the size of largest height and width
...@@ -299,17 +337,28 @@ class BridgeTowerImageProcessor(BaseImageProcessor): ...@@ -299,17 +337,28 @@ class BridgeTowerImageProcessor(BaseImageProcessor):
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
data_format (`str` or `ChannelDimension`, *optional*): data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image. The channel dimension format of the image. If not provided, it will be the same as the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred.
""" """
pad_size = get_max_height_width(images) pad_size = get_max_height_width(images, input_data_format=input_data_format)
padded_images = [ padded_images = [
self._pad_image(image, pad_size, constant_values=constant_values, data_format=data_format) self._pad_image(
image,
pad_size,
constant_values=constant_values,
data_format=data_format,
input_data_format=input_data_format,
)
for image in images for image in images
] ]
data = {"pixel_values": padded_images} data = {"pixel_values": padded_images}
if return_pixel_mask: if return_pixel_mask:
masks = [make_pixel_mask(image=image, output_size=pad_size) for image in images] masks = [
make_pixel_mask(image=image, output_size=pad_size, input_data_format=input_data_format)
for image in images
]
data["pixel_mask"] = masks data["pixel_mask"] = masks
return BatchFeature(data=data, tensor_type=return_tensors) return BatchFeature(data=data, tensor_type=return_tensors)
...@@ -330,6 +379,7 @@ class BridgeTowerImageProcessor(BaseImageProcessor): ...@@ -330,6 +379,7 @@ class BridgeTowerImageProcessor(BaseImageProcessor):
do_center_crop: Optional[bool] = None, do_center_crop: Optional[bool] = None,
return_tensors: Optional[Union[str, TensorType]] = None, return_tensors: Optional[Union[str, TensorType]] = None,
data_format: ChannelDimension = ChannelDimension.FIRST, data_format: ChannelDimension = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs, **kwargs,
) -> PIL.Image.Image: ) -> PIL.Image.Image:
""" """
...@@ -374,8 +424,15 @@ class BridgeTowerImageProcessor(BaseImageProcessor): ...@@ -374,8 +424,15 @@ class BridgeTowerImageProcessor(BaseImageProcessor):
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
The channel dimension format for the output image. Can be one of: The channel dimension format for the output image. Can be one of:
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `ChannelDimension.LAST`: image in (height, width, num_channels) format. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- Unset: Use the channel dimension format of the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
""" """
do_resize = do_resize if do_resize is not None else self.do_resize do_resize = do_resize if do_resize is not None else self.do_resize
size_divisor = size_divisor if size_divisor is not None else self.size_divisor size_divisor = size_divisor if size_divisor is not None else self.size_divisor
...@@ -414,22 +471,41 @@ class BridgeTowerImageProcessor(BaseImageProcessor): ...@@ -414,22 +471,41 @@ class BridgeTowerImageProcessor(BaseImageProcessor):
if do_resize: if do_resize:
images = [ images = [
self.resize(image=image, size=size, size_divisor=size_divisor, resample=resample) for image in images self.resize(
image=image,
size=size,
size_divisor=size_divisor,
resample=resample,
input_data_format=input_data_format,
)
for image in images
] ]
if do_center_crop: if do_center_crop:
images = [self.center_crop(image=image, size=size) for image in images] images = [
self.center_crop(image=image, size=size, input_data_format=input_data_format) for image in images
]
if do_rescale: if do_rescale:
images = [self.rescale(image=image, scale=rescale_factor) for image in images] images = [
self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
for image in images
]
if do_normalize: if do_normalize:
images = [self.normalize(image=image, mean=image_mean, std=image_std) for image in images] images = [
self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
for image in images
]
images = [to_channel_dimension_format(image, data_format) for image in images] images = [
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
]
if do_pad: if do_pad:
encoded_outputs = self.pad(images, return_pixel_mask=True, return_tensors=return_tensors) encoded_outputs = self.pad(
images, return_pixel_mask=True, return_tensors=return_tensors, input_data_format=data_format
)
else: else:
encoded_outputs = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors) encoded_outputs = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors)
......
...@@ -31,6 +31,7 @@ from ...image_utils import ( ...@@ -31,6 +31,7 @@ from ...image_utils import (
ChannelDimension, ChannelDimension,
ImageInput, ImageInput,
PILImageResampling, PILImageResampling,
infer_channel_dimension_format,
make_list_of_images, make_list_of_images,
to_numpy_array, to_numpy_array,
valid_images, valid_images,
...@@ -124,6 +125,7 @@ class ChineseCLIPImageProcessor(BaseImageProcessor): ...@@ -124,6 +125,7 @@ class ChineseCLIPImageProcessor(BaseImageProcessor):
size: Dict[str, int], size: Dict[str, int],
resample: PILImageResampling = PILImageResampling.BICUBIC, resample: PILImageResampling = PILImageResampling.BICUBIC,
data_format: Optional[Union[str, ChannelDimension]] = None, data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs, **kwargs,
) -> np.ndarray: ) -> np.ndarray:
""" """
...@@ -139,12 +141,22 @@ class ChineseCLIPImageProcessor(BaseImageProcessor): ...@@ -139,12 +141,22 @@ class ChineseCLIPImageProcessor(BaseImageProcessor):
Resampling filter to use when resiizing the image. Resampling filter to use when resiizing the image.
data_format (`str` or `ChannelDimension`, *optional*): data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image. The channel dimension format of the image. If not provided, it will be the same as the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred from the input
image.
""" """
size = get_size_dict(size, default_to_square=False) size = get_size_dict(size, default_to_square=False)
output_size = get_resize_output_image_size( output_size = get_resize_output_image_size(
image, size=(size["height"], size["width"]), default_to_square=False image, size=(size["height"], size["width"]), default_to_square=False, input_data_format=input_data_format
)
return resize(
image,
size=output_size,
resample=resample,
data_format=data_format,
input_data_format=input_data_format,
**kwargs,
) )
return resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs)
def preprocess( def preprocess(
self, self,
...@@ -162,6 +174,7 @@ class ChineseCLIPImageProcessor(BaseImageProcessor): ...@@ -162,6 +174,7 @@ class ChineseCLIPImageProcessor(BaseImageProcessor):
do_convert_rgb: bool = None, do_convert_rgb: bool = None,
return_tensors: Optional[Union[str, TensorType]] = None, return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs, **kwargs,
) -> PIL.Image.Image: ) -> PIL.Image.Image:
""" """
...@@ -204,9 +217,15 @@ class ChineseCLIPImageProcessor(BaseImageProcessor): ...@@ -204,9 +217,15 @@ class ChineseCLIPImageProcessor(BaseImageProcessor):
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
The channel dimension format for the output image. Can be one of: The channel dimension format for the output image. Can be one of:
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `ChannelDimension.LAST`: image in (height, width, num_channels) format. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- Unset: defaults to the channel dimension format of the input image. - Unset: Use the channel dimension format of the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
""" """
do_resize = do_resize if do_resize is not None else self.do_resize do_resize = do_resize if do_resize is not None else self.do_resize
size = size if size is not None else self.size size = size if size is not None else self.size
...@@ -249,19 +268,36 @@ class ChineseCLIPImageProcessor(BaseImageProcessor): ...@@ -249,19 +268,36 @@ class ChineseCLIPImageProcessor(BaseImageProcessor):
# All transformations expect numpy arrays. # All transformations expect numpy arrays.
images = [to_numpy_array(image) for image in images] images = [to_numpy_array(image) for image in images]
if input_data_format is None:
# We assume that all images have the same channel dimension format.
input_data_format = infer_channel_dimension_format(images[0])
if do_resize: if do_resize:
images = [self.resize(image=image, size=size, resample=resample) for image in images] images = [
self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
for image in images
]
if do_center_crop: if do_center_crop:
images = [self.center_crop(image=image, size=crop_size) for image in images] images = [
self.center_crop(image=image, size=crop_size, input_data_format=input_data_format) for image in images
]
if do_rescale: if do_rescale:
images = [self.rescale(image=image, scale=rescale_factor) for image in images] images = [
self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
for image in images
]
if do_normalize: if do_normalize:
images = [self.normalize(image=image, mean=image_mean, std=image_std) for image in images] images = [
self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
for image in images
]
images = [to_channel_dimension_format(image, data_format) for image in images] images = [
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
]
data = {"pixel_values": images} data = {"pixel_values": images}
return BatchFeature(data=data, tensor_type=return_tensors) return BatchFeature(data=data, tensor_type=return_tensors)
...@@ -31,6 +31,7 @@ from ...image_utils import ( ...@@ -31,6 +31,7 @@ from ...image_utils import (
ChannelDimension, ChannelDimension,
ImageInput, ImageInput,
PILImageResampling, PILImageResampling,
infer_channel_dimension_format,
make_list_of_images, make_list_of_images,
to_numpy_array, to_numpy_array,
valid_images, valid_images,
...@@ -124,6 +125,7 @@ class CLIPImageProcessor(BaseImageProcessor): ...@@ -124,6 +125,7 @@ class CLIPImageProcessor(BaseImageProcessor):
size: Dict[str, int], size: Dict[str, int],
resample: PILImageResampling = PILImageResampling.BICUBIC, resample: PILImageResampling = PILImageResampling.BICUBIC,
data_format: Optional[Union[str, ChannelDimension]] = None, data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs, **kwargs,
) -> np.ndarray: ) -> np.ndarray:
""" """
...@@ -139,12 +141,23 @@ class CLIPImageProcessor(BaseImageProcessor): ...@@ -139,12 +141,23 @@ class CLIPImageProcessor(BaseImageProcessor):
Resampling filter to use when resiizing the image. Resampling filter to use when resiizing the image.
data_format (`str` or `ChannelDimension`, *optional*): data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image. The channel dimension format of the image. If not provided, it will be the same as the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred.
""" """
size = get_size_dict(size, default_to_square=False) size = get_size_dict(size, default_to_square=False)
if "shortest_edge" not in size: if "shortest_edge" not in size:
raise ValueError(f"The `size` parameter must contain the key `shortest_edge`. Got {size.keys()}") raise ValueError(f"The `size` parameter must contain the key `shortest_edge`. Got {size.keys()}")
output_size = get_resize_output_image_size(image, size=size["shortest_edge"], default_to_square=False) output_size = get_resize_output_image_size(
return resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs) image, size=size["shortest_edge"], default_to_square=False, input_data_format=input_data_format
)
return resize(
image,
size=output_size,
resample=resample,
data_format=data_format,
input_data_format=input_data_format,
**kwargs,
)
def preprocess( def preprocess(
self, self,
...@@ -162,6 +175,7 @@ class CLIPImageProcessor(BaseImageProcessor): ...@@ -162,6 +175,7 @@ class CLIPImageProcessor(BaseImageProcessor):
do_convert_rgb: bool = None, do_convert_rgb: bool = None,
return_tensors: Optional[Union[str, TensorType]] = None, return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs, **kwargs,
) -> PIL.Image.Image: ) -> PIL.Image.Image:
""" """
...@@ -204,9 +218,15 @@ class CLIPImageProcessor(BaseImageProcessor): ...@@ -204,9 +218,15 @@ class CLIPImageProcessor(BaseImageProcessor):
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
The channel dimension format for the output image. Can be one of: The channel dimension format for the output image. Can be one of:
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `ChannelDimension.LAST`: image in (height, width, num_channels) format. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- Unset: defaults to the channel dimension format of the input image. - Unset: Use the channel dimension format of the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
""" """
do_resize = do_resize if do_resize is not None else self.do_resize do_resize = do_resize if do_resize is not None else self.do_resize
size = size if size is not None else self.size size = size if size is not None else self.size
...@@ -249,19 +269,36 @@ class CLIPImageProcessor(BaseImageProcessor): ...@@ -249,19 +269,36 @@ class CLIPImageProcessor(BaseImageProcessor):
# All transformations expect numpy arrays. # All transformations expect numpy arrays.
images = [to_numpy_array(image) for image in images] images = [to_numpy_array(image) for image in images]
if input_data_format is None:
# We assume that all images have the same channel dimension format.
input_data_format = infer_channel_dimension_format(images[0])
if do_resize: if do_resize:
images = [self.resize(image=image, size=size, resample=resample) for image in images] images = [
self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
for image in images
]
if do_center_crop: if do_center_crop:
images = [self.center_crop(image=image, size=crop_size) for image in images] images = [
self.center_crop(image=image, size=crop_size, input_data_format=input_data_format) for image in images
]
if do_rescale: if do_rescale:
images = [self.rescale(image=image, scale=rescale_factor) for image in images] images = [
self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
for image in images
]
if do_normalize: if do_normalize:
images = [self.normalize(image=image, mean=image_mean, std=image_std) for image in images] images = [
self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
for image in images
]
images = [to_channel_dimension_format(image, data_format) for image in images] images = [
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
]
data = {"pixel_values": images} data = {"pixel_values": images}
return BatchFeature(data=data, tensor_type=return_tensors) return BatchFeature(data=data, tensor_type=return_tensors)
...@@ -31,6 +31,7 @@ from ...image_utils import ( ...@@ -31,6 +31,7 @@ from ...image_utils import (
ChannelDimension, ChannelDimension,
ImageInput, ImageInput,
PILImageResampling, PILImageResampling,
infer_channel_dimension_format,
make_list_of_images, make_list_of_images,
to_numpy_array, to_numpy_array,
valid_images, valid_images,
...@@ -118,6 +119,7 @@ class ConvNextImageProcessor(BaseImageProcessor): ...@@ -118,6 +119,7 @@ class ConvNextImageProcessor(BaseImageProcessor):
crop_pct: float, crop_pct: float,
resample: PILImageResampling = PILImageResampling.BICUBIC, resample: PILImageResampling = PILImageResampling.BICUBIC,
data_format: Optional[Union[str, ChannelDimension]] = None, data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs, **kwargs,
) -> np.ndarray: ) -> np.ndarray:
""" """
...@@ -137,6 +139,9 @@ class ConvNextImageProcessor(BaseImageProcessor): ...@@ -137,6 +139,9 @@ class ConvNextImageProcessor(BaseImageProcessor):
Resampling filter to use when resizing the image. Resampling filter to use when resizing the image.
data_format (`str` or `ChannelDimension`, *optional*): data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image. The channel dimension format of the image. If not provided, it will be the same as the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred from the input
image.
""" """
size = get_size_dict(size, default_to_square=False) size = get_size_dict(size, default_to_square=False)
if "shortest_edge" not in size: if "shortest_edge" not in size:
...@@ -146,14 +151,34 @@ class ConvNextImageProcessor(BaseImageProcessor): ...@@ -146,14 +151,34 @@ class ConvNextImageProcessor(BaseImageProcessor):
if shortest_edge < 384: if shortest_edge < 384:
# maintain same ratio, resizing shortest edge to shortest_edge/crop_pct # maintain same ratio, resizing shortest edge to shortest_edge/crop_pct
resize_shortest_edge = int(shortest_edge / crop_pct) resize_shortest_edge = int(shortest_edge / crop_pct)
resize_size = get_resize_output_image_size(image, size=resize_shortest_edge, default_to_square=False) resize_size = get_resize_output_image_size(
image = resize(image=image, size=resize_size, resample=resample, data_format=data_format, **kwargs) image, size=resize_shortest_edge, default_to_square=False, input_data_format=input_data_format
)
image = resize(
image=image,
size=resize_size,
resample=resample,
data_format=data_format,
input_data_format=input_data_format,
**kwargs,
)
# then crop to (shortest_edge, shortest_edge) # then crop to (shortest_edge, shortest_edge)
return center_crop(image=image, size=(shortest_edge, shortest_edge), data_format=data_format, **kwargs) return center_crop(
image=image,
size=(shortest_edge, shortest_edge),
data_format=data_format,
input_data_format=input_data_format,
**kwargs,
)
else: else:
# warping (no cropping) when evaluated at 384 or larger # warping (no cropping) when evaluated at 384 or larger
return resize( return resize(
image, size=(shortest_edge, shortest_edge), resample=resample, data_format=data_format, **kwargs image,
size=(shortest_edge, shortest_edge),
resample=resample,
data_format=data_format,
input_data_format=input_data_format,
**kwargs,
) )
def preprocess( def preprocess(
...@@ -170,6 +195,7 @@ class ConvNextImageProcessor(BaseImageProcessor): ...@@ -170,6 +195,7 @@ class ConvNextImageProcessor(BaseImageProcessor):
image_std: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None,
return_tensors: Optional[Union[str, TensorType]] = None, return_tensors: Optional[Union[str, TensorType]] = None,
data_format: ChannelDimension = ChannelDimension.FIRST, data_format: ChannelDimension = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs, **kwargs,
) -> PIL.Image.Image: ) -> PIL.Image.Image:
""" """
...@@ -209,8 +235,15 @@ class ConvNextImageProcessor(BaseImageProcessor): ...@@ -209,8 +235,15 @@ class ConvNextImageProcessor(BaseImageProcessor):
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
The channel dimension format for the output image. Can be one of: The channel dimension format for the output image. Can be one of:
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `ChannelDimension.LAST`: image in (height, width, num_channels) format. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- Unset: Use the channel dimension format of the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
""" """
do_resize = do_resize if do_resize is not None else self.do_resize do_resize = do_resize if do_resize is not None else self.do_resize
crop_pct = crop_pct if crop_pct is not None else self.crop_pct crop_pct = crop_pct if crop_pct is not None else self.crop_pct
...@@ -247,16 +280,33 @@ class ConvNextImageProcessor(BaseImageProcessor): ...@@ -247,16 +280,33 @@ class ConvNextImageProcessor(BaseImageProcessor):
# All transformations expect numpy arrays. # All transformations expect numpy arrays.
images = [to_numpy_array(image) for image in images] images = [to_numpy_array(image) for image in images]
if input_data_format is None:
# We assume that all images have the same channel dimension format.
input_data_format = infer_channel_dimension_format(images[0])
if do_resize: if do_resize:
images = [self.resize(image=image, size=size, crop_pct=crop_pct, resample=resample) for image in images] images = [
self.resize(
image=image, size=size, crop_pct=crop_pct, resample=resample, input_data_format=input_data_format
)
for image in images
]
if do_rescale: if do_rescale:
images = [self.rescale(image=image, scale=rescale_factor) for image in images] images = [
self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
for image in images
]
if do_normalize: if do_normalize:
images = [self.normalize(image=image, mean=image_mean, std=image_std) for image in images] images = [
self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
for image in images
]
images = [to_channel_dimension_format(image, data_format) for image in images] images = [
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
]
data = {"pixel_values": images} data = {"pixel_values": images}
return BatchFeature(data=data, tensor_type=return_tensors) return BatchFeature(data=data, tensor_type=return_tensors)
...@@ -26,6 +26,7 @@ from ...image_utils import ( ...@@ -26,6 +26,7 @@ from ...image_utils import (
ChannelDimension, ChannelDimension,
ImageInput, ImageInput,
PILImageResampling, PILImageResampling,
infer_channel_dimension_format,
make_list_of_images, make_list_of_images,
to_numpy_array, to_numpy_array,
valid_images, valid_images,
...@@ -114,6 +115,7 @@ class DeiTImageProcessor(BaseImageProcessor): ...@@ -114,6 +115,7 @@ class DeiTImageProcessor(BaseImageProcessor):
size: Dict[str, int], size: Dict[str, int],
resample: PILImageResampling = PILImageResampling.BICUBIC, resample: PILImageResampling = PILImageResampling.BICUBIC,
data_format: Optional[Union[str, ChannelDimension]] = None, data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs, **kwargs,
) -> np.ndarray: ) -> np.ndarray:
""" """
...@@ -131,6 +133,13 @@ class DeiTImageProcessor(BaseImageProcessor): ...@@ -131,6 +133,13 @@ class DeiTImageProcessor(BaseImageProcessor):
image is used. Can be one of: image is used. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
Returns: Returns:
`np.ndarray`: The resized image. `np.ndarray`: The resized image.
...@@ -139,7 +148,14 @@ class DeiTImageProcessor(BaseImageProcessor): ...@@ -139,7 +148,14 @@ class DeiTImageProcessor(BaseImageProcessor):
if "height" not in size or "width" not in size: if "height" not in size or "width" not in size:
raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}") raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}")
output_size = (size["height"], size["width"]) output_size = (size["height"], size["width"])
return resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs) return resize(
image,
size=output_size,
resample=resample,
data_format=data_format,
input_data_format=input_data_format,
**kwargs,
)
def preprocess( def preprocess(
self, self,
...@@ -156,6 +172,7 @@ class DeiTImageProcessor(BaseImageProcessor): ...@@ -156,6 +172,7 @@ class DeiTImageProcessor(BaseImageProcessor):
image_std: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None,
return_tensors: Optional[Union[str, TensorType]] = None, return_tensors: Optional[Union[str, TensorType]] = None,
data_format: ChannelDimension = ChannelDimension.FIRST, data_format: ChannelDimension = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs, **kwargs,
) -> PIL.Image.Image: ) -> PIL.Image.Image:
""" """
...@@ -197,6 +214,12 @@ class DeiTImageProcessor(BaseImageProcessor): ...@@ -197,6 +214,12 @@ class DeiTImageProcessor(BaseImageProcessor):
The channel dimension format for the output image. Can be one of: The channel dimension format for the output image. Can be one of:
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `ChannelDimension.LAST`: image in (height, width, num_channels) format. - `ChannelDimension.LAST`: image in (height, width, num_channels) format.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
""" """
do_resize = do_resize if do_resize is not None else self.do_resize do_resize = do_resize if do_resize is not None else self.do_resize
resample = resample if resample is not None else self.resample resample = resample if resample is not None else self.resample
...@@ -235,19 +258,36 @@ class DeiTImageProcessor(BaseImageProcessor): ...@@ -235,19 +258,36 @@ class DeiTImageProcessor(BaseImageProcessor):
# All transformations expect numpy arrays. # All transformations expect numpy arrays.
images = [to_numpy_array(image) for image in images] images = [to_numpy_array(image) for image in images]
if input_data_format is None:
# We assume that all images have the same channel dimension format.
input_data_format = infer_channel_dimension_format(images[0])
if do_resize: if do_resize:
images = [self.resize(image=image, size=size, resample=resample) for image in images] images = [
self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
for image in images
]
if do_center_crop: if do_center_crop:
images = [self.center_crop(image=image, size=crop_size) for image in images] images = [
self.center_crop(image=image, size=crop_size, input_data_format=input_data_format) for image in images
]
if do_rescale: if do_rescale:
images = [self.rescale(image=image, scale=rescale_factor) for image in images] images = [
self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
for image in images
]
if do_normalize: if do_normalize:
images = [self.normalize(image=image, mean=image_mean, std=image_std) for image in images] images = [
self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
images = [to_channel_dimension_format(image, data_format) for image in images] for image in images
]
images = [
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
]
data = {"pixel_values": images} data = {"pixel_values": images}
return BatchFeature(data=data, tensor_type=return_tensors) return BatchFeature(data=data, tensor_type=return_tensors)
...@@ -115,7 +115,10 @@ def get_size_with_aspect_ratio(image_size, size, max_size=None) -> Tuple[int, in ...@@ -115,7 +115,10 @@ def get_size_with_aspect_ratio(image_size, size, max_size=None) -> Tuple[int, in
# Copied from transformers.models.detr.image_processing_detr.get_resize_output_image_size # Copied from transformers.models.detr.image_processing_detr.get_resize_output_image_size
def get_resize_output_image_size( def get_resize_output_image_size(
input_image: np.ndarray, size: Union[int, Tuple[int, int], List[int]], max_size: Optional[int] = None input_image: np.ndarray,
size: Union[int, Tuple[int, int], List[int]],
max_size: Optional[int] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> Tuple[int, int]: ) -> Tuple[int, int]:
""" """
Computes the output image size given the input image size and the desired output size. If the desired output size Computes the output image size given the input image size and the desired output size. If the desired output size
...@@ -129,8 +132,10 @@ def get_resize_output_image_size( ...@@ -129,8 +132,10 @@ def get_resize_output_image_size(
The desired output size. The desired output size.
max_size (`int`, *optional*): max_size (`int`, *optional*):
The maximum allowed output size. The maximum allowed output size.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred from the input image.
""" """
image_size = get_image_size(input_image) image_size = get_image_size(input_image, input_data_format)
if isinstance(size, (list, tuple)): if isinstance(size, (list, tuple)):
return size return size
...@@ -200,23 +205,28 @@ def max_across_indices(values: Iterable[Any]) -> List[Any]: ...@@ -200,23 +205,28 @@ def max_across_indices(values: Iterable[Any]) -> List[Any]:
# Copied from transformers.models.detr.image_processing_detr.get_max_height_width # Copied from transformers.models.detr.image_processing_detr.get_max_height_width
def get_max_height_width(images: List[np.ndarray]) -> List[int]: def get_max_height_width(
images: List[np.ndarray], input_data_format: Optional[Union[str, ChannelDimension]] = None
) -> List[int]:
""" """
Get the maximum height and width across all images in a batch. Get the maximum height and width across all images in a batch.
""" """
input_channel_dimension = infer_channel_dimension_format(images[0]) if input_data_format is None:
input_data_format = infer_channel_dimension_format(images[0])
if input_channel_dimension == ChannelDimension.FIRST: if input_data_format == ChannelDimension.FIRST:
_, max_height, max_width = max_across_indices([img.shape for img in images]) _, max_height, max_width = max_across_indices([img.shape for img in images])
elif input_channel_dimension == ChannelDimension.LAST: elif input_data_format == ChannelDimension.LAST:
max_height, max_width, _ = max_across_indices([img.shape for img in images]) max_height, max_width, _ = max_across_indices([img.shape for img in images])
else: else:
raise ValueError(f"Invalid channel dimension format: {input_channel_dimension}") raise ValueError(f"Invalid channel dimension format: {input_data_format}")
return (max_height, max_width) return (max_height, max_width)
# Copied from transformers.models.detr.image_processing_detr.make_pixel_mask # Copied from transformers.models.detr.image_processing_detr.make_pixel_mask
def make_pixel_mask(image: np.ndarray, output_size: Tuple[int, int]) -> np.ndarray: def make_pixel_mask(
image: np.ndarray, output_size: Tuple[int, int], input_data_format: Optional[Union[str, ChannelDimension]] = None
) -> np.ndarray:
""" """
Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding. Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.
...@@ -226,7 +236,7 @@ def make_pixel_mask(image: np.ndarray, output_size: Tuple[int, int]) -> np.ndarr ...@@ -226,7 +236,7 @@ def make_pixel_mask(image: np.ndarray, output_size: Tuple[int, int]) -> np.ndarr
output_size (`Tuple[int, int]`): output_size (`Tuple[int, int]`):
Output size of the mask. Output size of the mask.
""" """
input_height, input_width = get_image_size(image) input_height, input_width = get_image_size(image, channel_dim=input_data_format)
mask = np.zeros(output_size, dtype=np.int64) mask = np.zeros(output_size, dtype=np.int64)
mask[:input_height, :input_width] = 1 mask[:input_height, :input_width] = 1
return mask return mask
...@@ -268,11 +278,16 @@ def convert_coco_poly_to_mask(segmentations, height: int, width: int) -> np.ndar ...@@ -268,11 +278,16 @@ def convert_coco_poly_to_mask(segmentations, height: int, width: int) -> np.ndar
# Copied from transformers.models.detr.image_processing_detr.prepare_coco_detection_annotation with DETR->DETA # Copied from transformers.models.detr.image_processing_detr.prepare_coco_detection_annotation with DETR->DETA
def prepare_coco_detection_annotation(image, target, return_segmentation_masks: bool = False): def prepare_coco_detection_annotation(
image,
target,
return_segmentation_masks: bool = False,
input_data_format: Optional[Union[ChannelDimension, str]] = None,
):
""" """
Convert the target in COCO format into the format expected by DETA. Convert the target in COCO format into the format expected by DETA.
""" """
image_height, image_width = get_image_size(image) image_height, image_width = get_image_size(image, channel_dim=input_data_format)
image_id = target["image_id"] image_id = target["image_id"]
image_id = np.asarray([image_id], dtype=np.int64) image_id = np.asarray([image_id], dtype=np.int64)
...@@ -357,12 +372,16 @@ def masks_to_boxes(masks: np.ndarray) -> np.ndarray: ...@@ -357,12 +372,16 @@ def masks_to_boxes(masks: np.ndarray) -> np.ndarray:
# Copied from transformers.models.detr.image_processing_detr.prepare_coco_panoptic_annotation with DETR->DETA # Copied from transformers.models.detr.image_processing_detr.prepare_coco_panoptic_annotation with DETR->DETA
def prepare_coco_panoptic_annotation( def prepare_coco_panoptic_annotation(
image: np.ndarray, target: Dict, masks_path: Union[str, pathlib.Path], return_masks: bool = True image: np.ndarray,
target: Dict,
masks_path: Union[str, pathlib.Path],
return_masks: bool = True,
input_data_format: Union[ChannelDimension, str] = None,
) -> Dict: ) -> Dict:
""" """
Prepare a coco panoptic annotation for DETA. Prepare a coco panoptic annotation for DETA.
""" """
image_height, image_width = get_image_size(image) image_height, image_width = get_image_size(image, channel_dim=input_data_format)
annotation_path = pathlib.Path(masks_path) / target["file_name"] annotation_path = pathlib.Path(masks_path) / target["file_name"]
new_target = {} new_target = {}
...@@ -522,6 +541,7 @@ class DetaImageProcessor(BaseImageProcessor): ...@@ -522,6 +541,7 @@ class DetaImageProcessor(BaseImageProcessor):
format: Optional[AnnotionFormat] = None, format: Optional[AnnotionFormat] = None,
return_segmentation_masks: bool = None, return_segmentation_masks: bool = None,
masks_path: Optional[Union[str, pathlib.Path]] = None, masks_path: Optional[Union[str, pathlib.Path]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> Dict: ) -> Dict:
""" """
Prepare an annotation for feeding into DETA model. Prepare an annotation for feeding into DETA model.
...@@ -530,11 +550,17 @@ class DetaImageProcessor(BaseImageProcessor): ...@@ -530,11 +550,17 @@ class DetaImageProcessor(BaseImageProcessor):
if format == AnnotionFormat.COCO_DETECTION: if format == AnnotionFormat.COCO_DETECTION:
return_segmentation_masks = False if return_segmentation_masks is None else return_segmentation_masks return_segmentation_masks = False if return_segmentation_masks is None else return_segmentation_masks
target = prepare_coco_detection_annotation(image, target, return_segmentation_masks) target = prepare_coco_detection_annotation(
image, target, return_segmentation_masks, input_data_format=input_data_format
)
elif format == AnnotionFormat.COCO_PANOPTIC: elif format == AnnotionFormat.COCO_PANOPTIC:
return_segmentation_masks = True if return_segmentation_masks is None else return_segmentation_masks return_segmentation_masks = True if return_segmentation_masks is None else return_segmentation_masks
target = prepare_coco_panoptic_annotation( target = prepare_coco_panoptic_annotation(
image, target, masks_path=masks_path, return_masks=return_segmentation_masks image,
target,
masks_path=masks_path,
return_masks=return_segmentation_masks,
input_data_format=input_data_format,
) )
else: else:
raise ValueError(f"Format {format} is not supported.") raise ValueError(f"Format {format} is not supported.")
...@@ -571,15 +597,32 @@ class DetaImageProcessor(BaseImageProcessor): ...@@ -571,15 +597,32 @@ class DetaImageProcessor(BaseImageProcessor):
size: Dict[str, int], size: Dict[str, int],
resample: PILImageResampling = PILImageResampling.BILINEAR, resample: PILImageResampling = PILImageResampling.BILINEAR,
data_format: Optional[ChannelDimension] = None, data_format: Optional[ChannelDimension] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs, **kwargs,
) -> np.ndarray: ) -> np.ndarray:
""" """
Resize the image to the given size. Size can be `min_size` (scalar) or `(height, width)` tuple. If size is an Resize the image to the given size. Size can be `min_size` (scalar) or `(height, width)` tuple. If size is an
int, smaller edge of the image will be matched to this number. int, smaller edge of the image will be matched to this number.
Args:
image (`np.ndarray`):
Image to resize.
size (`Dict[str, int]`):
The desired output size. Can contain keys `shortest_edge` and `longest_edge` or `height` and `width`.
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
Resampling filter to use if resizing the image.
data_format (`ChannelDimension`, *optional*):
The channel dimension format for the output image. If unset, the channel dimension format of the input
image is used.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred from the input
image.
""" """
size = get_size_dict(size, default_to_square=False) size = get_size_dict(size, default_to_square=False)
if "shortest_edge" in size and "longest_edge" in size: if "shortest_edge" in size and "longest_edge" in size:
size = get_resize_output_image_size(image, size["shortest_edge"], size["longest_edge"]) size = get_resize_output_image_size(
image, size["shortest_edge"], size["longest_edge"], input_data_format=input_data_format
)
elif "height" in size and "width" in size: elif "height" in size and "width" in size:
size = (size["height"], size["width"]) size = (size["height"], size["width"])
else: else:
...@@ -587,7 +630,9 @@ class DetaImageProcessor(BaseImageProcessor): ...@@ -587,7 +630,9 @@ class DetaImageProcessor(BaseImageProcessor):
"Size must contain 'height' and 'width' keys or 'shortest_edge' and 'longest_edge' keys. Got" "Size must contain 'height' and 'width' keys or 'shortest_edge' and 'longest_edge' keys. Got"
f" {size.keys()}." f" {size.keys()}."
) )
image = resize(image, size=size, resample=resample, data_format=data_format) image = resize(
image, size=size, resample=resample, data_format=data_format, input_data_format=input_data_format
)
return image return image
# Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.resize_annotation # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.resize_annotation
...@@ -606,7 +651,11 @@ class DetaImageProcessor(BaseImageProcessor): ...@@ -606,7 +651,11 @@ class DetaImageProcessor(BaseImageProcessor):
# Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.rescale # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.rescale
def rescale( def rescale(
self, image: np.ndarray, rescale_factor: float, data_format: Optional[Union[str, ChannelDimension]] = None self,
image: np.ndarray,
rescale_factor: float,
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray: ) -> np.ndarray:
""" """
Rescale the image by the given factor. image = image * rescale_factor. Rescale the image by the given factor. image = image * rescale_factor.
...@@ -621,8 +670,13 @@ class DetaImageProcessor(BaseImageProcessor): ...@@ -621,8 +670,13 @@ class DetaImageProcessor(BaseImageProcessor):
image is used. Can be one of: image is used. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
input_data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format for the input image. If unset, is inferred from the input image. Can be
one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
""" """
return rescale(image, rescale_factor, data_format=data_format) return rescale(image, rescale_factor, data_format=data_format, input_data_format=input_data_format)
# Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.normalize_annotation # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.normalize_annotation
def normalize_annotation(self, annotation: Dict, image_size: Tuple[int, int]) -> Dict: def normalize_annotation(self, annotation: Dict, image_size: Tuple[int, int]) -> Dict:
...@@ -639,18 +693,24 @@ class DetaImageProcessor(BaseImageProcessor): ...@@ -639,18 +693,24 @@ class DetaImageProcessor(BaseImageProcessor):
output_size: Tuple[int, int], output_size: Tuple[int, int],
constant_values: Union[float, Iterable[float]] = 0, constant_values: Union[float, Iterable[float]] = 0,
data_format: Optional[ChannelDimension] = None, data_format: Optional[ChannelDimension] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray: ) -> np.ndarray:
""" """
Pad an image with zeros to the given size. Pad an image with zeros to the given size.
""" """
input_height, input_width = get_image_size(image) input_height, input_width = get_image_size(image, channel_dim=input_data_format)
output_height, output_width = output_size output_height, output_width = output_size
pad_bottom = output_height - input_height pad_bottom = output_height - input_height
pad_right = output_width - input_width pad_right = output_width - input_width
padding = ((0, pad_bottom), (0, pad_right)) padding = ((0, pad_bottom), (0, pad_right))
padded_image = pad( padded_image = pad(
image, padding, mode=PaddingMode.CONSTANT, constant_values=constant_values, data_format=data_format image,
padding,
mode=PaddingMode.CONSTANT,
constant_values=constant_values,
data_format=data_format,
input_data_format=input_data_format,
) )
return padded_image return padded_image
...@@ -662,6 +722,7 @@ class DetaImageProcessor(BaseImageProcessor): ...@@ -662,6 +722,7 @@ class DetaImageProcessor(BaseImageProcessor):
return_pixel_mask: bool = True, return_pixel_mask: bool = True,
return_tensors: Optional[Union[str, TensorType]] = None, return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Optional[ChannelDimension] = None, data_format: Optional[ChannelDimension] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> BatchFeature: ) -> BatchFeature:
""" """
Pads a batch of images to the bottom and right of the image with zeros to the size of largest height and width Pads a batch of images to the bottom and right of the image with zeros to the size of largest height and width
...@@ -683,17 +744,28 @@ class DetaImageProcessor(BaseImageProcessor): ...@@ -683,17 +744,28 @@ class DetaImageProcessor(BaseImageProcessor):
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
data_format (`str` or `ChannelDimension`, *optional*): data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image. The channel dimension format of the image. If not provided, it will be the same as the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred.
""" """
pad_size = get_max_height_width(images) pad_size = get_max_height_width(images, input_data_format=input_data_format)
padded_images = [ padded_images = [
self._pad_image(image, pad_size, constant_values=constant_values, data_format=data_format) self._pad_image(
image,
pad_size,
constant_values=constant_values,
data_format=data_format,
input_data_format=input_data_format,
)
for image in images for image in images
] ]
data = {"pixel_values": padded_images} data = {"pixel_values": padded_images}
if return_pixel_mask: if return_pixel_mask:
masks = [make_pixel_mask(image=image, output_size=pad_size) for image in images] masks = [
make_pixel_mask(image=image, output_size=pad_size, input_data_format=input_data_format)
for image in images
]
data["pixel_mask"] = masks data["pixel_mask"] = masks
return BatchFeature(data=data, tensor_type=return_tensors) return BatchFeature(data=data, tensor_type=return_tensors)
...@@ -716,6 +788,7 @@ class DetaImageProcessor(BaseImageProcessor): ...@@ -716,6 +788,7 @@ class DetaImageProcessor(BaseImageProcessor):
format: Optional[Union[str, AnnotionFormat]] = None, format: Optional[Union[str, AnnotionFormat]] = None,
return_tensors: Optional[Union[TensorType, str]] = None, return_tensors: Optional[Union[TensorType, str]] = None,
data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST, data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs, **kwargs,
) -> BatchFeature: ) -> BatchFeature:
""" """
...@@ -761,8 +834,17 @@ class DetaImageProcessor(BaseImageProcessor): ...@@ -761,8 +834,17 @@ class DetaImageProcessor(BaseImageProcessor):
Format of the annotations. Format of the annotations.
return_tensors (`str` or `TensorType`, *optional*, defaults to self.return_tensors): return_tensors (`str` or `TensorType`, *optional*, defaults to self.return_tensors):
Type of tensors to return. If `None`, will return the list of images. Type of tensors to return. If `None`, will return the list of images.
data_format (`str` or `ChannelDimension`, *optional*, defaults to self.data_format): data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
The channel dimension format of the image. If not provided, it will be the same as the input image. The channel dimension format for the output image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- Unset: Use the channel dimension format of the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
""" """
if "pad_and_return_pixel_mask" in kwargs: if "pad_and_return_pixel_mask" in kwargs:
logger.warning_once( logger.warning_once(
...@@ -839,13 +921,22 @@ class DetaImageProcessor(BaseImageProcessor): ...@@ -839,13 +921,22 @@ class DetaImageProcessor(BaseImageProcessor):
# All transformations expect numpy arrays # All transformations expect numpy arrays
images = [to_numpy_array(image) for image in images] images = [to_numpy_array(image) for image in images]
if input_data_format is None:
# We assume that all images have the same channel dimension format.
input_data_format = infer_channel_dimension_format(images[0])
# prepare (COCO annotations as a list of Dict -> DETR target as a single Dict per image) # prepare (COCO annotations as a list of Dict -> DETR target as a single Dict per image)
if annotations is not None: if annotations is not None:
prepared_images = [] prepared_images = []
prepared_annotations = [] prepared_annotations = []
for image, target in zip(images, annotations): for image, target in zip(images, annotations):
target = self.prepare_annotation( target = self.prepare_annotation(
image, target, format, return_segmentation_masks=return_segmentation_masks, masks_path=masks_path image,
target,
format,
return_segmentation_masks=return_segmentation_masks,
masks_path=masks_path,
input_data_format=input_data_format,
) )
prepared_images.append(image) prepared_images.append(image)
prepared_annotations.append(target) prepared_annotations.append(target)
...@@ -858,33 +949,47 @@ class DetaImageProcessor(BaseImageProcessor): ...@@ -858,33 +949,47 @@ class DetaImageProcessor(BaseImageProcessor):
if annotations is not None: if annotations is not None:
resized_images, resized_annotations = [], [] resized_images, resized_annotations = [], []
for image, target in zip(images, annotations): for image, target in zip(images, annotations):
orig_size = get_image_size(image) orig_size = get_image_size(image, input_data_format)
resized_image = self.resize(image, size=size, resample=resample) resized_image = self.resize(
resized_annotation = self.resize_annotation(target, orig_size, get_image_size(resized_image)) image, size=size, resample=resample, input_data_format=input_data_format
)
resized_annotation = self.resize_annotation(
target, orig_size, get_image_size(resized_image, input_data_format)
)
resized_images.append(resized_image) resized_images.append(resized_image)
resized_annotations.append(resized_annotation) resized_annotations.append(resized_annotation)
images = resized_images images = resized_images
annotations = resized_annotations annotations = resized_annotations
del resized_images, resized_annotations del resized_images, resized_annotations
else: else:
images = [self.resize(image, size=size, resample=resample) for image in images] images = [
self.resize(image, size=size, resample=resample, input_data_format=input_data_format)
for image in images
]
if do_rescale: if do_rescale:
images = [self.rescale(image, rescale_factor) for image in images] images = [self.rescale(image, rescale_factor, input_data_format=input_data_format) for image in images]
if do_normalize: if do_normalize:
images = [self.normalize(image, image_mean, image_std) for image in images] images = [
self.normalize(image, image_mean, image_std, input_data_format=input_data_format) for image in images
]
if annotations is not None: if annotations is not None:
annotations = [ annotations = [
self.normalize_annotation(annotation, get_image_size(image)) self.normalize_annotation(annotation, get_image_size(image, input_data_format))
for annotation, image in zip(annotations, images) for annotation, image in zip(annotations, images)
] ]
if do_pad: if do_pad:
# Pads images and returns their mask: {'pixel_values': ..., 'pixel_mask': ...} # Pads images and returns their mask: {'pixel_values': ..., 'pixel_mask': ...}
data = self.pad(images, return_pixel_mask=True, data_format=data_format) data = self.pad(
images, return_pixel_mask=True, data_format=data_format, input_data_format=input_data_format
)
else: else:
images = [to_channel_dimension_format(image, data_format) for image in images] images = [
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
for image in images
]
data = {"pixel_values": images} data = {"pixel_values": images}
encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors) encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
......
...@@ -121,7 +121,10 @@ def get_size_with_aspect_ratio(image_size, size, max_size=None) -> Tuple[int, in ...@@ -121,7 +121,10 @@ def get_size_with_aspect_ratio(image_size, size, max_size=None) -> Tuple[int, in
def get_resize_output_image_size( def get_resize_output_image_size(
input_image: np.ndarray, size: Union[int, Tuple[int, int], List[int]], max_size: Optional[int] = None input_image: np.ndarray,
size: Union[int, Tuple[int, int], List[int]],
max_size: Optional[int] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> Tuple[int, int]: ) -> Tuple[int, int]:
""" """
Computes the output image size given the input image size and the desired output size. If the desired output size Computes the output image size given the input image size and the desired output size. If the desired output size
...@@ -135,8 +138,10 @@ def get_resize_output_image_size( ...@@ -135,8 +138,10 @@ def get_resize_output_image_size(
The desired output size. The desired output size.
max_size (`int`, *optional*): max_size (`int`, *optional*):
The maximum allowed output size. The maximum allowed output size.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred from the input image.
""" """
image_size = get_image_size(input_image) image_size = get_image_size(input_image, input_data_format)
if isinstance(size, (list, tuple)): if isinstance(size, (list, tuple)):
return size return size
...@@ -203,23 +208,28 @@ def max_across_indices(values: Iterable[Any]) -> List[Any]: ...@@ -203,23 +208,28 @@ def max_across_indices(values: Iterable[Any]) -> List[Any]:
# Copied from transformers.models.vilt.image_processing_vilt.get_max_height_width # Copied from transformers.models.vilt.image_processing_vilt.get_max_height_width
def get_max_height_width(images: List[np.ndarray]) -> List[int]: def get_max_height_width(
images: List[np.ndarray], input_data_format: Optional[Union[str, ChannelDimension]] = None
) -> List[int]:
""" """
Get the maximum height and width across all images in a batch. Get the maximum height and width across all images in a batch.
""" """
input_channel_dimension = infer_channel_dimension_format(images[0]) if input_data_format is None:
input_data_format = infer_channel_dimension_format(images[0])
if input_channel_dimension == ChannelDimension.FIRST: if input_data_format == ChannelDimension.FIRST:
_, max_height, max_width = max_across_indices([img.shape for img in images]) _, max_height, max_width = max_across_indices([img.shape for img in images])
elif input_channel_dimension == ChannelDimension.LAST: elif input_data_format == ChannelDimension.LAST:
max_height, max_width, _ = max_across_indices([img.shape for img in images]) max_height, max_width, _ = max_across_indices([img.shape for img in images])
else: else:
raise ValueError(f"Invalid channel dimension format: {input_channel_dimension}") raise ValueError(f"Invalid channel dimension format: {input_data_format}")
return (max_height, max_width) return (max_height, max_width)
# Copied from transformers.models.vilt.image_processing_vilt.make_pixel_mask # Copied from transformers.models.vilt.image_processing_vilt.make_pixel_mask
def make_pixel_mask(image: np.ndarray, output_size: Tuple[int, int]) -> np.ndarray: def make_pixel_mask(
image: np.ndarray, output_size: Tuple[int, int], input_data_format: Optional[Union[str, ChannelDimension]] = None
) -> np.ndarray:
""" """
Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding. Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.
...@@ -229,7 +239,7 @@ def make_pixel_mask(image: np.ndarray, output_size: Tuple[int, int]) -> np.ndarr ...@@ -229,7 +239,7 @@ def make_pixel_mask(image: np.ndarray, output_size: Tuple[int, int]) -> np.ndarr
output_size (`Tuple[int, int]`): output_size (`Tuple[int, int]`):
Output size of the mask. Output size of the mask.
""" """
input_height, input_width = get_image_size(image) input_height, input_width = get_image_size(image, channel_dim=input_data_format)
mask = np.zeros(output_size, dtype=np.int64) mask = np.zeros(output_size, dtype=np.int64)
mask[:input_height, :input_width] = 1 mask[:input_height, :input_width] = 1
return mask return mask
...@@ -271,11 +281,16 @@ def convert_coco_poly_to_mask(segmentations, height: int, width: int) -> np.ndar ...@@ -271,11 +281,16 @@ def convert_coco_poly_to_mask(segmentations, height: int, width: int) -> np.ndar
# inspired by https://github.com/facebookresearch/detr/blob/master/datasets/coco.py#L50 # inspired by https://github.com/facebookresearch/detr/blob/master/datasets/coco.py#L50
def prepare_coco_detection_annotation(image, target, return_segmentation_masks: bool = False): def prepare_coco_detection_annotation(
image,
target,
return_segmentation_masks: bool = False,
input_data_format: Optional[Union[ChannelDimension, str]] = None,
):
""" """
Convert the target in COCO format into the format expected by DETR. Convert the target in COCO format into the format expected by DETR.
""" """
image_height, image_width = get_image_size(image) image_height, image_width = get_image_size(image, channel_dim=input_data_format)
image_id = target["image_id"] image_id = target["image_id"]
image_id = np.asarray([image_id], dtype=np.int64) image_id = np.asarray([image_id], dtype=np.int64)
...@@ -358,12 +373,16 @@ def masks_to_boxes(masks: np.ndarray) -> np.ndarray: ...@@ -358,12 +373,16 @@ def masks_to_boxes(masks: np.ndarray) -> np.ndarray:
def prepare_coco_panoptic_annotation( def prepare_coco_panoptic_annotation(
image: np.ndarray, target: Dict, masks_path: Union[str, pathlib.Path], return_masks: bool = True image: np.ndarray,
target: Dict,
masks_path: Union[str, pathlib.Path],
return_masks: bool = True,
input_data_format: Union[ChannelDimension, str] = None,
) -> Dict: ) -> Dict:
""" """
Prepare a coco panoptic annotation for DETR. Prepare a coco panoptic annotation for DETR.
""" """
image_height, image_width = get_image_size(image) image_height, image_width = get_image_size(image, channel_dim=input_data_format)
annotation_path = pathlib.Path(masks_path) / target["file_name"] annotation_path = pathlib.Path(masks_path) / target["file_name"]
new_target = {} new_target = {}
...@@ -822,6 +841,7 @@ class DetrImageProcessor(BaseImageProcessor): ...@@ -822,6 +841,7 @@ class DetrImageProcessor(BaseImageProcessor):
format: Optional[AnnotionFormat] = None, format: Optional[AnnotionFormat] = None,
return_segmentation_masks: bool = None, return_segmentation_masks: bool = None,
masks_path: Optional[Union[str, pathlib.Path]] = None, masks_path: Optional[Union[str, pathlib.Path]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> Dict: ) -> Dict:
""" """
Prepare an annotation for feeding into DETR model. Prepare an annotation for feeding into DETR model.
...@@ -830,11 +850,17 @@ class DetrImageProcessor(BaseImageProcessor): ...@@ -830,11 +850,17 @@ class DetrImageProcessor(BaseImageProcessor):
if format == AnnotionFormat.COCO_DETECTION: if format == AnnotionFormat.COCO_DETECTION:
return_segmentation_masks = False if return_segmentation_masks is None else return_segmentation_masks return_segmentation_masks = False if return_segmentation_masks is None else return_segmentation_masks
target = prepare_coco_detection_annotation(image, target, return_segmentation_masks) target = prepare_coco_detection_annotation(
image, target, return_segmentation_masks, input_data_format=input_data_format
)
elif format == AnnotionFormat.COCO_PANOPTIC: elif format == AnnotionFormat.COCO_PANOPTIC:
return_segmentation_masks = True if return_segmentation_masks is None else return_segmentation_masks return_segmentation_masks = True if return_segmentation_masks is None else return_segmentation_masks
target = prepare_coco_panoptic_annotation( target = prepare_coco_panoptic_annotation(
image, target, masks_path=masks_path, return_masks=return_segmentation_masks image,
target,
masks_path=masks_path,
return_masks=return_segmentation_masks,
input_data_format=input_data_format,
) )
else: else:
raise ValueError(f"Format {format} is not supported.") raise ValueError(f"Format {format} is not supported.")
...@@ -867,11 +893,26 @@ class DetrImageProcessor(BaseImageProcessor): ...@@ -867,11 +893,26 @@ class DetrImageProcessor(BaseImageProcessor):
size: Dict[str, int], size: Dict[str, int],
resample: PILImageResampling = PILImageResampling.BILINEAR, resample: PILImageResampling = PILImageResampling.BILINEAR,
data_format: Optional[ChannelDimension] = None, data_format: Optional[ChannelDimension] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs, **kwargs,
) -> np.ndarray: ) -> np.ndarray:
""" """
Resize the image to the given size. Size can be `min_size` (scalar) or `(height, width)` tuple. If size is an Resize the image to the given size. Size can be `min_size` (scalar) or `(height, width)` tuple. If size is an
int, smaller edge of the image will be matched to this number. int, smaller edge of the image will be matched to this number.
Args:
image (`np.ndarray`):
Image to resize.
size (`Dict[str, int]`):
Dictionary containing the size to resize to. Can contain the keys `shortest_edge` and `longest_edge` or
`height` and `width`.
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
Resampling filter to use if resizing the image.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format for the output image. If unset, the channel dimension format of the input
image is used.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred.
""" """
if "max_size" in kwargs: if "max_size" in kwargs:
logger.warning_once( logger.warning_once(
...@@ -883,7 +924,9 @@ class DetrImageProcessor(BaseImageProcessor): ...@@ -883,7 +924,9 @@ class DetrImageProcessor(BaseImageProcessor):
max_size = None max_size = None
size = get_size_dict(size, max_size=max_size, default_to_square=False) size = get_size_dict(size, max_size=max_size, default_to_square=False)
if "shortest_edge" in size and "longest_edge" in size: if "shortest_edge" in size and "longest_edge" in size:
size = get_resize_output_image_size(image, size["shortest_edge"], size["longest_edge"]) size = get_resize_output_image_size(
image, size["shortest_edge"], size["longest_edge"], input_data_format=input_data_format
)
elif "height" in size and "width" in size: elif "height" in size and "width" in size:
size = (size["height"], size["width"]) size = (size["height"], size["width"])
else: else:
...@@ -891,7 +934,9 @@ class DetrImageProcessor(BaseImageProcessor): ...@@ -891,7 +934,9 @@ class DetrImageProcessor(BaseImageProcessor):
"Size must contain 'height' and 'width' keys or 'shortest_edge' and 'longest_edge' keys. Got" "Size must contain 'height' and 'width' keys or 'shortest_edge' and 'longest_edge' keys. Got"
f" {size.keys()}." f" {size.keys()}."
) )
image = resize(image, size=size, resample=resample, data_format=data_format) image = resize(
image, size=size, resample=resample, data_format=data_format, input_data_format=input_data_format, **kwargs
)
return image return image
def resize_annotation( def resize_annotation(
...@@ -909,7 +954,11 @@ class DetrImageProcessor(BaseImageProcessor): ...@@ -909,7 +954,11 @@ class DetrImageProcessor(BaseImageProcessor):
# TODO (Amy) - update to use `rescale_factor` instead of `scale` # TODO (Amy) - update to use `rescale_factor` instead of `scale`
def rescale( def rescale(
self, image: np.ndarray, rescale_factor: float, data_format: Optional[Union[str, ChannelDimension]] = None self,
image: np.ndarray,
rescale_factor: float,
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray: ) -> np.ndarray:
""" """
Rescale the image by the given factor. image = image * rescale_factor. Rescale the image by the given factor. image = image * rescale_factor.
...@@ -924,8 +973,13 @@ class DetrImageProcessor(BaseImageProcessor): ...@@ -924,8 +973,13 @@ class DetrImageProcessor(BaseImageProcessor):
image is used. Can be one of: image is used. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
input_data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format for the input image. If unset, is inferred from the input image. Can be
one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
""" """
return rescale(image, rescale_factor, data_format=data_format) return rescale(image, rescale_factor, data_format=data_format, input_data_format=input_data_format)
def normalize_annotation(self, annotation: Dict, image_size: Tuple[int, int]) -> Dict: def normalize_annotation(self, annotation: Dict, image_size: Tuple[int, int]) -> Dict:
""" """
...@@ -940,18 +994,24 @@ class DetrImageProcessor(BaseImageProcessor): ...@@ -940,18 +994,24 @@ class DetrImageProcessor(BaseImageProcessor):
output_size: Tuple[int, int], output_size: Tuple[int, int],
constant_values: Union[float, Iterable[float]] = 0, constant_values: Union[float, Iterable[float]] = 0,
data_format: Optional[ChannelDimension] = None, data_format: Optional[ChannelDimension] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray: ) -> np.ndarray:
""" """
Pad an image with zeros to the given size. Pad an image with zeros to the given size.
""" """
input_height, input_width = get_image_size(image) input_height, input_width = get_image_size(image, channel_dim=input_data_format)
output_height, output_width = output_size output_height, output_width = output_size
pad_bottom = output_height - input_height pad_bottom = output_height - input_height
pad_right = output_width - input_width pad_right = output_width - input_width
padding = ((0, pad_bottom), (0, pad_right)) padding = ((0, pad_bottom), (0, pad_right))
padded_image = pad( padded_image = pad(
image, padding, mode=PaddingMode.CONSTANT, constant_values=constant_values, data_format=data_format image,
padding,
mode=PaddingMode.CONSTANT,
constant_values=constant_values,
data_format=data_format,
input_data_format=input_data_format,
) )
return padded_image return padded_image
...@@ -962,6 +1022,7 @@ class DetrImageProcessor(BaseImageProcessor): ...@@ -962,6 +1022,7 @@ class DetrImageProcessor(BaseImageProcessor):
return_pixel_mask: bool = True, return_pixel_mask: bool = True,
return_tensors: Optional[Union[str, TensorType]] = None, return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Optional[ChannelDimension] = None, data_format: Optional[ChannelDimension] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> BatchFeature: ) -> BatchFeature:
""" """
Pads a batch of images to the bottom and right of the image with zeros to the size of largest height and width Pads a batch of images to the bottom and right of the image with zeros to the size of largest height and width
...@@ -983,17 +1044,28 @@ class DetrImageProcessor(BaseImageProcessor): ...@@ -983,17 +1044,28 @@ class DetrImageProcessor(BaseImageProcessor):
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
data_format (`str` or `ChannelDimension`, *optional*): data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image. The channel dimension format of the image. If not provided, it will be the same as the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred.
""" """
pad_size = get_max_height_width(images) pad_size = get_max_height_width(images, input_data_format=input_data_format)
padded_images = [ padded_images = [
self._pad_image(image, pad_size, constant_values=constant_values, data_format=data_format) self._pad_image(
image,
pad_size,
constant_values=constant_values,
data_format=data_format,
input_data_format=input_data_format,
)
for image in images for image in images
] ]
data = {"pixel_values": padded_images} data = {"pixel_values": padded_images}
if return_pixel_mask: if return_pixel_mask:
masks = [make_pixel_mask(image=image, output_size=pad_size) for image in images] masks = [
make_pixel_mask(image=image, output_size=pad_size, input_data_format=input_data_format)
for image in images
]
data["pixel_mask"] = masks data["pixel_mask"] = masks
return BatchFeature(data=data, tensor_type=return_tensors) return BatchFeature(data=data, tensor_type=return_tensors)
...@@ -1016,6 +1088,7 @@ class DetrImageProcessor(BaseImageProcessor): ...@@ -1016,6 +1088,7 @@ class DetrImageProcessor(BaseImageProcessor):
format: Optional[Union[str, AnnotionFormat]] = None, format: Optional[Union[str, AnnotionFormat]] = None,
return_tensors: Optional[Union[TensorType, str]] = None, return_tensors: Optional[Union[TensorType, str]] = None,
data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST, data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs, **kwargs,
) -> BatchFeature: ) -> BatchFeature:
""" """
...@@ -1061,8 +1134,17 @@ class DetrImageProcessor(BaseImageProcessor): ...@@ -1061,8 +1134,17 @@ class DetrImageProcessor(BaseImageProcessor):
Format of the annotations. Format of the annotations.
return_tensors (`str` or `TensorType`, *optional*, defaults to self.return_tensors): return_tensors (`str` or `TensorType`, *optional*, defaults to self.return_tensors):
Type of tensors to return. If `None`, will return the list of images. Type of tensors to return. If `None`, will return the list of images.
data_format (`str` or `ChannelDimension`, *optional*, defaults to self.data_format): data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
The channel dimension format of the image. If not provided, it will be the same as the input image. The channel dimension format for the output image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- Unset: Use the channel dimension format of the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
""" """
if "pad_and_return_pixel_mask" in kwargs: if "pad_and_return_pixel_mask" in kwargs:
logger.warning_once( logger.warning_once(
...@@ -1147,13 +1229,22 @@ class DetrImageProcessor(BaseImageProcessor): ...@@ -1147,13 +1229,22 @@ class DetrImageProcessor(BaseImageProcessor):
# All transformations expect numpy arrays # All transformations expect numpy arrays
images = [to_numpy_array(image) for image in images] images = [to_numpy_array(image) for image in images]
if input_data_format is None:
# We assume that all images have the same channel dimension format.
input_data_format = infer_channel_dimension_format(images[0])
# prepare (COCO annotations as a list of Dict -> DETR target as a single Dict per image) # prepare (COCO annotations as a list of Dict -> DETR target as a single Dict per image)
if annotations is not None: if annotations is not None:
prepared_images = [] prepared_images = []
prepared_annotations = [] prepared_annotations = []
for image, target in zip(images, annotations): for image, target in zip(images, annotations):
target = self.prepare_annotation( target = self.prepare_annotation(
image, target, format, return_segmentation_masks=return_segmentation_masks, masks_path=masks_path image,
target,
format,
return_segmentation_masks=return_segmentation_masks,
masks_path=masks_path,
input_data_format=input_data_format,
) )
prepared_images.append(image) prepared_images.append(image)
prepared_annotations.append(target) prepared_annotations.append(target)
...@@ -1166,33 +1257,49 @@ class DetrImageProcessor(BaseImageProcessor): ...@@ -1166,33 +1257,49 @@ class DetrImageProcessor(BaseImageProcessor):
if annotations is not None: if annotations is not None:
resized_images, resized_annotations = [], [] resized_images, resized_annotations = [], []
for image, target in zip(images, annotations): for image, target in zip(images, annotations):
orig_size = get_image_size(image) orig_size = get_image_size(image, input_data_format)
resized_image = self.resize(image, size=size, max_size=max_size, resample=resample) resized_image = self.resize(
resized_annotation = self.resize_annotation(target, orig_size, get_image_size(resized_image)) image, size=size, max_size=max_size, resample=resample, input_data_format=input_data_format
)
resized_annotation = self.resize_annotation(
target, orig_size, get_image_size(resized_image, input_data_format)
)
resized_images.append(resized_image) resized_images.append(resized_image)
resized_annotations.append(resized_annotation) resized_annotations.append(resized_annotation)
images = resized_images images = resized_images
annotations = resized_annotations annotations = resized_annotations
del resized_images, resized_annotations del resized_images, resized_annotations
else: else:
images = [self.resize(image, size=size, resample=resample) for image in images] images = [
self.resize(image, size=size, resample=resample, input_data_format=input_data_format)
for image in images
]
if do_rescale: if do_rescale:
images = [self.rescale(image, rescale_factor) for image in images] images = [self.rescale(image, rescale_factor, input_data_format=input_data_format) for image in images]
if do_normalize: if do_normalize:
images = [self.normalize(image, image_mean, image_std) for image in images] images = [
self.normalize(image, image_mean, image_std, input_data_format=input_data_format) for image in images
]
if annotations is not None: if annotations is not None:
annotations = [ annotations = [
self.normalize_annotation(annotation, get_image_size(image)) self.normalize_annotation(
annotation, get_image_size(image, input_data_format), input_data_format=input_data_format
)
for annotation, image in zip(annotations, images) for annotation, image in zip(annotations, images)
] ]
if do_pad: if do_pad:
# Pads images and returns their mask: {'pixel_values': ..., 'pixel_mask': ...} # Pads images and returns their mask: {'pixel_values': ..., 'pixel_mask': ...}
data = self.pad(images, return_pixel_mask=True, data_format=data_format) data = self.pad(
images, return_pixel_mask=True, data_format=data_format, input_data_format=input_data_format
)
else: else:
images = [to_channel_dimension_format(image, data_format) for image in images] images = [
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
for image in images
]
data = {"pixel_values": images} data = {"pixel_values": images}
encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors) encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
......
...@@ -32,6 +32,7 @@ from ...image_utils import ( ...@@ -32,6 +32,7 @@ from ...image_utils import (
ImageInput, ImageInput,
PILImageResampling, PILImageResampling,
get_image_size, get_image_size,
infer_channel_dimension_format,
make_list_of_images, make_list_of_images,
to_numpy_array, to_numpy_array,
valid_images, valid_images,
...@@ -122,7 +123,11 @@ class DonutImageProcessor(BaseImageProcessor): ...@@ -122,7 +123,11 @@ class DonutImageProcessor(BaseImageProcessor):
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
def align_long_axis( def align_long_axis(
self, image: np.ndarray, size: Dict[str, int], data_format: Optional[Union[str, ChannelDimension]] = None self,
image: np.ndarray,
size: Dict[str, int],
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray: ) -> np.ndarray:
""" """
Align the long axis of the image to the longest axis of the specified size. Align the long axis of the image to the longest axis of the specified size.
...@@ -132,11 +137,15 @@ class DonutImageProcessor(BaseImageProcessor): ...@@ -132,11 +137,15 @@ class DonutImageProcessor(BaseImageProcessor):
The image to be aligned. The image to be aligned.
size (`Dict[str, int]`): size (`Dict[str, int]`):
The size `{"height": h, "width": w}` to align the long axis to. The size `{"height": h, "width": w}` to align the long axis to.
data_format (`str` or `ChannelDimension`, *optional*):
The data format of the output image. If unset, the same format as the input image is used.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred.
Returns: Returns:
`np.ndarray`: The aligned image. `np.ndarray`: The aligned image.
""" """
input_height, input_width = get_image_size(image) input_height, input_width = get_image_size(image, channel_dim=input_data_format)
output_height, output_width = size["height"], size["width"] output_height, output_width = size["height"], size["width"]
if (output_width < output_height and input_width > input_height) or ( if (output_width < output_height and input_width > input_height) or (
...@@ -145,7 +154,7 @@ class DonutImageProcessor(BaseImageProcessor): ...@@ -145,7 +154,7 @@ class DonutImageProcessor(BaseImageProcessor):
image = np.rot90(image, 3) image = np.rot90(image, 3)
if data_format is not None: if data_format is not None:
image = to_channel_dimension_format(image, data_format) image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
return image return image
...@@ -155,6 +164,7 @@ class DonutImageProcessor(BaseImageProcessor): ...@@ -155,6 +164,7 @@ class DonutImageProcessor(BaseImageProcessor):
size: Dict[str, int], size: Dict[str, int],
random_padding: bool = False, random_padding: bool = False,
data_format: Optional[Union[str, ChannelDimension]] = None, data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray: ) -> np.ndarray:
""" """
Pad the image to the specified size. Pad the image to the specified size.
...@@ -168,9 +178,11 @@ class DonutImageProcessor(BaseImageProcessor): ...@@ -168,9 +178,11 @@ class DonutImageProcessor(BaseImageProcessor):
Whether to use random padding or not. Whether to use random padding or not.
data_format (`str` or `ChannelDimension`, *optional*): data_format (`str` or `ChannelDimension`, *optional*):
The data format of the output image. If unset, the same format as the input image is used. The data format of the output image. If unset, the same format as the input image is used.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred.
""" """
output_height, output_width = size["height"], size["width"] output_height, output_width = size["height"], size["width"]
input_height, input_width = get_image_size(image) input_height, input_width = get_image_size(image, channel_dim=input_data_format)
delta_width = output_width - input_width delta_width = output_width - input_width
delta_height = output_height - input_height delta_height = output_height - input_height
...@@ -186,7 +198,7 @@ class DonutImageProcessor(BaseImageProcessor): ...@@ -186,7 +198,7 @@ class DonutImageProcessor(BaseImageProcessor):
pad_right = delta_width - pad_left pad_right = delta_width - pad_left
padding = ((pad_top, pad_bottom), (pad_left, pad_right)) padding = ((pad_top, pad_bottom), (pad_left, pad_right))
return pad(image, padding, data_format=data_format) return pad(image, padding, data_format=data_format, input_data_format=input_data_format)
def pad(self, *args, **kwargs): def pad(self, *args, **kwargs):
logger.info("pad is deprecated and will be removed in version 4.27. Please use pad_image instead.") logger.info("pad is deprecated and will be removed in version 4.27. Please use pad_image instead.")
...@@ -198,6 +210,7 @@ class DonutImageProcessor(BaseImageProcessor): ...@@ -198,6 +210,7 @@ class DonutImageProcessor(BaseImageProcessor):
size: Dict[str, int], size: Dict[str, int],
resample: PILImageResampling = PILImageResampling.BICUBIC, resample: PILImageResampling = PILImageResampling.BICUBIC,
data_format: Optional[Union[str, ChannelDimension]] = None, data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs, **kwargs,
) -> np.ndarray: ) -> np.ndarray:
""" """
...@@ -213,8 +226,10 @@ class DonutImageProcessor(BaseImageProcessor): ...@@ -213,8 +226,10 @@ class DonutImageProcessor(BaseImageProcessor):
The resampling filter to use. The resampling filter to use.
data_format (`Optional[Union[str, ChannelDimension]]`, *optional*): data_format (`Optional[Union[str, ChannelDimension]]`, *optional*):
The data format of the output image. If unset, the same format as the input image is used. The data format of the output image. If unset, the same format as the input image is used.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred.
""" """
input_height, input_width = get_image_size(image) input_height, input_width = get_image_size(image, channel_dim=input_data_format)
output_height, output_width = size["height"], size["width"] output_height, output_width = size["height"], size["width"]
# We always resize to the smallest of either the input or output size. # We always resize to the smallest of either the input or output size.
...@@ -230,7 +245,13 @@ class DonutImageProcessor(BaseImageProcessor): ...@@ -230,7 +245,13 @@ class DonutImageProcessor(BaseImageProcessor):
height = int(input_height * width / input_width) height = int(input_height * width / input_width)
return resize( return resize(
image, size=(height, width), resample=resample, reducing_gap=2.0, data_format=data_format, **kwargs image,
size=(height, width),
resample=resample,
reducing_gap=2.0,
data_format=data_format,
input_data_format=input_data_format,
**kwargs,
) )
def resize( def resize(
...@@ -239,6 +260,7 @@ class DonutImageProcessor(BaseImageProcessor): ...@@ -239,6 +260,7 @@ class DonutImageProcessor(BaseImageProcessor):
size: Dict[str, int], size: Dict[str, int],
resample: PILImageResampling = PILImageResampling.BICUBIC, resample: PILImageResampling = PILImageResampling.BICUBIC,
data_format: Optional[Union[str, ChannelDimension]] = None, data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs, **kwargs,
) -> np.ndarray: ) -> np.ndarray:
""" """
...@@ -254,11 +276,22 @@ class DonutImageProcessor(BaseImageProcessor): ...@@ -254,11 +276,22 @@ class DonutImageProcessor(BaseImageProcessor):
Resampling filter to use when resiizing the image. Resampling filter to use when resiizing the image.
data_format (`str` or `ChannelDimension`, *optional*): data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image. The channel dimension format of the image. If not provided, it will be the same as the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred.
""" """
size = get_size_dict(size) size = get_size_dict(size)
shortest_edge = min(size["height"], size["width"]) shortest_edge = min(size["height"], size["width"])
output_size = get_resize_output_image_size(image, size=shortest_edge, default_to_square=False) output_size = get_resize_output_image_size(
resized_image = resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs) image, size=shortest_edge, default_to_square=False, input_data_format=input_data_format
)
resized_image = resize(
image,
size=output_size,
resample=resample,
data_format=data_format,
input_data_format=input_data_format,
**kwargs,
)
return resized_image return resized_image
def preprocess( def preprocess(
...@@ -278,6 +311,7 @@ class DonutImageProcessor(BaseImageProcessor): ...@@ -278,6 +311,7 @@ class DonutImageProcessor(BaseImageProcessor):
image_std: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None,
return_tensors: Optional[Union[str, TensorType]] = None, return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs, **kwargs,
) -> PIL.Image.Image: ) -> PIL.Image.Image:
""" """
...@@ -327,6 +361,12 @@ class DonutImageProcessor(BaseImageProcessor): ...@@ -327,6 +361,12 @@ class DonutImageProcessor(BaseImageProcessor):
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `ChannelDimension.LAST`: image in (height, width, num_channels) format. - `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- Unset: defaults to the channel dimension format of the input image. - Unset: defaults to the channel dimension format of the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
""" """
do_resize = do_resize if do_resize is not None else self.do_resize do_resize = do_resize if do_resize is not None else self.do_resize
size = size if size is not None else self.size size = size if size is not None else self.size
...@@ -367,25 +407,45 @@ class DonutImageProcessor(BaseImageProcessor): ...@@ -367,25 +407,45 @@ class DonutImageProcessor(BaseImageProcessor):
# All transformations expect numpy arrays. # All transformations expect numpy arrays.
images = [to_numpy_array(image) for image in images] images = [to_numpy_array(image) for image in images]
if input_data_format is None:
# We assume that all images have the same channel dimension format.
input_data_format = infer_channel_dimension_format(images[0])
if do_align_long_axis: if do_align_long_axis:
images = [self.align_long_axis(image, size=size) for image in images] images = [self.align_long_axis(image, size=size, input_data_format=input_data_format) for image in images]
if do_resize: if do_resize:
images = [self.resize(image=image, size=size, resample=resample) for image in images] images = [
self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
for image in images
]
if do_thumbnail: if do_thumbnail:
images = [self.thumbnail(image=image, size=size) for image in images] images = [self.thumbnail(image=image, size=size, input_data_format=input_data_format) for image in images]
if do_pad: if do_pad:
images = [self.pad_image(image=image, size=size, random_padding=random_padding) for image in images] images = [
self.pad_image(
image=image, size=size, random_padding=random_padding, input_data_format=input_data_format
)
for image in images
]
if do_rescale: if do_rescale:
images = [self.rescale(image=image, scale=rescale_factor) for image in images] images = [
self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
for image in images
]
if do_normalize: if do_normalize:
images = [self.normalize(image=image, mean=image_mean, std=image_std) for image in images] images = [
self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
images = [to_channel_dimension_format(image, data_format) for image in images] for image in images
]
images = [
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
]
data = {"pixel_values": images} data = {"pixel_values": images}
return BatchFeature(data=data, tensor_type=return_tensors) return BatchFeature(data=data, tensor_type=return_tensors)
...@@ -28,6 +28,7 @@ from ...image_utils import ( ...@@ -28,6 +28,7 @@ from ...image_utils import (
ImageInput, ImageInput,
PILImageResampling, PILImageResampling,
get_image_size, get_image_size,
infer_channel_dimension_format,
is_torch_available, is_torch_available,
is_torch_tensor, is_torch_tensor,
make_list_of_images, make_list_of_images,
...@@ -48,7 +49,11 @@ logger = logging.get_logger(__name__) ...@@ -48,7 +49,11 @@ logger = logging.get_logger(__name__)
def get_resize_output_image_size( def get_resize_output_image_size(
input_image: np.ndarray, output_size: Union[int, Iterable[int]], keep_aspect_ratio: bool, multiple: int input_image: np.ndarray,
output_size: Union[int, Iterable[int]],
keep_aspect_ratio: bool,
multiple: int,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> Tuple[int, int]: ) -> Tuple[int, int]:
def constraint_to_multiple_of(val, multiple, min_val=0, max_val=None): def constraint_to_multiple_of(val, multiple, min_val=0, max_val=None):
x = round(val / multiple) * multiple x = round(val / multiple) * multiple
...@@ -63,7 +68,7 @@ def get_resize_output_image_size( ...@@ -63,7 +68,7 @@ def get_resize_output_image_size(
output_size = (output_size, output_size) if isinstance(output_size, int) else output_size output_size = (output_size, output_size) if isinstance(output_size, int) else output_size
input_height, input_width = get_image_size(input_image) input_height, input_width = get_image_size(input_image, input_data_format)
output_height, output_width = output_size output_height, output_width = output_size
# determine new height and width # determine new height and width
...@@ -97,7 +102,7 @@ class DPTImageProcessor(BaseImageProcessor): ...@@ -97,7 +102,7 @@ class DPTImageProcessor(BaseImageProcessor):
keep_aspect_ratio (`bool`, *optional*, defaults to `False`): keep_aspect_ratio (`bool`, *optional*, defaults to `False`):
If `True`, the image is resized to the largest possible size such that the aspect ratio is preserved. Can If `True`, the image is resized to the largest possible size such that the aspect ratio is preserved. Can
be overidden by `keep_aspect_ratio` in `preprocess`. be overidden by `keep_aspect_ratio` in `preprocess`.
ensure_multiple_of (`int`, *optional*, defaults to `1`): ensure_multiple_of (`int`, *optional*, defaults to 1):
If `do_resize` is `True`, the image is resized to a size that is a multiple of this value. Can be overidden If `do_resize` is `True`, the image is resized to a size that is a multiple of this value. Can be overidden
by `ensure_multiple_of` in `preprocess`. by `ensure_multiple_of` in `preprocess`.
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`): resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
...@@ -156,6 +161,7 @@ class DPTImageProcessor(BaseImageProcessor): ...@@ -156,6 +161,7 @@ class DPTImageProcessor(BaseImageProcessor):
ensure_multiple_of: int = 1, ensure_multiple_of: int = 1,
resample: PILImageResampling = PILImageResampling.BICUBIC, resample: PILImageResampling = PILImageResampling.BICUBIC,
data_format: Optional[Union[str, ChannelDimension]] = None, data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs, **kwargs,
) -> np.ndarray: ) -> np.ndarray:
""" """
...@@ -170,7 +176,7 @@ class DPTImageProcessor(BaseImageProcessor): ...@@ -170,7 +176,7 @@ class DPTImageProcessor(BaseImageProcessor):
Target size of the output image. Target size of the output image.
keep_aspect_ratio (`bool`, *optional*, defaults to `False`): keep_aspect_ratio (`bool`, *optional*, defaults to `False`):
If `True`, the image is resized to the largest possible size such that the aspect ratio is preserved. If `True`, the image is resized to the largest possible size such that the aspect ratio is preserved.
ensure_multiple_of (`int`, *optional*, defaults to `1`): ensure_multiple_of (`int`, *optional*, defaults to 1):
The image is resized to a size that is a multiple of this value. The image is resized to a size that is a multiple of this value.
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
Defines the resampling filter to use if resizing the image. Otherwise, the image is resized to size Defines the resampling filter to use if resizing the image. Otherwise, the image is resized to size
...@@ -179,6 +185,8 @@ class DPTImageProcessor(BaseImageProcessor): ...@@ -179,6 +185,8 @@ class DPTImageProcessor(BaseImageProcessor):
Resampling filter to use when resiizing the image. Resampling filter to use when resiizing the image.
data_format (`str` or `ChannelDimension`, *optional*): data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image. The channel dimension format of the image. If not provided, it will be the same as the input image.
input_data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred.
""" """
size = get_size_dict(size) size = get_size_dict(size)
if "height" not in size or "width" not in size: if "height" not in size or "width" not in size:
...@@ -188,8 +196,16 @@ class DPTImageProcessor(BaseImageProcessor): ...@@ -188,8 +196,16 @@ class DPTImageProcessor(BaseImageProcessor):
output_size=(size["height"], size["width"]), output_size=(size["height"], size["width"]),
keep_aspect_ratio=keep_aspect_ratio, keep_aspect_ratio=keep_aspect_ratio,
multiple=ensure_multiple_of, multiple=ensure_multiple_of,
input_data_format=input_data_format,
)
return resize(
image,
size=output_size,
resample=resample,
data_format=data_format,
input_data_format=input_data_format,
**kwargs,
) )
return resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs)
def preprocess( def preprocess(
self, self,
...@@ -206,6 +222,7 @@ class DPTImageProcessor(BaseImageProcessor): ...@@ -206,6 +222,7 @@ class DPTImageProcessor(BaseImageProcessor):
image_std: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None,
return_tensors: Optional[Union[str, TensorType]] = None, return_tensors: Optional[Union[str, TensorType]] = None,
data_format: ChannelDimension = ChannelDimension.FIRST, data_format: ChannelDimension = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs, **kwargs,
) -> PIL.Image.Image: ) -> PIL.Image.Image:
""" """
...@@ -249,6 +266,12 @@ class DPTImageProcessor(BaseImageProcessor): ...@@ -249,6 +266,12 @@ class DPTImageProcessor(BaseImageProcessor):
The channel dimension format for the output image. Can be one of: The channel dimension format for the output image. Can be one of:
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `ChannelDimension.LAST`: image in (height, width, num_channels) format. - `ChannelDimension.LAST`: image in (height, width, num_channels) format.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
""" """
do_resize = do_resize if do_resize is not None else self.do_resize do_resize = do_resize if do_resize is not None else self.do_resize
size = size if size is not None else self.size size = size if size is not None else self.size
...@@ -282,16 +305,31 @@ class DPTImageProcessor(BaseImageProcessor): ...@@ -282,16 +305,31 @@ class DPTImageProcessor(BaseImageProcessor):
# All transformations expect numpy arrays. # All transformations expect numpy arrays.
images = [to_numpy_array(image) for image in images] images = [to_numpy_array(image) for image in images]
if input_data_format is None:
# We assume that all images have the same channel dimension format.
input_data_format = infer_channel_dimension_format(images[0])
if do_resize: if do_resize:
images = [self.resize(image=image, size=size, resample=resample) for image in images] images = [
self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
for image in images
]
if do_rescale: if do_rescale:
images = [self.rescale(image=image, scale=rescale_factor) for image in images] images = [
self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
for image in images
]
if do_normalize: if do_normalize:
images = [self.normalize(image=image, mean=image_mean, std=image_std) for image in images] images = [
self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
images = [to_channel_dimension_format(image, data_format) for image in images] for image in images
]
images = [
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
]
data = {"pixel_values": images} data = {"pixel_values": images}
return BatchFeature(data=data, tensor_type=return_tensors) return BatchFeature(data=data, tensor_type=return_tensors)
......
...@@ -30,6 +30,7 @@ from ...image_utils import ( ...@@ -30,6 +30,7 @@ from ...image_utils import (
ChannelDimension, ChannelDimension,
ImageInput, ImageInput,
PILImageResampling, PILImageResampling,
infer_channel_dimension_format,
is_batched, is_batched,
to_numpy_array, to_numpy_array,
valid_images, valid_images,
...@@ -116,6 +117,7 @@ class EfficientFormerImageProcessor(BaseImageProcessor): ...@@ -116,6 +117,7 @@ class EfficientFormerImageProcessor(BaseImageProcessor):
size: Dict[str, int], size: Dict[str, int],
resample: PILImageResampling = PILImageResampling.BILINEAR, resample: PILImageResampling = PILImageResampling.BILINEAR,
data_format: Optional[Union[str, ChannelDimension]] = None, data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs, **kwargs,
) -> np.ndarray: ) -> np.ndarray:
""" """
...@@ -133,6 +135,8 @@ class EfficientFormerImageProcessor(BaseImageProcessor): ...@@ -133,6 +135,8 @@ class EfficientFormerImageProcessor(BaseImageProcessor):
image is used. Can be one of: image is used. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred.
Returns: Returns:
`np.ndarray`: The resized image. `np.ndarray`: The resized image.
...@@ -140,13 +144,17 @@ class EfficientFormerImageProcessor(BaseImageProcessor): ...@@ -140,13 +144,17 @@ class EfficientFormerImageProcessor(BaseImageProcessor):
size = get_size_dict(size) size = get_size_dict(size)
if "shortest_edge" in size: if "shortest_edge" in size:
size = get_resize_output_image_size(image, size=size["shortest_edge"], default_to_square=False) size = get_resize_output_image_size(
image, size=size["shortest_edge"], default_to_square=False, input_data_format=input_data_format
)
# size = get_resize_output_image_size(image, size["shortest_edge"], size["longest_edge"]) # size = get_resize_output_image_size(image, size["shortest_edge"], size["longest_edge"])
elif "height" in size and "width" in size: elif "height" in size and "width" in size:
size = (size["height"], size["width"]) size = (size["height"], size["width"])
else: else:
raise ValueError(f"Size must contain 'height' and 'width' keys or 'shortest_edge' key. Got {size.keys()}") raise ValueError(f"Size must contain 'height' and 'width' keys or 'shortest_edge' key. Got {size.keys()}")
return resize(image, size=size, resample=resample, data_format=data_format, **kwargs) return resize(
image, size=size, resample=resample, data_format=data_format, input_data_format=input_data_format, **kwargs
)
def preprocess( def preprocess(
self, self,
...@@ -163,6 +171,7 @@ class EfficientFormerImageProcessor(BaseImageProcessor): ...@@ -163,6 +171,7 @@ class EfficientFormerImageProcessor(BaseImageProcessor):
image_std: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None,
return_tensors: Optional[Union[str, TensorType]] = None, return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST, data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs, **kwargs,
) -> BatchFeature: ) -> BatchFeature:
""" """
...@@ -205,6 +214,12 @@ class EfficientFormerImageProcessor(BaseImageProcessor): ...@@ -205,6 +214,12 @@ class EfficientFormerImageProcessor(BaseImageProcessor):
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- Unset: Use the channel dimension format of the input image. - Unset: Use the channel dimension format of the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
""" """
do_resize = do_resize if do_resize is not None else self.do_resize do_resize = do_resize if do_resize is not None else self.do_resize
do_rescale = do_rescale if do_rescale is not None else self.do_rescale do_rescale = do_rescale if do_rescale is not None else self.do_rescale
...@@ -241,19 +256,36 @@ class EfficientFormerImageProcessor(BaseImageProcessor): ...@@ -241,19 +256,36 @@ class EfficientFormerImageProcessor(BaseImageProcessor):
# All transformations expect numpy arrays. # All transformations expect numpy arrays.
images = [to_numpy_array(image) for image in images] images = [to_numpy_array(image) for image in images]
if input_data_format is None:
# We assume that all images have the same channel dimension format.
input_data_format = infer_channel_dimension_format(images[0])
if do_resize: if do_resize:
images = [self.resize(image=image, size=size_dict, resample=resample) for image in images] images = [
self.resize(image=image, size=size_dict, resample=resample, input_data_format=input_data_format)
for image in images
]
if do_center_crop: if do_center_crop:
images = [self.center_crop(image=image, size=crop_size) for image in images] images = [
self.center_crop(image=image, size=crop_size, input_data_format=input_data_format) for image in images
]
if do_rescale: if do_rescale:
images = [self.rescale(image=image, scale=rescale_factor) for image in images] images = [
self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
for image in images
]
if do_normalize: if do_normalize:
images = [self.normalize(image=image, mean=image_mean, std=image_std) for image in images] images = [
self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
for image in images
]
images = [to_channel_dimension_format(image, data_format) for image in images] images = [
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
]
data = {"pixel_values": images} data = {"pixel_values": images}
return BatchFeature(data=data, tensor_type=return_tensors) return BatchFeature(data=data, tensor_type=return_tensors)
...@@ -26,6 +26,7 @@ from ...image_utils import ( ...@@ -26,6 +26,7 @@ from ...image_utils import (
ChannelDimension, ChannelDimension,
ImageInput, ImageInput,
PILImageResampling, PILImageResampling,
infer_channel_dimension_format,
make_list_of_images, make_list_of_images,
to_numpy_array, to_numpy_array,
valid_images, valid_images,
...@@ -123,6 +124,7 @@ class EfficientNetImageProcessor(BaseImageProcessor): ...@@ -123,6 +124,7 @@ class EfficientNetImageProcessor(BaseImageProcessor):
size: Dict[str, int], size: Dict[str, int],
resample: PILImageResampling = PILImageResampling.NEAREST, resample: PILImageResampling = PILImageResampling.NEAREST,
data_format: Optional[Union[str, ChannelDimension]] = None, data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs, **kwargs,
) -> np.ndarray: ) -> np.ndarray:
""" """
...@@ -140,6 +142,13 @@ class EfficientNetImageProcessor(BaseImageProcessor): ...@@ -140,6 +142,13 @@ class EfficientNetImageProcessor(BaseImageProcessor):
image is used. Can be one of: image is used. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
Returns: Returns:
`np.ndarray`: The resized image. `np.ndarray`: The resized image.
...@@ -148,7 +157,14 @@ class EfficientNetImageProcessor(BaseImageProcessor): ...@@ -148,7 +157,14 @@ class EfficientNetImageProcessor(BaseImageProcessor):
if "height" not in size or "width" not in size: if "height" not in size or "width" not in size:
raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}") raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}")
output_size = (size["height"], size["width"]) output_size = (size["height"], size["width"])
return resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs) return resize(
image,
size=output_size,
resample=resample,
data_format=data_format,
input_data_format=input_data_format,
**kwargs,
)
def rescale( def rescale(
self, self,
...@@ -156,6 +172,7 @@ class EfficientNetImageProcessor(BaseImageProcessor): ...@@ -156,6 +172,7 @@ class EfficientNetImageProcessor(BaseImageProcessor):
scale: Union[int, float], scale: Union[int, float],
offset: bool = True, offset: bool = True,
data_format: Optional[Union[str, ChannelDimension]] = None, data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs, **kwargs,
): ):
""" """
...@@ -177,8 +194,12 @@ class EfficientNetImageProcessor(BaseImageProcessor): ...@@ -177,8 +194,12 @@ class EfficientNetImageProcessor(BaseImageProcessor):
Whether to scale the image in both negative and positive directions. Whether to scale the image in both negative and positive directions.
data_format (`str` or `ChannelDimension`, *optional*): data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image. The channel dimension format of the image. If not provided, it will be the same as the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred.
""" """
rescaled_image = rescale(image, scale=scale, data_format=data_format, **kwargs) rescaled_image = rescale(
image, scale=scale, data_format=data_format, input_data_format=input_data_format, **kwargs
)
if offset: if offset:
rescaled_image = rescaled_image - 1 rescaled_image = rescaled_image - 1
...@@ -202,6 +223,7 @@ class EfficientNetImageProcessor(BaseImageProcessor): ...@@ -202,6 +223,7 @@ class EfficientNetImageProcessor(BaseImageProcessor):
include_top: bool = None, include_top: bool = None,
return_tensors: Optional[Union[str, TensorType]] = None, return_tensors: Optional[Union[str, TensorType]] = None,
data_format: ChannelDimension = ChannelDimension.FIRST, data_format: ChannelDimension = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs, **kwargs,
) -> PIL.Image.Image: ) -> PIL.Image.Image:
""" """
...@@ -247,6 +269,12 @@ class EfficientNetImageProcessor(BaseImageProcessor): ...@@ -247,6 +269,12 @@ class EfficientNetImageProcessor(BaseImageProcessor):
The channel dimension format for the output image. Can be one of: The channel dimension format for the output image. Can be one of:
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `ChannelDimension.LAST`: image in (height, width, num_channels) format. - `ChannelDimension.LAST`: image in (height, width, num_channels) format.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
""" """
do_resize = do_resize if do_resize is not None else self.do_resize do_resize = do_resize if do_resize is not None else self.do_resize
resample = resample if resample is not None else self.resample resample = resample if resample is not None else self.resample
...@@ -287,22 +315,44 @@ class EfficientNetImageProcessor(BaseImageProcessor): ...@@ -287,22 +315,44 @@ class EfficientNetImageProcessor(BaseImageProcessor):
# All transformations expect numpy arrays. # All transformations expect numpy arrays.
images = [to_numpy_array(image) for image in images] images = [to_numpy_array(image) for image in images]
if input_data_format is None:
# We assume that all images have the same channel dimension format.
input_data_format = infer_channel_dimension_format(images[0])
if do_resize: if do_resize:
images = [self.resize(image=image, size=size, resample=resample) for image in images] images = [
self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
for image in images
]
if do_center_crop: if do_center_crop:
images = [self.center_crop(image=image, size=crop_size) for image in images] images = [
self.center_crop(image=image, size=crop_size, input_data_format=input_data_format) for image in images
]
if do_rescale: if do_rescale:
images = [self.rescale(image=image, scale=rescale_factor, offset=rescale_offset) for image in images] images = [
self.rescale(
image=image, scale=rescale_factor, offset=rescale_offset, input_data_format=input_data_format
)
for image in images
]
if do_normalize: if do_normalize:
images = [self.normalize(image=image, mean=image_mean, std=image_std) for image in images] images = [
self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
for image in images
]
if include_top: if include_top:
images = [self.normalize(image=image, mean=[0, 0, 0], std=image_std) for image in images] images = [
self.normalize(image=image, mean=0, std=image_std, input_data_format=input_data_format)
images = [to_channel_dimension_format(image, data_format) for image in images] for image in images
]
images = [
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
]
data = {"pixel_values": images} data = {"pixel_values": images}
return BatchFeature(data=data, tensor_type=return_tensors) return BatchFeature(data=data, tensor_type=return_tensors)
...@@ -29,6 +29,7 @@ from ...image_utils import ( ...@@ -29,6 +29,7 @@ from ...image_utils import (
ChannelDimension, ChannelDimension,
ImageInput, ImageInput,
PILImageResampling, PILImageResampling,
infer_channel_dimension_format,
make_list_of_images, make_list_of_images,
to_numpy_array, to_numpy_array,
valid_images, valid_images,
...@@ -338,6 +339,7 @@ class FlavaImageProcessor(BaseImageProcessor): ...@@ -338,6 +339,7 @@ class FlavaImageProcessor(BaseImageProcessor):
size: Dict[str, int], size: Dict[str, int],
resample: PILImageResampling = PILImageResampling.BICUBIC, resample: PILImageResampling = PILImageResampling.BICUBIC,
data_format: Optional[Union[str, ChannelDimension]] = None, data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs, **kwargs,
) -> np.ndarray: ) -> np.ndarray:
""" """
...@@ -355,6 +357,13 @@ class FlavaImageProcessor(BaseImageProcessor): ...@@ -355,6 +357,13 @@ class FlavaImageProcessor(BaseImageProcessor):
image is used. Can be one of: image is used. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
Returns: Returns:
`np.ndarray`: The resized image. `np.ndarray`: The resized image.
...@@ -363,7 +372,14 @@ class FlavaImageProcessor(BaseImageProcessor): ...@@ -363,7 +372,14 @@ class FlavaImageProcessor(BaseImageProcessor):
if "height" not in size or "width" not in size: if "height" not in size or "width" not in size:
raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}") raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}")
output_size = (size["height"], size["width"]) output_size = (size["height"], size["width"])
return resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs) return resize(
image,
size=output_size,
resample=resample,
data_format=data_format,
input_data_format=input_data_format,
**kwargs,
)
def map_pixels(self, image: np.ndarray) -> np.ndarray: def map_pixels(self, image: np.ndarray) -> np.ndarray:
return (1 - 2 * LOGIT_LAPLACE_EPS) * image + LOGIT_LAPLACE_EPS return (1 - 2 * LOGIT_LAPLACE_EPS) * image + LOGIT_LAPLACE_EPS
...@@ -383,6 +399,7 @@ class FlavaImageProcessor(BaseImageProcessor): ...@@ -383,6 +399,7 @@ class FlavaImageProcessor(BaseImageProcessor):
image_std: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None,
do_map_pixels: bool = None, do_map_pixels: bool = None,
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
input_data_format: Optional[ChannelDimension] = None,
) -> np.ndarray: ) -> np.ndarray:
"""Preprocesses a single image.""" """Preprocesses a single image."""
if do_resize and size is None or resample is None: if do_resize and size is None or resample is None:
...@@ -397,23 +414,27 @@ class FlavaImageProcessor(BaseImageProcessor): ...@@ -397,23 +414,27 @@ class FlavaImageProcessor(BaseImageProcessor):
# All transformations expect numpy arrays. # All transformations expect numpy arrays.
image = to_numpy_array(image) image = to_numpy_array(image)
if input_data_format is None:
# We assume that all images have the same channel dimension format.
input_data_format = infer_channel_dimension_format(image)
if do_resize: if do_resize:
image = self.resize(image=image, size=size, resample=resample) image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
if do_center_crop: if do_center_crop:
image = self.center_crop(image=image, size=crop_size) image = self.center_crop(image=image, size=crop_size, input_data_format=input_data_format)
if do_rescale: if do_rescale:
image = self.rescale(image=image, scale=rescale_factor) image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
if do_normalize: if do_normalize:
image = self.normalize(image=image, mean=image_mean, std=image_std) image = self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
if do_map_pixels: if do_map_pixels:
image = self.map_pixels(image) image = self.map_pixels(image)
if data_format is not None: if data_format is not None:
image = to_channel_dimension_format(image, data_format) image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
return image return image
def preprocess( def preprocess(
...@@ -452,6 +473,7 @@ class FlavaImageProcessor(BaseImageProcessor): ...@@ -452,6 +473,7 @@ class FlavaImageProcessor(BaseImageProcessor):
codebook_image_std: Optional[Iterable[float]] = None, codebook_image_std: Optional[Iterable[float]] = None,
return_tensors: Optional[Union[str, TensorType]] = None, return_tensors: Optional[Union[str, TensorType]] = None,
data_format: ChannelDimension = ChannelDimension.FIRST, data_format: ChannelDimension = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs, **kwargs,
) -> PIL.Image.Image: ) -> PIL.Image.Image:
""" """
...@@ -533,6 +555,12 @@ class FlavaImageProcessor(BaseImageProcessor): ...@@ -533,6 +555,12 @@ class FlavaImageProcessor(BaseImageProcessor):
The channel dimension format for the output image. Can be one of: The channel dimension format for the output image. Can be one of:
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `ChannelDimension.LAST`: image in (height, width, num_channels) format. - `ChannelDimension.LAST`: image in (height, width, num_channels) format.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
""" """
do_resize = do_resize if do_resize is not None else self.do_resize do_resize = do_resize if do_resize is not None else self.do_resize
size = size if size is not None else self.size size = size if size is not None else self.size
...@@ -615,6 +643,7 @@ class FlavaImageProcessor(BaseImageProcessor): ...@@ -615,6 +643,7 @@ class FlavaImageProcessor(BaseImageProcessor):
image_std=image_std, image_std=image_std,
do_map_pixels=False, do_map_pixels=False,
data_format=data_format, data_format=data_format,
input_data_format=input_data_format,
) )
for img in images for img in images
] ]
...@@ -636,6 +665,7 @@ class FlavaImageProcessor(BaseImageProcessor): ...@@ -636,6 +665,7 @@ class FlavaImageProcessor(BaseImageProcessor):
image_std=codebook_image_std, image_std=codebook_image_std,
do_map_pixels=codebook_do_map_pixels, do_map_pixels=codebook_do_map_pixels,
data_format=data_format, data_format=data_format,
input_data_format=input_data_format,
) )
for img in images for img in images
] ]
......
...@@ -25,6 +25,7 @@ from ...image_utils import ( ...@@ -25,6 +25,7 @@ from ...image_utils import (
ChannelDimension, ChannelDimension,
PILImageResampling, PILImageResampling,
get_image_size, get_image_size,
infer_channel_dimension_format,
make_list_of_images, make_list_of_images,
to_numpy_array, to_numpy_array,
valid_images, valid_images,
...@@ -75,6 +76,7 @@ class GLPNImageProcessor(BaseImageProcessor): ...@@ -75,6 +76,7 @@ class GLPNImageProcessor(BaseImageProcessor):
size_divisor: int, size_divisor: int,
resample: PILImageResampling = PILImageResampling.BILINEAR, resample: PILImageResampling = PILImageResampling.BILINEAR,
data_format: Optional[ChannelDimension] = None, data_format: Optional[ChannelDimension] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs, **kwargs,
) -> np.ndarray: ) -> np.ndarray:
""" """
...@@ -95,15 +97,27 @@ class GLPNImageProcessor(BaseImageProcessor): ...@@ -95,15 +97,27 @@ class GLPNImageProcessor(BaseImageProcessor):
image is used. Can be one of: image is used. Can be one of:
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `ChannelDimension.LAST`: image in (height, width, num_channels) format. - `ChannelDimension.LAST`: image in (height, width, num_channels) format.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format of the input image. If not set, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
Returns: Returns:
`np.ndarray`: The resized image. `np.ndarray`: The resized image.
""" """
height, width = get_image_size(image) height, width = get_image_size(image, channel_dim=input_data_format)
# Rounds the height and width down to the closest multiple of size_divisor # Rounds the height and width down to the closest multiple of size_divisor
new_h = height // size_divisor * size_divisor new_h = height // size_divisor * size_divisor
new_w = width // size_divisor * size_divisor new_w = width // size_divisor * size_divisor
image = resize(image, (new_h, new_w), resample=resample, data_format=data_format, **kwargs) image = resize(
image,
(new_h, new_w),
resample=resample,
data_format=data_format,
input_data_format=input_data_format,
**kwargs,
)
return image return image
def preprocess( def preprocess(
...@@ -115,6 +129,7 @@ class GLPNImageProcessor(BaseImageProcessor): ...@@ -115,6 +129,7 @@ class GLPNImageProcessor(BaseImageProcessor):
do_rescale: Optional[bool] = None, do_rescale: Optional[bool] = None,
return_tensors: Optional[Union[TensorType, str]] = None, return_tensors: Optional[Union[TensorType, str]] = None,
data_format: ChannelDimension = ChannelDimension.FIRST, data_format: ChannelDimension = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs, **kwargs,
) -> BatchFeature: ) -> BatchFeature:
""" """
...@@ -144,6 +159,12 @@ class GLPNImageProcessor(BaseImageProcessor): ...@@ -144,6 +159,12 @@ class GLPNImageProcessor(BaseImageProcessor):
The channel dimension format for the output image. Can be one of: The channel dimension format for the output image. Can be one of:
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `ChannelDimension.LAST`: image in (height, width, num_channels) format. - `ChannelDimension.LAST`: image in (height, width, num_channels) format.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
""" """
do_resize = do_resize if do_resize is not None else self.do_resize do_resize = do_resize if do_resize is not None else self.do_resize
do_rescale = do_rescale if do_rescale is not None else self.do_rescale do_rescale = do_rescale if do_rescale is not None else self.do_rescale
...@@ -161,13 +182,22 @@ class GLPNImageProcessor(BaseImageProcessor): ...@@ -161,13 +182,22 @@ class GLPNImageProcessor(BaseImageProcessor):
# All transformations expect numpy arrays. # All transformations expect numpy arrays.
images = [to_numpy_array(img) for img in images] images = [to_numpy_array(img) for img in images]
if input_data_format is None:
# We assume that all images have the same channel dimension format.
input_data_format = infer_channel_dimension_format(images[0])
if do_resize: if do_resize:
images = [self.resize(image, size_divisor=size_divisor, resample=resample) for image in images] images = [
self.resize(image, size_divisor=size_divisor, resample=resample, input_data_format=input_data_format)
for image in images
]
if do_rescale: if do_rescale:
images = [self.rescale(image, scale=1 / 255) for image in images] images = [self.rescale(image, scale=1 / 255, input_data_format=input_data_format) for image in images]
images = [to_channel_dimension_format(image, data_format) for image in images] images = [
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
]
data = {"pixel_values": images} data = {"pixel_values": images}
return BatchFeature(data=data, tensor_type=return_tensors) return BatchFeature(data=data, tensor_type=return_tensors)
...@@ -24,6 +24,7 @@ from ...image_utils import ( ...@@ -24,6 +24,7 @@ from ...image_utils import (
ChannelDimension, ChannelDimension,
ImageInput, ImageInput,
PILImageResampling, PILImageResampling,
infer_channel_dimension_format,
make_list_of_images, make_list_of_images,
to_numpy_array, to_numpy_array,
valid_images, valid_images,
...@@ -107,6 +108,7 @@ class ImageGPTImageProcessor(BaseImageProcessor): ...@@ -107,6 +108,7 @@ class ImageGPTImageProcessor(BaseImageProcessor):
size: Dict[str, int], size: Dict[str, int],
resample: PILImageResampling = PILImageResampling.BILINEAR, resample: PILImageResampling = PILImageResampling.BILINEAR,
data_format: Optional[Union[str, ChannelDimension]] = None, data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs, **kwargs,
) -> np.ndarray: ) -> np.ndarray:
""" """
...@@ -124,6 +126,13 @@ class ImageGPTImageProcessor(BaseImageProcessor): ...@@ -124,6 +126,13 @@ class ImageGPTImageProcessor(BaseImageProcessor):
image is used. Can be one of: image is used. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
Returns: Returns:
`np.ndarray`: The resized image. `np.ndarray`: The resized image.
...@@ -132,12 +141,20 @@ class ImageGPTImageProcessor(BaseImageProcessor): ...@@ -132,12 +141,20 @@ class ImageGPTImageProcessor(BaseImageProcessor):
if "height" not in size or "width" not in size: if "height" not in size or "width" not in size:
raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}") raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}")
output_size = (size["height"], size["width"]) output_size = (size["height"], size["width"])
return resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs) return resize(
image,
size=output_size,
resample=resample,
data_format=data_format,
input_data_format=input_data_format,
**kwargs,
)
def normalize( def normalize(
self, self,
image: np.ndarray, image: np.ndarray,
data_format: Optional[Union[str, ChannelDimension]] = None, data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray: ) -> np.ndarray:
""" """
Normalizes an images' pixel values to between [-1, 1]. Normalizes an images' pixel values to between [-1, 1].
...@@ -147,8 +164,10 @@ class ImageGPTImageProcessor(BaseImageProcessor): ...@@ -147,8 +164,10 @@ class ImageGPTImageProcessor(BaseImageProcessor):
Image to normalize. Image to normalize.
data_format (`str` or `ChannelDimension`, *optional*): data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image. The channel dimension format of the image. If not provided, it will be the same as the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred.
""" """
image = rescale(image=image, scale=1 / 127.5, data_format=data_format) image = rescale(image=image, scale=1 / 127.5, data_format=data_format, input_data_format=input_data_format)
image = image - 1 image = image - 1
return image return image
...@@ -163,6 +182,7 @@ class ImageGPTImageProcessor(BaseImageProcessor): ...@@ -163,6 +182,7 @@ class ImageGPTImageProcessor(BaseImageProcessor):
clusters: Optional[Union[List[List[int]], np.ndarray]] = None, clusters: Optional[Union[List[List[int]], np.ndarray]] = None,
return_tensors: Optional[Union[str, TensorType]] = None, return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Optional[Union[str, ChannelDimension]] = ChannelDimension.FIRST, data_format: Optional[Union[str, ChannelDimension]] = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs, **kwargs,
) -> PIL.Image.Image: ) -> PIL.Image.Image:
""" """
...@@ -197,6 +217,12 @@ class ImageGPTImageProcessor(BaseImageProcessor): ...@@ -197,6 +217,12 @@ class ImageGPTImageProcessor(BaseImageProcessor):
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `ChannelDimension.LAST`: image in (height, width, num_channels) format. - `ChannelDimension.LAST`: image in (height, width, num_channels) format.
Only has an effect if `do_color_quantize` is set to `False`. Only has an effect if `do_color_quantize` is set to `False`.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
""" """
do_resize = do_resize if do_resize is not None else self.do_resize do_resize = do_resize if do_resize is not None else self.do_resize
size = size if size is not None else self.size size = size if size is not None else self.size
...@@ -224,14 +250,21 @@ class ImageGPTImageProcessor(BaseImageProcessor): ...@@ -224,14 +250,21 @@ class ImageGPTImageProcessor(BaseImageProcessor):
# All transformations expect numpy arrays. # All transformations expect numpy arrays.
images = [to_numpy_array(image) for image in images] images = [to_numpy_array(image) for image in images]
if input_data_format is None:
# We assume that all images have the same channel dimension format.
input_data_format = infer_channel_dimension_format(images[0])
if do_resize: if do_resize:
images = [self.resize(image=image, size=size, resample=resample) for image in images] images = [
self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
for image in images
]
if do_normalize: if do_normalize:
images = [self.normalize(image=image) for image in images] images = [self.normalize(image=image, input_data_format=input_data_format) for image in images]
if do_color_quantize: if do_color_quantize:
images = [to_channel_dimension_format(image, ChannelDimension.LAST) for image in images] images = [to_channel_dimension_format(image, ChannelDimension.LAST, input_data_format) for image in images]
# color quantize from (batch_size, height, width, 3) to (batch_size, height, width) # color quantize from (batch_size, height, width, 3) to (batch_size, height, width)
images = np.array(images) images = np.array(images)
images = color_quantize(images, clusters).reshape(images.shape[:-1]) images = color_quantize(images, clusters).reshape(images.shape[:-1])
...@@ -243,7 +276,10 @@ class ImageGPTImageProcessor(BaseImageProcessor): ...@@ -243,7 +276,10 @@ class ImageGPTImageProcessor(BaseImageProcessor):
# We need to convert back to a list of images to keep consistent behaviour across processors. # We need to convert back to a list of images to keep consistent behaviour across processors.
images = list(images) images = list(images)
else: else:
images = [to_channel_dimension_format(image, data_format) for image in images] images = [
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
for image in images
]
data = {"input_ids": images} data = {"input_ids": images}
return BatchFeature(data=data, tensor_type=return_tensors) return BatchFeature(data=data, tensor_type=return_tensors)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment