Unverified Commit 6bca43bb authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

Input data format (#25464)

* Add copied from statements for image processors

* Move out rescale and normalize to base image processor

* Remove rescale and normalize from vit (post rebase)

* Update docstrings and tidy up

* PR comments

* Add input_data_format as preprocess argument

* Resolve tests and tidy up

* Remove num_channels argument

* Update doc strings -> default ints not in code formatting
parent a6609caf
...@@ -24,6 +24,7 @@ from ...image_utils import ( ...@@ -24,6 +24,7 @@ from ...image_utils import (
ChannelDimension, ChannelDimension,
ImageInput, ImageInput,
PILImageResampling, PILImageResampling,
infer_channel_dimension_format,
make_list_of_images, make_list_of_images,
to_numpy_array, to_numpy_array,
valid_images, valid_images,
...@@ -50,12 +51,17 @@ def normalize_box(box, width, height): ...@@ -50,12 +51,17 @@ def normalize_box(box, width, height):
] ]
def apply_tesseract(image: np.ndarray, lang: Optional[str], tesseract_config: Optional[str] = None): def apply_tesseract(
image: np.ndarray,
lang: Optional[str],
tesseract_config: Optional[str] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
):
"""Applies Tesseract OCR on a document image, and returns recognized words + normalized bounding boxes.""" """Applies Tesseract OCR on a document image, and returns recognized words + normalized bounding boxes."""
tesseract_config = tesseract_config if tesseract_config is not None else "" tesseract_config = tesseract_config if tesseract_config is not None else ""
# apply OCR # apply OCR
pil_image = to_pil_image(image) pil_image = to_pil_image(image, input_data_format=input_data_format)
image_width, image_height = pil_image.size image_width, image_height = pil_image.size
data = pytesseract.image_to_data(pil_image, lang=lang, output_type="dict", config=tesseract_config) data = pytesseract.image_to_data(pil_image, lang=lang, output_type="dict", config=tesseract_config)
words, left, top, width, height = data["text"], data["left"], data["top"], data["width"], data["height"] words, left, top, width, height = data["text"], data["left"], data["top"], data["width"], data["height"]
...@@ -138,6 +144,7 @@ class LayoutLMv2ImageProcessor(BaseImageProcessor): ...@@ -138,6 +144,7 @@ class LayoutLMv2ImageProcessor(BaseImageProcessor):
size: Dict[str, int], size: Dict[str, int],
resample: PILImageResampling = PILImageResampling.BILINEAR, resample: PILImageResampling = PILImageResampling.BILINEAR,
data_format: Optional[Union[str, ChannelDimension]] = None, data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs, **kwargs,
) -> np.ndarray: ) -> np.ndarray:
""" """
...@@ -155,6 +162,13 @@ class LayoutLMv2ImageProcessor(BaseImageProcessor): ...@@ -155,6 +162,13 @@ class LayoutLMv2ImageProcessor(BaseImageProcessor):
image is used. Can be one of: image is used. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
Returns: Returns:
`np.ndarray`: The resized image. `np.ndarray`: The resized image.
...@@ -163,7 +177,14 @@ class LayoutLMv2ImageProcessor(BaseImageProcessor): ...@@ -163,7 +177,14 @@ class LayoutLMv2ImageProcessor(BaseImageProcessor):
if "height" not in size or "width" not in size: if "height" not in size or "width" not in size:
raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}") raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}")
output_size = (size["height"], size["width"]) output_size = (size["height"], size["width"])
return resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs) return resize(
image,
size=output_size,
resample=resample,
data_format=data_format,
input_data_format=input_data_format,
**kwargs,
)
def preprocess( def preprocess(
self, self,
...@@ -176,6 +197,7 @@ class LayoutLMv2ImageProcessor(BaseImageProcessor): ...@@ -176,6 +197,7 @@ class LayoutLMv2ImageProcessor(BaseImageProcessor):
tesseract_config: Optional[str] = None, tesseract_config: Optional[str] = None,
return_tensors: Optional[Union[str, TensorType]] = None, return_tensors: Optional[Union[str, TensorType]] = None,
data_format: ChannelDimension = ChannelDimension.FIRST, data_format: ChannelDimension = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs, **kwargs,
) -> PIL.Image.Image: ) -> PIL.Image.Image:
""" """
...@@ -233,21 +255,30 @@ class LayoutLMv2ImageProcessor(BaseImageProcessor): ...@@ -233,21 +255,30 @@ class LayoutLMv2ImageProcessor(BaseImageProcessor):
# All transformations expect numpy arrays. # All transformations expect numpy arrays.
images = [to_numpy_array(image) for image in images] images = [to_numpy_array(image) for image in images]
if input_data_format is None:
# We assume that all images have the same channel dimension format.
input_data_format = infer_channel_dimension_format(images[0])
if apply_ocr: if apply_ocr:
requires_backends(self, "pytesseract") requires_backends(self, "pytesseract")
words_batch = [] words_batch = []
boxes_batch = [] boxes_batch = []
for image in images: for image in images:
words, boxes = apply_tesseract(image, ocr_lang, tesseract_config) words, boxes = apply_tesseract(image, ocr_lang, tesseract_config, input_data_format=input_data_format)
words_batch.append(words) words_batch.append(words)
boxes_batch.append(boxes) boxes_batch.append(boxes)
if do_resize: if do_resize:
images = [self.resize(image=image, size=size, resample=resample) for image in images] images = [
self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
for image in images
]
# flip color channels from RGB to BGR (as Detectron2 requires this) # flip color channels from RGB to BGR (as Detectron2 requires this)
images = [flip_channel_order(image) for image in images] images = [flip_channel_order(image, input_data_format=input_data_format) for image in images]
images = [to_channel_dimension_format(image, data_format) for image in images] images = [
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
]
data = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors) data = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors)
......
...@@ -26,6 +26,7 @@ from ...image_utils import ( ...@@ -26,6 +26,7 @@ from ...image_utils import (
ChannelDimension, ChannelDimension,
ImageInput, ImageInput,
PILImageResampling, PILImageResampling,
infer_channel_dimension_format,
make_list_of_images, make_list_of_images,
to_numpy_array, to_numpy_array,
valid_images, valid_images,
...@@ -52,11 +53,16 @@ def normalize_box(box, width, height): ...@@ -52,11 +53,16 @@ def normalize_box(box, width, height):
] ]
def apply_tesseract(image: np.ndarray, lang: Optional[str], tesseract_config: Optional[str]): def apply_tesseract(
image: np.ndarray,
lang: Optional[str],
tesseract_config: Optional[str],
input_data_format: Optional[Union[ChannelDimension, str]] = None,
):
"""Applies Tesseract OCR on a document image, and returns recognized words + normalized bounding boxes.""" """Applies Tesseract OCR on a document image, and returns recognized words + normalized bounding boxes."""
# apply OCR # apply OCR
pil_image = to_pil_image(image) pil_image = to_pil_image(image, input_data_format=input_data_format)
image_width, image_height = pil_image.size image_width, image_height = pil_image.size
data = pytesseract.image_to_data(pil_image, lang=lang, output_type="dict", config=tesseract_config) data = pytesseract.image_to_data(pil_image, lang=lang, output_type="dict", config=tesseract_config)
words, left, top, width, height = data["text"], data["left"], data["top"], data["width"], data["height"] words, left, top, width, height = data["text"], data["left"], data["top"], data["width"], data["height"]
...@@ -164,6 +170,7 @@ class LayoutLMv3ImageProcessor(BaseImageProcessor): ...@@ -164,6 +170,7 @@ class LayoutLMv3ImageProcessor(BaseImageProcessor):
size: Dict[str, int], size: Dict[str, int],
resample: PILImageResampling = PILImageResampling.BILINEAR, resample: PILImageResampling = PILImageResampling.BILINEAR,
data_format: Optional[Union[str, ChannelDimension]] = None, data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs, **kwargs,
) -> np.ndarray: ) -> np.ndarray:
""" """
...@@ -181,6 +188,13 @@ class LayoutLMv3ImageProcessor(BaseImageProcessor): ...@@ -181,6 +188,13 @@ class LayoutLMv3ImageProcessor(BaseImageProcessor):
image is used. Can be one of: image is used. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
Returns: Returns:
`np.ndarray`: The resized image. `np.ndarray`: The resized image.
...@@ -189,7 +203,14 @@ class LayoutLMv3ImageProcessor(BaseImageProcessor): ...@@ -189,7 +203,14 @@ class LayoutLMv3ImageProcessor(BaseImageProcessor):
if "height" not in size or "width" not in size: if "height" not in size or "width" not in size:
raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}") raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}")
output_size = (size["height"], size["width"]) output_size = (size["height"], size["width"])
return resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs) return resize(
image,
size=output_size,
resample=resample,
data_format=data_format,
input_data_format=input_data_format,
**kwargs,
)
def preprocess( def preprocess(
self, self,
...@@ -207,6 +228,7 @@ class LayoutLMv3ImageProcessor(BaseImageProcessor): ...@@ -207,6 +228,7 @@ class LayoutLMv3ImageProcessor(BaseImageProcessor):
tesseract_config: Optional[str] = None, tesseract_config: Optional[str] = None,
return_tensors: Optional[Union[str, TensorType]] = None, return_tensors: Optional[Union[str, TensorType]] = None,
data_format: ChannelDimension = ChannelDimension.FIRST, data_format: ChannelDimension = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs, **kwargs,
) -> PIL.Image.Image: ) -> PIL.Image.Image:
""" """
...@@ -252,6 +274,12 @@ class LayoutLMv3ImageProcessor(BaseImageProcessor): ...@@ -252,6 +274,12 @@ class LayoutLMv3ImageProcessor(BaseImageProcessor):
The channel dimension format for the output image. Can be one of: The channel dimension format for the output image. Can be one of:
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `ChannelDimension.LAST`: image in (height, width, num_channels) format. - `ChannelDimension.LAST`: image in (height, width, num_channels) format.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
""" """
do_resize = do_resize if do_resize is not None else self.do_resize do_resize = do_resize if do_resize is not None else self.do_resize
size = size if size is not None else self.size size = size if size is not None else self.size
...@@ -286,26 +314,41 @@ class LayoutLMv3ImageProcessor(BaseImageProcessor): ...@@ -286,26 +314,41 @@ class LayoutLMv3ImageProcessor(BaseImageProcessor):
# All transformations expect numpy arrays. # All transformations expect numpy arrays.
images = [to_numpy_array(image) for image in images] images = [to_numpy_array(image) for image in images]
if input_data_format is None:
# We assume that all images have the same channel dimension format.
input_data_format = infer_channel_dimension_format(images[0])
# Tesseract OCR to get words + normalized bounding boxes # Tesseract OCR to get words + normalized bounding boxes
if apply_ocr: if apply_ocr:
requires_backends(self, "pytesseract") requires_backends(self, "pytesseract")
words_batch = [] words_batch = []
boxes_batch = [] boxes_batch = []
for image in images: for image in images:
words, boxes = apply_tesseract(image, ocr_lang, tesseract_config) words, boxes = apply_tesseract(image, ocr_lang, tesseract_config, input_data_format=input_data_format)
words_batch.append(words) words_batch.append(words)
boxes_batch.append(boxes) boxes_batch.append(boxes)
if do_resize: if do_resize:
images = [self.resize(image=image, size=size, resample=resample) for image in images] images = [
self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
for image in images
]
if do_rescale: if do_rescale:
images = [self.rescale(image=image, scale=rescale_factor) for image in images] images = [
self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
for image in images
]
if do_normalize: if do_normalize:
images = [self.normalize(image=image, mean=image_mean, std=image_std) for image in images] images = [
self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
images = [to_channel_dimension_format(image, data_format) for image in images] for image in images
]
images = [
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
]
data = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors) data = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors)
......
...@@ -30,6 +30,7 @@ from ...image_utils import ( ...@@ -30,6 +30,7 @@ from ...image_utils import (
ChannelDimension, ChannelDimension,
ImageInput, ImageInput,
PILImageResampling, PILImageResampling,
infer_channel_dimension_format,
make_list_of_images, make_list_of_images,
to_numpy_array, to_numpy_array,
valid_images, valid_images,
...@@ -119,6 +120,7 @@ class LevitImageProcessor(BaseImageProcessor): ...@@ -119,6 +120,7 @@ class LevitImageProcessor(BaseImageProcessor):
size: Dict[str, int], size: Dict[str, int],
resample: PILImageResampling = PILImageResampling.BICUBIC, resample: PILImageResampling = PILImageResampling.BICUBIC,
data_format: Optional[Union[str, ChannelDimension]] = None, data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs, **kwargs,
) -> np.ndarray: ) -> np.ndarray:
""" """
...@@ -143,19 +145,28 @@ class LevitImageProcessor(BaseImageProcessor): ...@@ -143,19 +145,28 @@ class LevitImageProcessor(BaseImageProcessor):
Resampling filter to use when resiizing the image. Resampling filter to use when resiizing the image.
data_format (`str` or `ChannelDimension`, *optional*): data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image. The channel dimension format of the image. If not provided, it will be the same as the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred.
""" """
size_dict = get_size_dict(size, default_to_square=False) size_dict = get_size_dict(size, default_to_square=False)
# size_dict is a dict with either keys "height" and "width" or "shortest_edge" # size_dict is a dict with either keys "height" and "width" or "shortest_edge"
if "shortest_edge" in size: if "shortest_edge" in size:
shortest_edge = int((256 / 224) * size["shortest_edge"]) shortest_edge = int((256 / 224) * size["shortest_edge"])
output_size = get_resize_output_image_size(image, size=shortest_edge, default_to_square=False) output_size = get_resize_output_image_size(
image, size=shortest_edge, default_to_square=False, input_data_format=input_data_format
)
size_dict = {"height": output_size[0], "width": output_size[1]} size_dict = {"height": output_size[0], "width": output_size[1]}
if "height" not in size_dict or "width" not in size_dict: if "height" not in size_dict or "width" not in size_dict:
raise ValueError( raise ValueError(
f"Size dict must have keys 'height' and 'width' or 'shortest_edge'. Got {size_dict.keys()}" f"Size dict must have keys 'height' and 'width' or 'shortest_edge'. Got {size_dict.keys()}"
) )
return resize( return resize(
image, size=(size_dict["height"], size_dict["width"]), resample=resample, data_format=data_format, **kwargs image,
size=(size_dict["height"], size_dict["width"]),
resample=resample,
data_format=data_format,
input_data_format=input_data_format,
**kwargs,
) )
def preprocess( def preprocess(
...@@ -173,6 +184,7 @@ class LevitImageProcessor(BaseImageProcessor): ...@@ -173,6 +184,7 @@ class LevitImageProcessor(BaseImageProcessor):
image_std: Optional[Union[float, Iterable[float]]] = None, image_std: Optional[Union[float, Iterable[float]]] = None,
return_tensors: Optional[TensorType] = None, return_tensors: Optional[TensorType] = None,
data_format: ChannelDimension = ChannelDimension.FIRST, data_format: ChannelDimension = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs, **kwargs,
) -> BatchFeature: ) -> BatchFeature:
""" """
...@@ -217,6 +229,12 @@ class LevitImageProcessor(BaseImageProcessor): ...@@ -217,6 +229,12 @@ class LevitImageProcessor(BaseImageProcessor):
image is used. Can be one of: image is used. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
""" """
do_resize = do_resize if do_resize is not None else self.do_resize do_resize = do_resize if do_resize is not None else self.do_resize
resample = resample if resample is not None else self.resample resample = resample if resample is not None else self.resample
...@@ -255,19 +273,27 @@ class LevitImageProcessor(BaseImageProcessor): ...@@ -255,19 +273,27 @@ class LevitImageProcessor(BaseImageProcessor):
# All transformations expect numpy arrays. # All transformations expect numpy arrays.
images = [to_numpy_array(image) for image in images] images = [to_numpy_array(image) for image in images]
if input_data_format is None:
# We assume that all images have the same channel dimension format.
input_data_format = infer_channel_dimension_format(images[0])
if do_resize: if do_resize:
images = [self.resize(image, size, resample) for image in images] images = [self.resize(image, size, resample, input_data_format=input_data_format) for image in images]
if do_center_crop: if do_center_crop:
images = [self.center_crop(image, crop_size) for image in images] images = [self.center_crop(image, crop_size, input_data_format=input_data_format) for image in images]
if do_rescale: if do_rescale:
images = [self.rescale(image, rescale_factor) for image in images] images = [self.rescale(image, rescale_factor, input_data_format=input_data_format) for image in images]
if do_normalize: if do_normalize:
images = [self.normalize(image, image_mean, image_std) for image in images] images = [
self.normalize(image, image_mean, image_std, input_data_format=input_data_format) for image in images
]
images = [to_channel_dimension_format(image, data_format) for image in images] images = [
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
]
data = {"pixel_values": images} data = {"pixel_values": images}
return BatchFeature(data=data, tensor_type=return_tensors) return BatchFeature(data=data, tensor_type=return_tensors)
...@@ -66,23 +66,28 @@ def max_across_indices(values: Iterable[Any]) -> List[Any]: ...@@ -66,23 +66,28 @@ def max_across_indices(values: Iterable[Any]) -> List[Any]:
# Copied from transformers.models.detr.image_processing_detr.get_max_height_width # Copied from transformers.models.detr.image_processing_detr.get_max_height_width
def get_max_height_width(images: List[np.ndarray]) -> List[int]: def get_max_height_width(
images: List[np.ndarray], input_data_format: Optional[Union[str, ChannelDimension]] = None
) -> List[int]:
""" """
Get the maximum height and width across all images in a batch. Get the maximum height and width across all images in a batch.
""" """
input_channel_dimension = infer_channel_dimension_format(images[0]) if input_data_format is None:
input_data_format = infer_channel_dimension_format(images[0])
if input_channel_dimension == ChannelDimension.FIRST: if input_data_format == ChannelDimension.FIRST:
_, max_height, max_width = max_across_indices([img.shape for img in images]) _, max_height, max_width = max_across_indices([img.shape for img in images])
elif input_channel_dimension == ChannelDimension.LAST: elif input_data_format == ChannelDimension.LAST:
max_height, max_width, _ = max_across_indices([img.shape for img in images]) max_height, max_width, _ = max_across_indices([img.shape for img in images])
else: else:
raise ValueError(f"Invalid channel dimension format: {input_channel_dimension}") raise ValueError(f"Invalid channel dimension format: {input_data_format}")
return (max_height, max_width) return (max_height, max_width)
# Copied from transformers.models.detr.image_processing_detr.make_pixel_mask # Copied from transformers.models.detr.image_processing_detr.make_pixel_mask
def make_pixel_mask(image: np.ndarray, output_size: Tuple[int, int]) -> np.ndarray: def make_pixel_mask(
image: np.ndarray, output_size: Tuple[int, int], input_data_format: Optional[Union[str, ChannelDimension]] = None
) -> np.ndarray:
""" """
Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding. Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.
...@@ -92,7 +97,7 @@ def make_pixel_mask(image: np.ndarray, output_size: Tuple[int, int]) -> np.ndarr ...@@ -92,7 +97,7 @@ def make_pixel_mask(image: np.ndarray, output_size: Tuple[int, int]) -> np.ndarr
output_size (`Tuple[int, int]`): output_size (`Tuple[int, int]`):
Output size of the mask. Output size of the mask.
""" """
input_height, input_width = get_image_size(image) input_height, input_width = get_image_size(image, channel_dim=input_data_format)
mask = np.zeros(output_size, dtype=np.int64) mask = np.zeros(output_size, dtype=np.int64)
mask[:input_height, :input_width] = 1 mask[:input_height, :input_width] = 1
return mask return mask
...@@ -297,6 +302,7 @@ def get_mask2former_resize_output_image_size( ...@@ -297,6 +302,7 @@ def get_mask2former_resize_output_image_size(
max_size: Optional[int] = None, max_size: Optional[int] = None,
size_divisor: int = 0, size_divisor: int = 0,
default_to_square: bool = True, default_to_square: bool = True,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> tuple: ) -> tuple:
""" """
Computes the output size given the desired size. Computes the output size given the desired size.
...@@ -310,14 +316,18 @@ def get_mask2former_resize_output_image_size( ...@@ -310,14 +316,18 @@ def get_mask2former_resize_output_image_size(
Whether to default to square if no size is provided. Whether to default to square if no size is provided.
max_size (`int`, *optional*): max_size (`int`, *optional*):
The maximum size of the output image. The maximum size of the output image.
size_divisible (`int`, *optional*, defaults to `0`): size_divisible (`int`, *optional*, defaults to 0):
If size_divisible is given, the output image size will be divisible by the number. If size_divisible is given, the output image size will be divisible by the number.
Returns: Returns:
`Tuple[int, int]`: The output size. `Tuple[int, int]`: The output size.
""" """
output_size = get_resize_output_image_size( output_size = get_resize_output_image_size(
input_image=image, size=size, default_to_square=default_to_square, max_size=max_size input_image=image,
size=size,
default_to_square=default_to_square,
max_size=max_size,
input_data_format=input_data_format,
) )
if size_divisor > 0: if size_divisor > 0:
...@@ -450,11 +460,27 @@ class Mask2FormerImageProcessor(BaseImageProcessor): ...@@ -450,11 +460,27 @@ class Mask2FormerImageProcessor(BaseImageProcessor):
size_divisor: int = 0, size_divisor: int = 0,
resample: PILImageResampling = PILImageResampling.BILINEAR, resample: PILImageResampling = PILImageResampling.BILINEAR,
data_format=None, data_format=None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs, **kwargs,
) -> np.ndarray: ) -> np.ndarray:
""" """
Resize the image to the given size. Size can be min_size (scalar) or `(height, width)` tuple. If size is an Resize the image to the given size. Size can be min_size (scalar) or `(height, width)` tuple. If size is an
int, smaller edge of the image will be matched to this number. int, smaller edge of the image will be matched to this number.
Args:
image (`np.ndarray`):
Image to resize.
size (`Dict[str, int]`):
The size of the output image.
size_divisor (`int`, *optional*, defaults to 0):
If size_divisor is given, the output image size will be divisible by the number.
resample (`PILImageResampling` resampling filter, *optional*, defaults to `PILImageResampling.BILINEAR`):
Resampling filter to use when resizing the image.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format for the output image. If unset, the channel dimension format of the input
image is used.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred.
""" """
if "max_size" in kwargs: if "max_size" in kwargs:
warnings.warn( warnings.warn(
...@@ -482,13 +508,20 @@ class Mask2FormerImageProcessor(BaseImageProcessor): ...@@ -482,13 +508,20 @@ class Mask2FormerImageProcessor(BaseImageProcessor):
max_size=max_size, max_size=max_size,
size_divisor=size_divisor, size_divisor=size_divisor,
default_to_square=False, default_to_square=False,
input_data_format=input_data_format,
)
image = resize(
image, size=size, resample=resample, data_format=data_format, input_data_format=input_data_format, **kwargs
) )
image = resize(image, size=size, resample=resample, data_format=data_format)
return image return image
# Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.rescale # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.rescale
def rescale( def rescale(
self, image: np.ndarray, rescale_factor: float, data_format: Optional[Union[str, ChannelDimension]] = None self,
image: np.ndarray,
rescale_factor: float,
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray: ) -> np.ndarray:
""" """
Rescale the image by the given factor. image = image * rescale_factor. Rescale the image by the given factor. image = image * rescale_factor.
...@@ -503,8 +536,13 @@ class Mask2FormerImageProcessor(BaseImageProcessor): ...@@ -503,8 +536,13 @@ class Mask2FormerImageProcessor(BaseImageProcessor):
image is used. Can be one of: image is used. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
input_data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format for the input image. If unset, is inferred from the input image. Can be
one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
""" """
return rescale(image, rescale_factor, data_format=data_format) return rescale(image, rescale_factor, data_format=data_format, input_data_format=input_data_format)
# Copied from transformers.models.maskformer.image_processing_maskformer.MaskFormerImageProcessor.convert_segmentation_map_to_binary_masks # Copied from transformers.models.maskformer.image_processing_maskformer.MaskFormerImageProcessor.convert_segmentation_map_to_binary_masks
def convert_segmentation_map_to_binary_masks( def convert_segmentation_map_to_binary_masks(
...@@ -538,13 +576,16 @@ class Mask2FormerImageProcessor(BaseImageProcessor): ...@@ -538,13 +576,16 @@ class Mask2FormerImageProcessor(BaseImageProcessor):
do_normalize: bool = None, do_normalize: bool = None,
image_mean: Optional[Union[float, List[float]]] = None, image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
): ):
if do_resize: if do_resize:
image = self.resize(image, size=size, size_divisor=size_divisor, resample=resample) image = self.resize(
image, size=size, size_divisor=size_divisor, resample=resample, input_data_format=input_data_format
)
if do_rescale: if do_rescale:
image = self.rescale(image, rescale_factor=rescale_factor) image = self.rescale(image, rescale_factor=rescale_factor, input_data_format=input_data_format)
if do_normalize: if do_normalize:
image = self.normalize(image, mean=image_mean, std=image_std) image = self.normalize(image, mean=image_mean, std=image_std, input_data_format=input_data_format)
return image return image
def _preprocess_image( def _preprocess_image(
...@@ -560,10 +601,13 @@ class Mask2FormerImageProcessor(BaseImageProcessor): ...@@ -560,10 +601,13 @@ class Mask2FormerImageProcessor(BaseImageProcessor):
image_mean: Optional[Union[float, List[float]]] = None, image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None,
data_format: Optional[Union[str, ChannelDimension]] = None, data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray: ) -> np.ndarray:
"""Preprocesses a single image.""" """Preprocesses a single image."""
# All transformations expect numpy arrays. # All transformations expect numpy arrays.
image = to_numpy_array(image) image = to_numpy_array(image)
if input_data_format is None:
input_data_format = infer_channel_dimension_format(image)
image = self._preprocess( image = self._preprocess(
image=image, image=image,
do_resize=do_resize, do_resize=do_resize,
...@@ -575,9 +619,10 @@ class Mask2FormerImageProcessor(BaseImageProcessor): ...@@ -575,9 +619,10 @@ class Mask2FormerImageProcessor(BaseImageProcessor):
do_normalize=do_normalize, do_normalize=do_normalize,
image_mean=image_mean, image_mean=image_mean,
image_std=image_std, image_std=image_std,
input_data_format=input_data_format,
) )
if data_format is not None: if data_format is not None:
image = to_channel_dimension_format(image, data_format) image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
return image return image
def _preprocess_mask( def _preprocess_mask(
...@@ -586,14 +631,19 @@ class Mask2FormerImageProcessor(BaseImageProcessor): ...@@ -586,14 +631,19 @@ class Mask2FormerImageProcessor(BaseImageProcessor):
do_resize: bool = None, do_resize: bool = None,
size: Dict[str, int] = None, size: Dict[str, int] = None,
size_divisor: int = 0, size_divisor: int = 0,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray: ) -> np.ndarray:
"""Preprocesses a single mask.""" """Preprocesses a single mask."""
segmentation_map = to_numpy_array(segmentation_map) segmentation_map = to_numpy_array(segmentation_map)
# Add channel dimension if missing - needed for certain transformations # Add channel dimension if missing - needed for certain transformations
added_channel_dim = False
if segmentation_map.ndim == 2: if segmentation_map.ndim == 2:
added_channel_dim = True added_channel_dim = True
segmentation_map = segmentation_map[None, ...] segmentation_map = segmentation_map[None, ...]
input_data_format = ChannelDimension.FIRST
else:
added_channel_dim = False
if input_data_format is None:
input_data_format = infer_channel_dimension_format(segmentation_map)
# TODO: (Amy) # TODO: (Amy)
# Remork segmentation map processing to include reducing labels and resizing which doesn't # Remork segmentation map processing to include reducing labels and resizing which doesn't
# drop segment IDs > 255. # drop segment IDs > 255.
...@@ -605,6 +655,7 @@ class Mask2FormerImageProcessor(BaseImageProcessor): ...@@ -605,6 +655,7 @@ class Mask2FormerImageProcessor(BaseImageProcessor):
size_divisor=size_divisor, size_divisor=size_divisor,
do_rescale=False, do_rescale=False,
do_normalize=False, do_normalize=False,
input_data_format=input_data_format,
) )
# Remove extra channel dimension if added for processing # Remove extra channel dimension if added for processing
if added_channel_dim: if added_channel_dim:
...@@ -629,6 +680,7 @@ class Mask2FormerImageProcessor(BaseImageProcessor): ...@@ -629,6 +680,7 @@ class Mask2FormerImageProcessor(BaseImageProcessor):
reduce_labels: Optional[bool] = None, reduce_labels: Optional[bool] = None,
return_tensors: Optional[Union[str, TensorType]] = None, return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST, data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs, **kwargs,
) -> BatchFeature: ) -> BatchFeature:
if "pad_and_return_pixel_mask" in kwargs: if "pad_and_return_pixel_mask" in kwargs:
...@@ -691,17 +743,26 @@ class Mask2FormerImageProcessor(BaseImageProcessor): ...@@ -691,17 +743,26 @@ class Mask2FormerImageProcessor(BaseImageProcessor):
image_mean=image_mean, image_mean=image_mean,
image_std=image_std, image_std=image_std,
data_format=data_format, data_format=data_format,
input_data_format=input_data_format,
) )
for image in images for image in images
] ]
if segmentation_maps is not None: if segmentation_maps is not None:
segmentation_maps = [ segmentation_maps = [
self._preprocess_mask(segmentation_map, do_resize, size, size_divisor) self._preprocess_mask(
segmentation_map, do_resize, size, size_divisor, input_data_format=input_data_format
)
for segmentation_map in segmentation_maps for segmentation_map in segmentation_maps
] ]
encoded_inputs = self.encode_inputs( encoded_inputs = self.encode_inputs(
images, segmentation_maps, instance_id_to_semantic_id, ignore_index, reduce_labels, return_tensors images,
segmentation_maps,
instance_id_to_semantic_id,
ignore_index,
reduce_labels,
return_tensors,
input_data_format=input_data_format,
) )
return encoded_inputs return encoded_inputs
...@@ -712,18 +773,24 @@ class Mask2FormerImageProcessor(BaseImageProcessor): ...@@ -712,18 +773,24 @@ class Mask2FormerImageProcessor(BaseImageProcessor):
output_size: Tuple[int, int], output_size: Tuple[int, int],
constant_values: Union[float, Iterable[float]] = 0, constant_values: Union[float, Iterable[float]] = 0,
data_format: Optional[ChannelDimension] = None, data_format: Optional[ChannelDimension] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray: ) -> np.ndarray:
""" """
Pad an image with zeros to the given size. Pad an image with zeros to the given size.
""" """
input_height, input_width = get_image_size(image) input_height, input_width = get_image_size(image, channel_dim=input_data_format)
output_height, output_width = output_size output_height, output_width = output_size
pad_bottom = output_height - input_height pad_bottom = output_height - input_height
pad_right = output_width - input_width pad_right = output_width - input_width
padding = ((0, pad_bottom), (0, pad_right)) padding = ((0, pad_bottom), (0, pad_right))
padded_image = pad( padded_image = pad(
image, padding, mode=PaddingMode.CONSTANT, constant_values=constant_values, data_format=data_format image,
padding,
mode=PaddingMode.CONSTANT,
constant_values=constant_values,
data_format=data_format,
input_data_format=input_data_format,
) )
return padded_image return padded_image
...@@ -735,6 +802,7 @@ class Mask2FormerImageProcessor(BaseImageProcessor): ...@@ -735,6 +802,7 @@ class Mask2FormerImageProcessor(BaseImageProcessor):
return_pixel_mask: bool = True, return_pixel_mask: bool = True,
return_tensors: Optional[Union[str, TensorType]] = None, return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Optional[ChannelDimension] = None, data_format: Optional[ChannelDimension] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> BatchFeature: ) -> BatchFeature:
""" """
Pads a batch of images to the bottom and right of the image with zeros to the size of largest height and width Pads a batch of images to the bottom and right of the image with zeros to the size of largest height and width
...@@ -756,17 +824,28 @@ class Mask2FormerImageProcessor(BaseImageProcessor): ...@@ -756,17 +824,28 @@ class Mask2FormerImageProcessor(BaseImageProcessor):
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
data_format (`str` or `ChannelDimension`, *optional*): data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image. The channel dimension format of the image. If not provided, it will be the same as the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred.
""" """
pad_size = get_max_height_width(images) pad_size = get_max_height_width(images, input_data_format=input_data_format)
padded_images = [ padded_images = [
self._pad_image(image, pad_size, constant_values=constant_values, data_format=data_format) self._pad_image(
image,
pad_size,
constant_values=constant_values,
data_format=data_format,
input_data_format=input_data_format,
)
for image in images for image in images
] ]
data = {"pixel_values": padded_images} data = {"pixel_values": padded_images}
if return_pixel_mask: if return_pixel_mask:
masks = [make_pixel_mask(image=image, output_size=pad_size) for image in images] masks = [
make_pixel_mask(image=image, output_size=pad_size, input_data_format=input_data_format)
for image in images
]
data["pixel_mask"] = masks data["pixel_mask"] = masks
return BatchFeature(data=data, tensor_type=return_tensors) return BatchFeature(data=data, tensor_type=return_tensors)
...@@ -779,6 +858,7 @@ class Mask2FormerImageProcessor(BaseImageProcessor): ...@@ -779,6 +858,7 @@ class Mask2FormerImageProcessor(BaseImageProcessor):
ignore_index: Optional[int] = None, ignore_index: Optional[int] = None,
reduce_labels: bool = False, reduce_labels: bool = False,
return_tensors: Optional[Union[str, TensorType]] = None, return_tensors: Optional[Union[str, TensorType]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
): ):
""" """
Pad images up to the largest image in a batch and create a corresponding `pixel_mask`. Pad images up to the largest image in a batch and create a corresponding `pixel_mask`.
...@@ -815,6 +895,9 @@ class Mask2FormerImageProcessor(BaseImageProcessor): ...@@ -815,6 +895,9 @@ class Mask2FormerImageProcessor(BaseImageProcessor):
If set, will return tensors instead of NumPy arrays. If set to `'pt'`, return PyTorch `torch.Tensor` If set, will return tensors instead of NumPy arrays. If set to `'pt'`, return PyTorch `torch.Tensor`
objects. objects.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred.
Returns: Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields: [`BatchFeature`]: A [`BatchFeature`] with the following fields:
...@@ -831,7 +914,13 @@ class Mask2FormerImageProcessor(BaseImageProcessor): ...@@ -831,7 +914,13 @@ class Mask2FormerImageProcessor(BaseImageProcessor):
reduce_labels = self.reduce_labels if reduce_labels is None else reduce_labels reduce_labels = self.reduce_labels if reduce_labels is None else reduce_labels
pixel_values_list = [to_numpy_array(pixel_values) for pixel_values in pixel_values_list] pixel_values_list = [to_numpy_array(pixel_values) for pixel_values in pixel_values_list]
encoded_inputs = self.pad(pixel_values_list, return_tensors=return_tensors)
if input_data_format is None:
input_data_format = infer_channel_dimension_format(pixel_values_list[0])
encoded_inputs = self.pad(
pixel_values_list, return_tensors=return_tensors, input_data_format=input_data_format
)
if segmentation_maps is not None: if segmentation_maps is not None:
mask_labels = [] mask_labels = []
......
...@@ -70,23 +70,28 @@ def max_across_indices(values: Iterable[Any]) -> List[Any]: ...@@ -70,23 +70,28 @@ def max_across_indices(values: Iterable[Any]) -> List[Any]:
# Copied from transformers.models.detr.image_processing_detr.get_max_height_width # Copied from transformers.models.detr.image_processing_detr.get_max_height_width
def get_max_height_width(images: List[np.ndarray]) -> List[int]: def get_max_height_width(
images: List[np.ndarray], input_data_format: Optional[Union[str, ChannelDimension]] = None
) -> List[int]:
""" """
Get the maximum height and width across all images in a batch. Get the maximum height and width across all images in a batch.
""" """
input_channel_dimension = infer_channel_dimension_format(images[0]) if input_data_format is None:
input_data_format = infer_channel_dimension_format(images[0])
if input_channel_dimension == ChannelDimension.FIRST: if input_data_format == ChannelDimension.FIRST:
_, max_height, max_width = max_across_indices([img.shape for img in images]) _, max_height, max_width = max_across_indices([img.shape for img in images])
elif input_channel_dimension == ChannelDimension.LAST: elif input_data_format == ChannelDimension.LAST:
max_height, max_width, _ = max_across_indices([img.shape for img in images]) max_height, max_width, _ = max_across_indices([img.shape for img in images])
else: else:
raise ValueError(f"Invalid channel dimension format: {input_channel_dimension}") raise ValueError(f"Invalid channel dimension format: {input_data_format}")
return (max_height, max_width) return (max_height, max_width)
# Copied from transformers.models.detr.image_processing_detr.make_pixel_mask # Copied from transformers.models.detr.image_processing_detr.make_pixel_mask
def make_pixel_mask(image: np.ndarray, output_size: Tuple[int, int]) -> np.ndarray: def make_pixel_mask(
image: np.ndarray, output_size: Tuple[int, int], input_data_format: Optional[Union[str, ChannelDimension]] = None
) -> np.ndarray:
""" """
Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding. Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.
...@@ -96,7 +101,7 @@ def make_pixel_mask(image: np.ndarray, output_size: Tuple[int, int]) -> np.ndarr ...@@ -96,7 +101,7 @@ def make_pixel_mask(image: np.ndarray, output_size: Tuple[int, int]) -> np.ndarr
output_size (`Tuple[int, int]`): output_size (`Tuple[int, int]`):
Output size of the mask. Output size of the mask.
""" """
input_height, input_width = get_image_size(image) input_height, input_width = get_image_size(image, channel_dim=input_data_format)
mask = np.zeros(output_size, dtype=np.int64) mask = np.zeros(output_size, dtype=np.int64)
mask[:input_height, :input_width] = 1 mask[:input_height, :input_width] = 1
return mask return mask
...@@ -299,6 +304,7 @@ def get_maskformer_resize_output_image_size( ...@@ -299,6 +304,7 @@ def get_maskformer_resize_output_image_size(
max_size: Optional[int] = None, max_size: Optional[int] = None,
size_divisor: int = 0, size_divisor: int = 0,
default_to_square: bool = True, default_to_square: bool = True,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> tuple: ) -> tuple:
""" """
Computes the output size given the desired size. Computes the output size given the desired size.
...@@ -312,14 +318,18 @@ def get_maskformer_resize_output_image_size( ...@@ -312,14 +318,18 @@ def get_maskformer_resize_output_image_size(
Whether to default to square if no size is provided. Whether to default to square if no size is provided.
max_size (`int`, *optional*): max_size (`int`, *optional*):
The maximum size of the output image. The maximum size of the output image.
size_divisible (`int`, *optional*, defaults to `0`): size_divisible (`int`, *optional*, defaults to 0):
If size_divisible is given, the output image size will be divisible by the number. If size_divisible is given, the output image size will be divisible by the number.
Returns: Returns:
`Tuple[int, int]`: The output size. `Tuple[int, int]`: The output size.
""" """
output_size = get_resize_output_image_size( output_size = get_resize_output_image_size(
input_image=image, size=size, default_to_square=default_to_square, max_size=max_size input_image=image,
size=size,
default_to_square=default_to_square,
max_size=max_size,
input_data_format=input_data_format,
) )
if size_divisor > 0: if size_divisor > 0:
...@@ -458,11 +468,27 @@ class MaskFormerImageProcessor(BaseImageProcessor): ...@@ -458,11 +468,27 @@ class MaskFormerImageProcessor(BaseImageProcessor):
size_divisor: int = 0, size_divisor: int = 0,
resample: PILImageResampling = PILImageResampling.BILINEAR, resample: PILImageResampling = PILImageResampling.BILINEAR,
data_format=None, data_format=None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs, **kwargs,
) -> np.ndarray: ) -> np.ndarray:
""" """
Resize the image to the given size. Size can be min_size (scalar) or `(height, width)` tuple. If size is an Resize the image to the given size. Size can be min_size (scalar) or `(height, width)` tuple. If size is an
int, smaller edge of the image will be matched to this number. int, smaller edge of the image will be matched to this number.
Args:
image (`np.ndarray`):
Image to resize.
size (`Dict[str, int]`):
The size of the output image.
size_divisor (`int`, *optional*, defaults to 0):
If size_divisor is given, the output image size will be divisible by the number.
resample (`PILImageResampling` resampling filter, *optional*, defaults to `PILImageResampling.BILINEAR`):
Resampling filter to use when resizing the image.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format for the output image. If unset, the channel dimension format of the input
image is used.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred.
""" """
if "max_size" in kwargs: if "max_size" in kwargs:
warnings.warn( warnings.warn(
...@@ -490,13 +516,20 @@ class MaskFormerImageProcessor(BaseImageProcessor): ...@@ -490,13 +516,20 @@ class MaskFormerImageProcessor(BaseImageProcessor):
max_size=max_size, max_size=max_size,
size_divisor=size_divisor, size_divisor=size_divisor,
default_to_square=False, default_to_square=False,
input_data_format=input_data_format,
)
image = resize(
image, size=size, resample=resample, data_format=data_format, input_data_format=input_data_format, **kwargs
) )
image = resize(image, size=size, resample=resample, data_format=data_format)
return image return image
# Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.rescale # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.rescale
def rescale( def rescale(
self, image: np.ndarray, rescale_factor: float, data_format: Optional[Union[str, ChannelDimension]] = None self,
image: np.ndarray,
rescale_factor: float,
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray: ) -> np.ndarray:
""" """
Rescale the image by the given factor. image = image * rescale_factor. Rescale the image by the given factor. image = image * rescale_factor.
...@@ -511,8 +544,13 @@ class MaskFormerImageProcessor(BaseImageProcessor): ...@@ -511,8 +544,13 @@ class MaskFormerImageProcessor(BaseImageProcessor):
image is used. Can be one of: image is used. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
input_data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format for the input image. If unset, is inferred from the input image. Can be
one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
""" """
return rescale(image, rescale_factor, data_format=data_format) return rescale(image, rescale_factor, data_format=data_format, input_data_format=input_data_format)
def convert_segmentation_map_to_binary_masks( def convert_segmentation_map_to_binary_masks(
self, self,
...@@ -545,13 +583,16 @@ class MaskFormerImageProcessor(BaseImageProcessor): ...@@ -545,13 +583,16 @@ class MaskFormerImageProcessor(BaseImageProcessor):
do_normalize: bool = None, do_normalize: bool = None,
image_mean: Optional[Union[float, List[float]]] = None, image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
): ):
if do_resize: if do_resize:
image = self.resize(image, size=size, size_divisor=size_divisor, resample=resample) image = self.resize(
image, size=size, size_divisor=size_divisor, resample=resample, input_data_format=input_data_format
)
if do_rescale: if do_rescale:
image = self.rescale(image, rescale_factor=rescale_factor) image = self.rescale(image, rescale_factor=rescale_factor, input_data_format=input_data_format)
if do_normalize: if do_normalize:
image = self.normalize(image, mean=image_mean, std=image_std) image = self.normalize(image, mean=image_mean, std=image_std, input_data_format=input_data_format)
return image return image
def _preprocess_image( def _preprocess_image(
...@@ -567,10 +608,13 @@ class MaskFormerImageProcessor(BaseImageProcessor): ...@@ -567,10 +608,13 @@ class MaskFormerImageProcessor(BaseImageProcessor):
image_mean: Optional[Union[float, List[float]]] = None, image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None,
data_format: Optional[Union[str, ChannelDimension]] = None, data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray: ) -> np.ndarray:
"""Preprocesses a single image.""" """Preprocesses a single image."""
# All transformations expect numpy arrays. # All transformations expect numpy arrays.
image = to_numpy_array(image) image = to_numpy_array(image)
if input_data_format is None:
input_data_format = infer_channel_dimension_format(image)
image = self._preprocess( image = self._preprocess(
image=image, image=image,
do_resize=do_resize, do_resize=do_resize,
...@@ -582,9 +626,10 @@ class MaskFormerImageProcessor(BaseImageProcessor): ...@@ -582,9 +626,10 @@ class MaskFormerImageProcessor(BaseImageProcessor):
do_normalize=do_normalize, do_normalize=do_normalize,
image_mean=image_mean, image_mean=image_mean,
image_std=image_std, image_std=image_std,
input_data_format=input_data_format,
) )
if data_format is not None: if data_format is not None:
image = to_channel_dimension_format(image, data_format) image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
return image return image
def _preprocess_mask( def _preprocess_mask(
...@@ -593,14 +638,19 @@ class MaskFormerImageProcessor(BaseImageProcessor): ...@@ -593,14 +638,19 @@ class MaskFormerImageProcessor(BaseImageProcessor):
do_resize: bool = None, do_resize: bool = None,
size: Dict[str, int] = None, size: Dict[str, int] = None,
size_divisor: int = 0, size_divisor: int = 0,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray: ) -> np.ndarray:
"""Preprocesses a single mask.""" """Preprocesses a single mask."""
segmentation_map = to_numpy_array(segmentation_map) segmentation_map = to_numpy_array(segmentation_map)
# Add channel dimension if missing - needed for certain transformations # Add channel dimension if missing - needed for certain transformations
added_channel_dim = False
if segmentation_map.ndim == 2: if segmentation_map.ndim == 2:
added_channel_dim = True added_channel_dim = True
segmentation_map = segmentation_map[None, ...] segmentation_map = segmentation_map[None, ...]
input_data_format = ChannelDimension.FIRST
else:
added_channel_dim = False
if input_data_format is None:
input_data_format = infer_channel_dimension_format(segmentation_map, num_channels=1)
# TODO: (Amy) # TODO: (Amy)
# Remork segmentation map processing to include reducing labels and resizing which doesn't # Remork segmentation map processing to include reducing labels and resizing which doesn't
# drop segment IDs > 255. # drop segment IDs > 255.
...@@ -612,6 +662,7 @@ class MaskFormerImageProcessor(BaseImageProcessor): ...@@ -612,6 +662,7 @@ class MaskFormerImageProcessor(BaseImageProcessor):
size_divisor=size_divisor, size_divisor=size_divisor,
do_rescale=False, do_rescale=False,
do_normalize=False, do_normalize=False,
input_data_format=input_data_format,
) )
# Remove extra channel dimension if added for processing # Remove extra channel dimension if added for processing
if added_channel_dim: if added_channel_dim:
...@@ -636,6 +687,7 @@ class MaskFormerImageProcessor(BaseImageProcessor): ...@@ -636,6 +687,7 @@ class MaskFormerImageProcessor(BaseImageProcessor):
do_reduce_labels: Optional[bool] = None, do_reduce_labels: Optional[bool] = None,
return_tensors: Optional[Union[str, TensorType]] = None, return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST, data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs, **kwargs,
) -> BatchFeature: ) -> BatchFeature:
if "pad_and_return_pixel_mask" in kwargs: if "pad_and_return_pixel_mask" in kwargs:
...@@ -708,17 +760,26 @@ class MaskFormerImageProcessor(BaseImageProcessor): ...@@ -708,17 +760,26 @@ class MaskFormerImageProcessor(BaseImageProcessor):
image_mean=image_mean, image_mean=image_mean,
image_std=image_std, image_std=image_std,
data_format=data_format, data_format=data_format,
input_data_format=input_data_format,
) )
for image in images for image in images
] ]
if segmentation_maps is not None: if segmentation_maps is not None:
segmentation_maps = [ segmentation_maps = [
self._preprocess_mask(segmentation_map, do_resize, size, size_divisor) self._preprocess_mask(
segmentation_map, do_resize, size, size_divisor, input_data_format=input_data_format
)
for segmentation_map in segmentation_maps for segmentation_map in segmentation_maps
] ]
encoded_inputs = self.encode_inputs( encoded_inputs = self.encode_inputs(
images, segmentation_maps, instance_id_to_semantic_id, ignore_index, do_reduce_labels, return_tensors images,
segmentation_maps,
instance_id_to_semantic_id,
ignore_index,
do_reduce_labels,
return_tensors,
input_data_format=input_data_format,
) )
return encoded_inputs return encoded_inputs
...@@ -729,18 +790,24 @@ class MaskFormerImageProcessor(BaseImageProcessor): ...@@ -729,18 +790,24 @@ class MaskFormerImageProcessor(BaseImageProcessor):
output_size: Tuple[int, int], output_size: Tuple[int, int],
constant_values: Union[float, Iterable[float]] = 0, constant_values: Union[float, Iterable[float]] = 0,
data_format: Optional[ChannelDimension] = None, data_format: Optional[ChannelDimension] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray: ) -> np.ndarray:
""" """
Pad an image with zeros to the given size. Pad an image with zeros to the given size.
""" """
input_height, input_width = get_image_size(image) input_height, input_width = get_image_size(image, channel_dim=input_data_format)
output_height, output_width = output_size output_height, output_width = output_size
pad_bottom = output_height - input_height pad_bottom = output_height - input_height
pad_right = output_width - input_width pad_right = output_width - input_width
padding = ((0, pad_bottom), (0, pad_right)) padding = ((0, pad_bottom), (0, pad_right))
padded_image = pad( padded_image = pad(
image, padding, mode=PaddingMode.CONSTANT, constant_values=constant_values, data_format=data_format image,
padding,
mode=PaddingMode.CONSTANT,
constant_values=constant_values,
data_format=data_format,
input_data_format=input_data_format,
) )
return padded_image return padded_image
...@@ -752,6 +819,7 @@ class MaskFormerImageProcessor(BaseImageProcessor): ...@@ -752,6 +819,7 @@ class MaskFormerImageProcessor(BaseImageProcessor):
return_pixel_mask: bool = True, return_pixel_mask: bool = True,
return_tensors: Optional[Union[str, TensorType]] = None, return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Optional[ChannelDimension] = None, data_format: Optional[ChannelDimension] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> BatchFeature: ) -> BatchFeature:
""" """
Pads a batch of images to the bottom and right of the image with zeros to the size of largest height and width Pads a batch of images to the bottom and right of the image with zeros to the size of largest height and width
...@@ -773,17 +841,28 @@ class MaskFormerImageProcessor(BaseImageProcessor): ...@@ -773,17 +841,28 @@ class MaskFormerImageProcessor(BaseImageProcessor):
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
data_format (`str` or `ChannelDimension`, *optional*): data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image. The channel dimension format of the image. If not provided, it will be the same as the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred.
""" """
pad_size = get_max_height_width(images) pad_size = get_max_height_width(images, input_data_format=input_data_format)
padded_images = [ padded_images = [
self._pad_image(image, pad_size, constant_values=constant_values, data_format=data_format) self._pad_image(
image,
pad_size,
constant_values=constant_values,
data_format=data_format,
input_data_format=input_data_format,
)
for image in images for image in images
] ]
data = {"pixel_values": padded_images} data = {"pixel_values": padded_images}
if return_pixel_mask: if return_pixel_mask:
masks = [make_pixel_mask(image=image, output_size=pad_size) for image in images] masks = [
make_pixel_mask(image=image, output_size=pad_size, input_data_format=input_data_format)
for image in images
]
data["pixel_mask"] = masks data["pixel_mask"] = masks
return BatchFeature(data=data, tensor_type=return_tensors) return BatchFeature(data=data, tensor_type=return_tensors)
...@@ -796,6 +875,7 @@ class MaskFormerImageProcessor(BaseImageProcessor): ...@@ -796,6 +875,7 @@ class MaskFormerImageProcessor(BaseImageProcessor):
ignore_index: Optional[int] = None, ignore_index: Optional[int] = None,
reduce_labels: bool = False, reduce_labels: bool = False,
return_tensors: Optional[Union[str, TensorType]] = None, return_tensors: Optional[Union[str, TensorType]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
): ):
""" """
Pad images up to the largest image in a batch and create a corresponding `pixel_mask`. Pad images up to the largest image in a batch and create a corresponding `pixel_mask`.
...@@ -848,12 +928,18 @@ class MaskFormerImageProcessor(BaseImageProcessor): ...@@ -848,12 +928,18 @@ class MaskFormerImageProcessor(BaseImageProcessor):
reduce_labels = self.do_reduce_labels if reduce_labels is None else reduce_labels reduce_labels = self.do_reduce_labels if reduce_labels is None else reduce_labels
pixel_values_list = [to_numpy_array(pixel_values) for pixel_values in pixel_values_list] pixel_values_list = [to_numpy_array(pixel_values) for pixel_values in pixel_values_list]
encoded_inputs = self.pad(pixel_values_list, return_tensors=return_tensors)
if input_data_format is None:
input_data_format = infer_channel_dimension_format(pixel_values_list[0])
encoded_inputs = self.pad(
pixel_values_list, return_tensors=return_tensors, input_data_format=input_data_format
)
if segmentation_maps is not None: if segmentation_maps is not None:
mask_labels = [] mask_labels = []
class_labels = [] class_labels = []
pad_size = get_max_height_width(pixel_values_list) pad_size = get_max_height_width(pixel_values_list, input_data_format=input_data_format)
# Convert to list of binary masks and labels # Convert to list of binary masks and labels
for idx, segmentation_map in enumerate(segmentation_maps): for idx, segmentation_map in enumerate(segmentation_maps):
segmentation_map = to_numpy_array(segmentation_map) segmentation_map = to_numpy_array(segmentation_map)
...@@ -869,7 +955,13 @@ class MaskFormerImageProcessor(BaseImageProcessor): ...@@ -869,7 +955,13 @@ class MaskFormerImageProcessor(BaseImageProcessor):
# this will be removed in the future # this will be removed in the future
masks = [mask[None, ...] for mask in masks] masks = [mask[None, ...] for mask in masks]
masks = [ masks = [
self._pad_image(image=mask, output_size=pad_size, constant_values=ignore_index) for mask in masks self._pad_image(
image=mask,
output_size=pad_size,
constant_values=ignore_index,
input_data_format=ChannelDimension.FIRST,
)
for mask in masks
] ]
masks = np.concatenate(masks, axis=0) masks = np.concatenate(masks, axis=0)
mask_labels.append(torch.from_numpy(masks)) mask_labels.append(torch.from_numpy(masks))
......
...@@ -30,6 +30,7 @@ from ...image_utils import ( ...@@ -30,6 +30,7 @@ from ...image_utils import (
ChannelDimension, ChannelDimension,
ImageInput, ImageInput,
PILImageResampling, PILImageResampling,
infer_channel_dimension_format,
make_list_of_images, make_list_of_images,
to_numpy_array, to_numpy_array,
valid_images, valid_images,
...@@ -118,6 +119,7 @@ class MobileNetV1ImageProcessor(BaseImageProcessor): ...@@ -118,6 +119,7 @@ class MobileNetV1ImageProcessor(BaseImageProcessor):
size: Dict[str, int], size: Dict[str, int],
resample: PILImageResampling = PILImageResampling.BICUBIC, resample: PILImageResampling = PILImageResampling.BICUBIC,
data_format: Optional[Union[str, ChannelDimension]] = None, data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs, **kwargs,
) -> np.ndarray: ) -> np.ndarray:
""" """
...@@ -133,12 +135,23 @@ class MobileNetV1ImageProcessor(BaseImageProcessor): ...@@ -133,12 +135,23 @@ class MobileNetV1ImageProcessor(BaseImageProcessor):
Resampling filter to use when resiizing the image. Resampling filter to use when resiizing the image.
data_format (`str` or `ChannelDimension`, *optional*): data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image. The channel dimension format of the image. If not provided, it will be the same as the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred.
""" """
size = get_size_dict(size, default_to_square=False) size = get_size_dict(size, default_to_square=False)
if "shortest_edge" not in size: if "shortest_edge" not in size:
raise ValueError(f"The `size` parameter must contain the key `shortest_edge`. Got {size.keys()}") raise ValueError(f"The `size` parameter must contain the key `shortest_edge`. Got {size.keys()}")
output_size = get_resize_output_image_size(image, size=size["shortest_edge"], default_to_square=False) output_size = get_resize_output_image_size(
return resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs) image, size=size["shortest_edge"], default_to_square=False, input_data_format=input_data_format
)
return resize(
image,
size=output_size,
resample=resample,
data_format=data_format,
input_data_format=input_data_format,
**kwargs,
)
def preprocess( def preprocess(
self, self,
...@@ -155,6 +168,7 @@ class MobileNetV1ImageProcessor(BaseImageProcessor): ...@@ -155,6 +168,7 @@ class MobileNetV1ImageProcessor(BaseImageProcessor):
image_std: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None,
return_tensors: Optional[Union[str, TensorType]] = None, return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST, data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs, **kwargs,
): ):
""" """
...@@ -197,6 +211,12 @@ class MobileNetV1ImageProcessor(BaseImageProcessor): ...@@ -197,6 +211,12 @@ class MobileNetV1ImageProcessor(BaseImageProcessor):
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- Unset: Use the channel dimension format of the input image. - Unset: Use the channel dimension format of the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
""" """
do_resize = do_resize if do_resize is not None else self.do_resize do_resize = do_resize if do_resize is not None else self.do_resize
size = size if size is not None else self.size size = size if size is not None else self.size
...@@ -234,19 +254,36 @@ class MobileNetV1ImageProcessor(BaseImageProcessor): ...@@ -234,19 +254,36 @@ class MobileNetV1ImageProcessor(BaseImageProcessor):
# All transformations expect numpy arrays. # All transformations expect numpy arrays.
images = [to_numpy_array(image) for image in images] images = [to_numpy_array(image) for image in images]
if input_data_format is None:
# We assume that all images have the same channel dimension format.
input_data_format = infer_channel_dimension_format(images[0])
if do_resize: if do_resize:
images = [self.resize(image=image, size=size, resample=resample) for image in images] images = [
self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
for image in images
]
if do_center_crop: if do_center_crop:
images = [self.center_crop(image=image, size=crop_size) for image in images] images = [
self.center_crop(image=image, size=crop_size, input_data_format=input_data_format) for image in images
]
if do_rescale: if do_rescale:
images = [self.rescale(image=image, scale=rescale_factor) for image in images] images = [
self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
for image in images
]
if do_normalize: if do_normalize:
images = [self.normalize(image=image, mean=image_mean, std=image_std) for image in images] images = [
self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
for image in images
]
images = [to_channel_dimension_format(image, data_format) for image in images] images = [
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
]
data = {"pixel_values": images} data = {"pixel_values": images}
return BatchFeature(data=data, tensor_type=return_tensors) return BatchFeature(data=data, tensor_type=return_tensors)
...@@ -30,6 +30,7 @@ from ...image_utils import ( ...@@ -30,6 +30,7 @@ from ...image_utils import (
ChannelDimension, ChannelDimension,
ImageInput, ImageInput,
PILImageResampling, PILImageResampling,
infer_channel_dimension_format,
make_list_of_images, make_list_of_images,
to_numpy_array, to_numpy_array,
valid_images, valid_images,
...@@ -122,6 +123,7 @@ class MobileNetV2ImageProcessor(BaseImageProcessor): ...@@ -122,6 +123,7 @@ class MobileNetV2ImageProcessor(BaseImageProcessor):
size: Dict[str, int], size: Dict[str, int],
resample: PILImageResampling = PILImageResampling.BICUBIC, resample: PILImageResampling = PILImageResampling.BICUBIC,
data_format: Optional[Union[str, ChannelDimension]] = None, data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs, **kwargs,
) -> np.ndarray: ) -> np.ndarray:
""" """
...@@ -137,12 +139,23 @@ class MobileNetV2ImageProcessor(BaseImageProcessor): ...@@ -137,12 +139,23 @@ class MobileNetV2ImageProcessor(BaseImageProcessor):
Resampling filter to use when resiizing the image. Resampling filter to use when resiizing the image.
data_format (`str` or `ChannelDimension`, *optional*): data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image. The channel dimension format of the image. If not provided, it will be the same as the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred.
""" """
size = get_size_dict(size, default_to_square=False) size = get_size_dict(size, default_to_square=False)
if "shortest_edge" not in size: if "shortest_edge" not in size:
raise ValueError(f"The `size` parameter must contain the key `shortest_edge`. Got {size.keys()}") raise ValueError(f"The `size` parameter must contain the key `shortest_edge`. Got {size.keys()}")
output_size = get_resize_output_image_size(image, size=size["shortest_edge"], default_to_square=False) output_size = get_resize_output_image_size(
return resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs) image, size=size["shortest_edge"], default_to_square=False, input_data_format=input_data_format
)
return resize(
image,
size=output_size,
resample=resample,
data_format=data_format,
input_data_format=input_data_format,
**kwargs,
)
def preprocess( def preprocess(
self, self,
...@@ -159,6 +172,7 @@ class MobileNetV2ImageProcessor(BaseImageProcessor): ...@@ -159,6 +172,7 @@ class MobileNetV2ImageProcessor(BaseImageProcessor):
image_std: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None,
return_tensors: Optional[Union[str, TensorType]] = None, return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST, data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs, **kwargs,
): ):
""" """
...@@ -201,6 +215,12 @@ class MobileNetV2ImageProcessor(BaseImageProcessor): ...@@ -201,6 +215,12 @@ class MobileNetV2ImageProcessor(BaseImageProcessor):
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- Unset: Use the channel dimension format of the input image. - Unset: Use the channel dimension format of the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
""" """
do_resize = do_resize if do_resize is not None else self.do_resize do_resize = do_resize if do_resize is not None else self.do_resize
size = size if size is not None else self.size size = size if size is not None else self.size
...@@ -238,19 +258,36 @@ class MobileNetV2ImageProcessor(BaseImageProcessor): ...@@ -238,19 +258,36 @@ class MobileNetV2ImageProcessor(BaseImageProcessor):
# All transformations expect numpy arrays. # All transformations expect numpy arrays.
images = [to_numpy_array(image) for image in images] images = [to_numpy_array(image) for image in images]
if input_data_format is None:
# We assume that all images have the same channel dimension format.
input_data_format = infer_channel_dimension_format(images[0])
if do_resize: if do_resize:
images = [self.resize(image=image, size=size, resample=resample) for image in images] images = [
self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
for image in images
]
if do_center_crop: if do_center_crop:
images = [self.center_crop(image=image, size=crop_size) for image in images] images = [
self.center_crop(image=image, size=crop_size, input_data_format=input_data_format) for image in images
]
if do_rescale: if do_rescale:
images = [self.rescale(image=image, scale=rescale_factor) for image in images] images = [
self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
for image in images
]
if do_normalize: if do_normalize:
images = [self.normalize(image=image, mean=image_mean, std=image_std) for image in images] images = [
self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
images = [to_channel_dimension_format(image, data_format) for image in images] for image in images
]
images = [
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
]
data = {"pixel_values": images} data = {"pixel_values": images}
return BatchFeature(data=data, tensor_type=return_tensors) return BatchFeature(data=data, tensor_type=return_tensors)
......
...@@ -29,6 +29,7 @@ from ...image_utils import ( ...@@ -29,6 +29,7 @@ from ...image_utils import (
ChannelDimension, ChannelDimension,
ImageInput, ImageInput,
PILImageResampling, PILImageResampling,
infer_channel_dimension_format,
make_list_of_images, make_list_of_images,
to_numpy_array, to_numpy_array,
valid_images, valid_images,
...@@ -114,6 +115,7 @@ class MobileViTImageProcessor(BaseImageProcessor): ...@@ -114,6 +115,7 @@ class MobileViTImageProcessor(BaseImageProcessor):
size: Dict[str, int], size: Dict[str, int],
resample: PILImageResampling = PILImageResampling.BILINEAR, resample: PILImageResampling = PILImageResampling.BILINEAR,
data_format: Optional[Union[str, ChannelDimension]] = None, data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs, **kwargs,
) -> np.ndarray: ) -> np.ndarray:
""" """
...@@ -129,15 +131,29 @@ class MobileViTImageProcessor(BaseImageProcessor): ...@@ -129,15 +131,29 @@ class MobileViTImageProcessor(BaseImageProcessor):
Resampling filter to use when resiizing the image. Resampling filter to use when resiizing the image.
data_format (`str` or `ChannelDimension`, *optional*): data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image. The channel dimension format of the image. If not provided, it will be the same as the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred.
""" """
size = get_size_dict(size, default_to_square=False) size = get_size_dict(size, default_to_square=False)
if "shortest_edge" not in size: if "shortest_edge" not in size:
raise ValueError(f"The `size` parameter must contain the key `shortest_edge`. Got {size.keys()}") raise ValueError(f"The `size` parameter must contain the key `shortest_edge`. Got {size.keys()}")
output_size = get_resize_output_image_size(image, size=size["shortest_edge"], default_to_square=False) output_size = get_resize_output_image_size(
return resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs) image, size=size["shortest_edge"], default_to_square=False, input_data_format=input_data_format
)
return resize(
image,
size=output_size,
resample=resample,
data_format=data_format,
input_data_format=input_data_format,
**kwargs,
)
def flip_channel_order( def flip_channel_order(
self, image: np.ndarray, data_format: Optional[Union[str, ChannelDimension]] = None self,
image: np.ndarray,
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray: ) -> np.ndarray:
""" """
Flip the color channels from RGB to BGR or vice versa. Flip the color channels from RGB to BGR or vice versa.
...@@ -147,8 +163,10 @@ class MobileViTImageProcessor(BaseImageProcessor): ...@@ -147,8 +163,10 @@ class MobileViTImageProcessor(BaseImageProcessor):
The image, represented as a numpy array. The image, represented as a numpy array.
data_format (`ChannelDimension` or `str`, *optional*): data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image. The channel dimension format of the image. If not provided, it will be the same as the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred.
""" """
return flip_channel_order(image, data_format=data_format) return flip_channel_order(image, data_format=data_format, input_data_format=input_data_format)
def preprocess( def preprocess(
self, self,
...@@ -163,6 +181,7 @@ class MobileViTImageProcessor(BaseImageProcessor): ...@@ -163,6 +181,7 @@ class MobileViTImageProcessor(BaseImageProcessor):
do_flip_channel_order: bool = None, do_flip_channel_order: bool = None,
return_tensors: Optional[Union[str, TensorType]] = None, return_tensors: Optional[Union[str, TensorType]] = None,
data_format: ChannelDimension = ChannelDimension.FIRST, data_format: ChannelDimension = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs, **kwargs,
) -> PIL.Image.Image: ) -> PIL.Image.Image:
""" """
...@@ -199,6 +218,12 @@ class MobileViTImageProcessor(BaseImageProcessor): ...@@ -199,6 +218,12 @@ class MobileViTImageProcessor(BaseImageProcessor):
The channel dimension format for the output image. Can be one of: The channel dimension format for the output image. Can be one of:
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `ChannelDimension.LAST`: image in (height, width, num_channels) format. - `ChannelDimension.LAST`: image in (height, width, num_channels) format.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
""" """
do_resize = do_resize if do_resize is not None else self.do_resize do_resize = do_resize if do_resize is not None else self.do_resize
resample = resample if resample is not None else self.resample resample = resample if resample is not None else self.resample
...@@ -234,20 +259,34 @@ class MobileViTImageProcessor(BaseImageProcessor): ...@@ -234,20 +259,34 @@ class MobileViTImageProcessor(BaseImageProcessor):
# All transformations expect numpy arrays. # All transformations expect numpy arrays.
images = [to_numpy_array(image) for image in images] images = [to_numpy_array(image) for image in images]
if input_data_format is None:
# We assume that all images have the same channel dimension format.
input_data_format = infer_channel_dimension_format(images[0])
if do_resize: if do_resize:
images = [self.resize(image=image, size=size, resample=resample) for image in images] images = [
self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
for image in images
]
if do_center_crop: if do_center_crop:
images = [self.center_crop(image=image, size=crop_size) for image in images] images = [
self.center_crop(image=image, size=crop_size, input_data_format=input_data_format) for image in images
]
if do_rescale: if do_rescale:
images = [self.rescale(image=image, scale=rescale_factor) for image in images] images = [
self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
for image in images
]
# the pretrained checkpoints assume images are BGR, not RGB # the pretrained checkpoints assume images are BGR, not RGB
if do_flip_channel_order: if do_flip_channel_order:
images = [self.flip_channel_order(image=image) for image in images] images = [self.flip_channel_order(image=image, input_data_format=input_data_format) for image in images]
images = [to_channel_dimension_format(image, data_format) for image in images] images = [
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
]
data = {"pixel_values": images} data = {"pixel_values": images}
return BatchFeature(data=data, tensor_type=return_tensors) return BatchFeature(data=data, tensor_type=return_tensors)
......
...@@ -67,23 +67,28 @@ def max_across_indices(values: Iterable[Any]) -> List[Any]: ...@@ -67,23 +67,28 @@ def max_across_indices(values: Iterable[Any]) -> List[Any]:
# Copied from transformers.models.detr.image_processing_detr.get_max_height_width # Copied from transformers.models.detr.image_processing_detr.get_max_height_width
def get_max_height_width(images: List[np.ndarray]) -> List[int]: def get_max_height_width(
images: List[np.ndarray], input_data_format: Optional[Union[str, ChannelDimension]] = None
) -> List[int]:
""" """
Get the maximum height and width across all images in a batch. Get the maximum height and width across all images in a batch.
""" """
input_channel_dimension = infer_channel_dimension_format(images[0]) if input_data_format is None:
input_data_format = infer_channel_dimension_format(images[0])
if input_channel_dimension == ChannelDimension.FIRST: if input_data_format == ChannelDimension.FIRST:
_, max_height, max_width = max_across_indices([img.shape for img in images]) _, max_height, max_width = max_across_indices([img.shape for img in images])
elif input_channel_dimension == ChannelDimension.LAST: elif input_data_format == ChannelDimension.LAST:
max_height, max_width, _ = max_across_indices([img.shape for img in images]) max_height, max_width, _ = max_across_indices([img.shape for img in images])
else: else:
raise ValueError(f"Invalid channel dimension format: {input_channel_dimension}") raise ValueError(f"Invalid channel dimension format: {input_data_format}")
return (max_height, max_width) return (max_height, max_width)
# Copied from transformers.models.detr.image_processing_detr.make_pixel_mask # Copied from transformers.models.detr.image_processing_detr.make_pixel_mask
def make_pixel_mask(image: np.ndarray, output_size: Tuple[int, int]) -> np.ndarray: def make_pixel_mask(
image: np.ndarray, output_size: Tuple[int, int], input_data_format: Optional[Union[str, ChannelDimension]] = None
) -> np.ndarray:
""" """
Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding. Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.
...@@ -93,7 +98,7 @@ def make_pixel_mask(image: np.ndarray, output_size: Tuple[int, int]) -> np.ndarr ...@@ -93,7 +98,7 @@ def make_pixel_mask(image: np.ndarray, output_size: Tuple[int, int]) -> np.ndarr
output_size (`Tuple[int, int]`): output_size (`Tuple[int, int]`):
Output size of the mask. Output size of the mask.
""" """
input_height, input_width = get_image_size(image) input_height, input_width = get_image_size(image, channel_dim=input_data_format)
mask = np.zeros(output_size, dtype=np.int64) mask = np.zeros(output_size, dtype=np.int64)
mask[:input_height, :input_width] = 1 mask[:input_height, :input_width] = 1
return mask return mask
...@@ -295,6 +300,7 @@ def get_oneformer_resize_output_image_size( ...@@ -295,6 +300,7 @@ def get_oneformer_resize_output_image_size(
size: Union[int, Tuple[int, int], List[int], Tuple[int]], size: Union[int, Tuple[int, int], List[int], Tuple[int]],
max_size: Optional[int] = None, max_size: Optional[int] = None,
default_to_square: bool = True, default_to_square: bool = True,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> tuple: ) -> tuple:
""" """
Computes the output size given the desired size. Computes the output size given the desired size.
...@@ -304,16 +310,20 @@ def get_oneformer_resize_output_image_size( ...@@ -304,16 +310,20 @@ def get_oneformer_resize_output_image_size(
The input image. The input image.
size (`int`, `Tuple[int, int]`, `List[int]`, `Tuple[int]`): size (`int`, `Tuple[int, int]`, `List[int]`, `Tuple[int]`):
The size of the output image. The size of the output image.
default_to_square (`bool`, *optional*, defaults to `True`):
Whether to default to square if no size is provided.
max_size (`int`, *optional*): max_size (`int`, *optional*):
The maximum size of the output image. The maximum size of the output image.
default_to_square (`bool`, *optional*, defaults to `True`):
Whether to default to square if no size is provided.
Returns: Returns:
`Tuple[int, int]`: The output size. `Tuple[int, int]`: The output size.
""" """
output_size = get_resize_output_image_size( output_size = get_resize_output_image_size(
input_image=image, size=size, default_to_square=default_to_square, max_size=max_size input_image=image,
size=size,
default_to_square=default_to_square,
max_size=max_size,
input_data_format=input_data_format,
) )
return output_size return output_size
...@@ -442,6 +452,7 @@ class OneFormerImageProcessor(BaseImageProcessor): ...@@ -442,6 +452,7 @@ class OneFormerImageProcessor(BaseImageProcessor):
size: Dict[str, int], size: Dict[str, int],
resample: PILImageResampling = PILImageResampling.BILINEAR, resample: PILImageResampling = PILImageResampling.BILINEAR,
data_format=None, data_format=None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs, **kwargs,
) -> np.ndarray: ) -> np.ndarray:
""" """
...@@ -469,17 +480,20 @@ class OneFormerImageProcessor(BaseImageProcessor): ...@@ -469,17 +480,20 @@ class OneFormerImageProcessor(BaseImageProcessor):
f" {size.keys()}." f" {size.keys()}."
) )
size = get_oneformer_resize_output_image_size( size = get_oneformer_resize_output_image_size(
image=image, image=image, size=size, max_size=max_size, default_to_square=False, input_data_format=input_data_format
size=size, )
max_size=max_size, image = resize(
default_to_square=False, image, size=size, resample=resample, data_format=data_format, input_data_format=input_data_format
) )
image = resize(image, size=size, resample=resample, data_format=data_format)
return image return image
# Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.rescale # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.rescale
def rescale( def rescale(
self, image: np.ndarray, rescale_factor: float, data_format: Optional[Union[str, ChannelDimension]] = None self,
image: np.ndarray,
rescale_factor: float,
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray: ) -> np.ndarray:
""" """
Rescale the image by the given factor. image = image * rescale_factor. Rescale the image by the given factor. image = image * rescale_factor.
...@@ -494,8 +508,13 @@ class OneFormerImageProcessor(BaseImageProcessor): ...@@ -494,8 +508,13 @@ class OneFormerImageProcessor(BaseImageProcessor):
image is used. Can be one of: image is used. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
input_data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format for the input image. If unset, is inferred from the input image. Can be
one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
""" """
return rescale(image, rescale_factor, data_format=data_format) return rescale(image, rescale_factor, data_format=data_format, input_data_format=input_data_format)
# Copied from transformers.models.maskformer.image_processing_maskformer.MaskFormerImageProcessor.convert_segmentation_map_to_binary_masks # Copied from transformers.models.maskformer.image_processing_maskformer.MaskFormerImageProcessor.convert_segmentation_map_to_binary_masks
def convert_segmentation_map_to_binary_masks( def convert_segmentation_map_to_binary_masks(
...@@ -528,13 +547,14 @@ class OneFormerImageProcessor(BaseImageProcessor): ...@@ -528,13 +547,14 @@ class OneFormerImageProcessor(BaseImageProcessor):
do_normalize: bool = None, do_normalize: bool = None,
image_mean: Optional[Union[float, List[float]]] = None, image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
): ):
if do_resize: if do_resize:
image = self.resize(image, size=size, resample=resample) image = self.resize(image, size=size, resample=resample, input_data_format=input_data_format)
if do_rescale: if do_rescale:
image = self.rescale(image, rescale_factor=rescale_factor) image = self.rescale(image, rescale_factor=rescale_factor, input_data_format=input_data_format)
if do_normalize: if do_normalize:
image = self.normalize(image, mean=image_mean, std=image_std) image = self.normalize(image, mean=image_mean, std=image_std, input_data_format=input_data_format)
return image return image
def _preprocess_image( def _preprocess_image(
...@@ -549,10 +569,13 @@ class OneFormerImageProcessor(BaseImageProcessor): ...@@ -549,10 +569,13 @@ class OneFormerImageProcessor(BaseImageProcessor):
image_mean: Optional[Union[float, List[float]]] = None, image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None,
data_format: Optional[Union[str, ChannelDimension]] = None, data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray: ) -> np.ndarray:
"""Preprocesses a single image.""" """Preprocesses a single image."""
# All transformations expect numpy arrays. # All transformations expect numpy arrays.
image = to_numpy_array(image) image = to_numpy_array(image)
if input_data_format is None:
input_data_format = infer_channel_dimension_format(image)
image = self._preprocess( image = self._preprocess(
image=image, image=image,
do_resize=do_resize, do_resize=do_resize,
...@@ -563,9 +586,10 @@ class OneFormerImageProcessor(BaseImageProcessor): ...@@ -563,9 +586,10 @@ class OneFormerImageProcessor(BaseImageProcessor):
do_normalize=do_normalize, do_normalize=do_normalize,
image_mean=image_mean, image_mean=image_mean,
image_std=image_std, image_std=image_std,
input_data_format=input_data_format,
) )
if data_format is not None: if data_format is not None:
image = to_channel_dimension_format(image, data_format) image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
return image return image
def _preprocess_mask( def _preprocess_mask(
...@@ -573,14 +597,19 @@ class OneFormerImageProcessor(BaseImageProcessor): ...@@ -573,14 +597,19 @@ class OneFormerImageProcessor(BaseImageProcessor):
segmentation_map: ImageInput, segmentation_map: ImageInput,
do_resize: bool = None, do_resize: bool = None,
size: Dict[str, int] = None, size: Dict[str, int] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray: ) -> np.ndarray:
"""Preprocesses a single mask.""" """Preprocesses a single mask."""
segmentation_map = to_numpy_array(segmentation_map) segmentation_map = to_numpy_array(segmentation_map)
# Add channel dimension if missing - needed for certain transformations # Add channel dimension if missing - needed for certain transformations
added_channel_dim = False
if segmentation_map.ndim == 2: if segmentation_map.ndim == 2:
added_channel_dim = True added_channel_dim = True
segmentation_map = segmentation_map[None, ...] segmentation_map = segmentation_map[None, ...]
input_data_format = ChannelDimension.FIRST
else:
added_channel_dim = False
if input_data_format is None:
input_data_format = infer_channel_dimension_format(segmentation_map, num_channels=1)
# TODO: (Amy) # TODO: (Amy)
# Remork segmentation map processing to include reducing labels and resizing which doesn't # Remork segmentation map processing to include reducing labels and resizing which doesn't
# drop segment IDs > 255. # drop segment IDs > 255.
...@@ -591,6 +620,7 @@ class OneFormerImageProcessor(BaseImageProcessor): ...@@ -591,6 +620,7 @@ class OneFormerImageProcessor(BaseImageProcessor):
size=size, size=size,
do_rescale=False, do_rescale=False,
do_normalize=False, do_normalize=False,
input_data_format=input_data_format,
) )
# Remove extra channel dimension if added for processing # Remove extra channel dimension if added for processing
if added_channel_dim: if added_channel_dim:
...@@ -615,6 +645,7 @@ class OneFormerImageProcessor(BaseImageProcessor): ...@@ -615,6 +645,7 @@ class OneFormerImageProcessor(BaseImageProcessor):
do_reduce_labels: Optional[bool] = None, do_reduce_labels: Optional[bool] = None,
return_tensors: Optional[Union[str, TensorType]] = None, return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST, data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs, **kwargs,
) -> BatchFeature: ) -> BatchFeature:
if "pad_and_return_pixel_mask" in kwargs: if "pad_and_return_pixel_mask" in kwargs:
...@@ -691,13 +722,15 @@ class OneFormerImageProcessor(BaseImageProcessor): ...@@ -691,13 +722,15 @@ class OneFormerImageProcessor(BaseImageProcessor):
image_mean=image_mean, image_mean=image_mean,
image_std=image_std, image_std=image_std,
data_format=data_format, data_format=data_format,
input_data_format=input_data_format,
) )
for image in images for image in images
] ]
if segmentation_maps is not None: if segmentation_maps is not None:
segmentation_maps = [ segmentation_maps = [
self._preprocess_mask(segmentation_map, do_resize, size) for segmentation_map in segmentation_maps self._preprocess_mask(segmentation_map, do_resize, size, input_data_format=input_data_format)
for segmentation_map in segmentation_maps
] ]
encoded_inputs = self.encode_inputs( encoded_inputs = self.encode_inputs(
images, images,
...@@ -707,6 +740,7 @@ class OneFormerImageProcessor(BaseImageProcessor): ...@@ -707,6 +740,7 @@ class OneFormerImageProcessor(BaseImageProcessor):
ignore_index, ignore_index,
do_reduce_labels, do_reduce_labels,
return_tensors, return_tensors,
input_data_format=input_data_format,
) )
return encoded_inputs return encoded_inputs
...@@ -717,18 +751,24 @@ class OneFormerImageProcessor(BaseImageProcessor): ...@@ -717,18 +751,24 @@ class OneFormerImageProcessor(BaseImageProcessor):
output_size: Tuple[int, int], output_size: Tuple[int, int],
constant_values: Union[float, Iterable[float]] = 0, constant_values: Union[float, Iterable[float]] = 0,
data_format: Optional[ChannelDimension] = None, data_format: Optional[ChannelDimension] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray: ) -> np.ndarray:
""" """
Pad an image with zeros to the given size. Pad an image with zeros to the given size.
""" """
input_height, input_width = get_image_size(image) input_height, input_width = get_image_size(image, channel_dim=input_data_format)
output_height, output_width = output_size output_height, output_width = output_size
pad_bottom = output_height - input_height pad_bottom = output_height - input_height
pad_right = output_width - input_width pad_right = output_width - input_width
padding = ((0, pad_bottom), (0, pad_right)) padding = ((0, pad_bottom), (0, pad_right))
padded_image = pad( padded_image = pad(
image, padding, mode=PaddingMode.CONSTANT, constant_values=constant_values, data_format=data_format image,
padding,
mode=PaddingMode.CONSTANT,
constant_values=constant_values,
data_format=data_format,
input_data_format=input_data_format,
) )
return padded_image return padded_image
...@@ -740,6 +780,7 @@ class OneFormerImageProcessor(BaseImageProcessor): ...@@ -740,6 +780,7 @@ class OneFormerImageProcessor(BaseImageProcessor):
return_pixel_mask: bool = True, return_pixel_mask: bool = True,
return_tensors: Optional[Union[str, TensorType]] = None, return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Optional[ChannelDimension] = None, data_format: Optional[ChannelDimension] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> BatchFeature: ) -> BatchFeature:
""" """
Pads a batch of images to the bottom and right of the image with zeros to the size of largest height and width Pads a batch of images to the bottom and right of the image with zeros to the size of largest height and width
...@@ -761,17 +802,28 @@ class OneFormerImageProcessor(BaseImageProcessor): ...@@ -761,17 +802,28 @@ class OneFormerImageProcessor(BaseImageProcessor):
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
data_format (`str` or `ChannelDimension`, *optional*): data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image. The channel dimension format of the image. If not provided, it will be the same as the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred.
""" """
pad_size = get_max_height_width(images) pad_size = get_max_height_width(images, input_data_format=input_data_format)
padded_images = [ padded_images = [
self._pad_image(image, pad_size, constant_values=constant_values, data_format=data_format) self._pad_image(
image,
pad_size,
constant_values=constant_values,
data_format=data_format,
input_data_format=input_data_format,
)
for image in images for image in images
] ]
data = {"pixel_values": padded_images} data = {"pixel_values": padded_images}
if return_pixel_mask: if return_pixel_mask:
masks = [make_pixel_mask(image=image, output_size=pad_size) for image in images] masks = [
make_pixel_mask(image=image, output_size=pad_size, input_data_format=input_data_format)
for image in images
]
data["pixel_mask"] = masks data["pixel_mask"] = masks
return BatchFeature(data=data, tensor_type=return_tensors) return BatchFeature(data=data, tensor_type=return_tensors)
...@@ -882,6 +934,7 @@ class OneFormerImageProcessor(BaseImageProcessor): ...@@ -882,6 +934,7 @@ class OneFormerImageProcessor(BaseImageProcessor):
ignore_index: Optional[int] = None, ignore_index: Optional[int] = None,
reduce_labels: bool = False, reduce_labels: bool = False,
return_tensors: Optional[Union[str, TensorType]] = None, return_tensors: Optional[Union[str, TensorType]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
): ):
""" """
Pad images up to the largest image in a batch and create a corresponding `pixel_mask`. Pad images up to the largest image in a batch and create a corresponding `pixel_mask`.
...@@ -921,6 +974,10 @@ class OneFormerImageProcessor(BaseImageProcessor): ...@@ -921,6 +974,10 @@ class OneFormerImageProcessor(BaseImageProcessor):
If set, will return tensors instead of NumPy arrays. If set to `'pt'`, return PyTorch `torch.Tensor` If set, will return tensors instead of NumPy arrays. If set to `'pt'`, return PyTorch `torch.Tensor`
objects. objects.
input_data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred from the input
image.
Returns: Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields: [`BatchFeature`]: A [`BatchFeature`] with the following fields:
...@@ -938,8 +995,14 @@ class OneFormerImageProcessor(BaseImageProcessor): ...@@ -938,8 +995,14 @@ class OneFormerImageProcessor(BaseImageProcessor):
ignore_index = self.ignore_index if ignore_index is None else ignore_index ignore_index = self.ignore_index if ignore_index is None else ignore_index
reduce_labels = self.do_reduce_labels if reduce_labels is None else reduce_labels reduce_labels = self.do_reduce_labels if reduce_labels is None else reduce_labels
pixel_values_list = [to_numpy_array(pixel_values) for pixel_values in pixel_values_list] pixel_values_list = [to_numpy_array(pixel_values) for pixel_values in pixel_values_list]
pad_size = get_max_height_width(pixel_values_list)
encoded_inputs = self.pad(pixel_values_list, return_tensors=return_tensors) if input_data_format is None:
input_data_format = infer_channel_dimension_format(pixel_values_list[0])
pad_size = get_max_height_width(pixel_values_list, input_data_format=input_data_format)
encoded_inputs = self.pad(
pixel_values_list, return_tensors=return_tensors, input_data_format=input_data_format
)
annotations = None annotations = None
if segmentation_maps is not None: if segmentation_maps is not None:
......
...@@ -33,6 +33,7 @@ from ...image_utils import ( ...@@ -33,6 +33,7 @@ from ...image_utils import (
ChannelDimension, ChannelDimension,
ImageInput, ImageInput,
PILImageResampling, PILImageResampling,
infer_channel_dimension_format,
make_list_of_images, make_list_of_images,
to_numpy_array, to_numpy_array,
valid_images, valid_images,
...@@ -169,36 +170,79 @@ class OwlViTImageProcessor(BaseImageProcessor): ...@@ -169,36 +170,79 @@ class OwlViTImageProcessor(BaseImageProcessor):
size: Dict[str, int], size: Dict[str, int],
resample: PILImageResampling.BICUBIC, resample: PILImageResampling.BICUBIC,
data_format: Optional[Union[str, ChannelDimension]] = None, data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs, **kwargs,
) -> np.ndarray: ) -> np.ndarray:
""" """
Resize an image to a certain size. Resize an image to a certain size.
Args:
image (`np.ndarray`):
Image to resize.
size (`Dict[str, int]`):
The size to resize the image to. Must contain height and width keys.
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
The resampling filter to use when resizing the input.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format for the output image. If unset, the channel dimension format of the input
image is used.
input_data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred.
""" """
size = get_size_dict(size, default_to_square=True) size = get_size_dict(size, default_to_square=True)
if "height" not in size or "width" not in size: if "height" not in size or "width" not in size:
raise ValueError("size dictionary must contain height and width keys") raise ValueError("size dictionary must contain height and width keys")
return resize(image, (size["height"], size["width"]), resample=resample, data_format=data_format, **kwargs) return resize(
image,
(size["height"], size["width"]),
resample=resample,
data_format=data_format,
input_data_format=input_data_format,
**kwargs,
)
def center_crop( def center_crop(
self, self,
image: np.ndarray, image: np.ndarray,
crop_size: Dict[str, int], crop_size: Dict[str, int],
data_format: Optional[Union[str, ChannelDimension]] = None, data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs, **kwargs,
) -> np.ndarray: ) -> np.ndarray:
""" """
Center crop an image to a certain size. Center crop an image to a certain size.
Args:
image (`np.ndarray`):
Image to center crop.
crop_size (`Dict[str, int]`):
The size to center crop the image to. Must contain height and width keys.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format for the output image. If unset, the channel dimension format of the input
image is used.
input_data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred.
""" """
crop_size = get_size_dict(crop_size, default_to_square=True) crop_size = get_size_dict(crop_size, default_to_square=True)
if "height" not in crop_size or "width" not in crop_size: if "height" not in crop_size or "width" not in crop_size:
raise ValueError("crop_size dictionary must contain height and width keys") raise ValueError("crop_size dictionary must contain height and width keys")
return center_crop(image, (crop_size["height"], crop_size["width"]), data_format=data_format, **kwargs) return center_crop(
image,
(crop_size["height"], crop_size["width"]),
data_format=data_format,
input_data_format=input_data_format,
**kwargs,
)
# Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.rescale # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.rescale
def rescale( def rescale(
self, image: np.ndarray, rescale_factor: float, data_format: Optional[Union[str, ChannelDimension]] = None self,
image: np.ndarray,
rescale_factor: float,
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray: ) -> np.ndarray:
""" """
Rescale the image by the given factor. image = image * rescale_factor. Rescale the image by the given factor. image = image * rescale_factor.
...@@ -213,8 +257,13 @@ class OwlViTImageProcessor(BaseImageProcessor): ...@@ -213,8 +257,13 @@ class OwlViTImageProcessor(BaseImageProcessor):
image is used. Can be one of: image is used. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
input_data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format for the input image. If unset, is inferred from the input image. Can be
one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
""" """
return rescale(image, rescale_factor, data_format=data_format) return rescale(image, rescale_factor, data_format=data_format, input_data_format=input_data_format)
def preprocess( def preprocess(
self, self,
...@@ -231,6 +280,7 @@ class OwlViTImageProcessor(BaseImageProcessor): ...@@ -231,6 +280,7 @@ class OwlViTImageProcessor(BaseImageProcessor):
image_std: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None,
return_tensors: Optional[Union[TensorType, str]] = None, return_tensors: Optional[Union[TensorType, str]] = None,
data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST, data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs, **kwargs,
) -> BatchFeature: ) -> BatchFeature:
""" """
...@@ -277,6 +327,12 @@ class OwlViTImageProcessor(BaseImageProcessor): ...@@ -277,6 +327,12 @@ class OwlViTImageProcessor(BaseImageProcessor):
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `ChannelDimension.LAST`: image in (height, width, num_channels) format. - `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- Unset: defaults to the channel dimension format of the input image. - Unset: defaults to the channel dimension format of the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
""" """
do_resize = do_resize if do_resize is not None else self.do_resize do_resize = do_resize if do_resize is not None else self.do_resize
size = size if size is not None else self.size size = size if size is not None else self.size
...@@ -312,19 +368,36 @@ class OwlViTImageProcessor(BaseImageProcessor): ...@@ -312,19 +368,36 @@ class OwlViTImageProcessor(BaseImageProcessor):
# All transformations expect numpy arrays # All transformations expect numpy arrays
images = [to_numpy_array(image) for image in images] images = [to_numpy_array(image) for image in images]
if input_data_format is None:
# We assume that all images have the same channel dimension format.
input_data_format = infer_channel_dimension_format(images[0])
if do_resize: if do_resize:
images = [self.resize(image, size=size, resample=resample) for image in images] images = [
self.resize(image, size=size, resample=resample, input_data_format=input_data_format)
for image in images
]
if do_center_crop: if do_center_crop:
images = [self.center_crop(image, crop_size=crop_size) for image in images] images = [
self.center_crop(image, crop_size=crop_size, input_data_format=input_data_format) for image in images
]
if do_rescale: if do_rescale:
images = [self.rescale(image, rescale_factor=rescale_factor) for image in images] images = [
self.rescale(image, rescale_factor=rescale_factor, input_data_format=input_data_format)
for image in images
]
if do_normalize: if do_normalize:
images = [self.normalize(image, mean=image_mean, std=image_std) for image in images] images = [
self.normalize(image, mean=image_mean, std=image_std, input_data_format=input_data_format)
images = [to_channel_dimension_format(image, data_format) for image in images] for image in images
]
images = [
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
]
encoded_inputs = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors) encoded_inputs = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors)
return encoded_inputs return encoded_inputs
......
...@@ -27,6 +27,7 @@ from ...image_utils import ( ...@@ -27,6 +27,7 @@ from ...image_utils import (
ImageInput, ImageInput,
PILImageResampling, PILImageResampling,
get_image_size, get_image_size,
infer_channel_dimension_format,
make_list_of_images, make_list_of_images,
to_numpy_array, to_numpy_array,
valid_images, valid_images,
...@@ -117,6 +118,7 @@ class PerceiverImageProcessor(BaseImageProcessor): ...@@ -117,6 +118,7 @@ class PerceiverImageProcessor(BaseImageProcessor):
crop_size: Dict[str, int], crop_size: Dict[str, int],
size: Optional[int] = None, size: Optional[int] = None,
data_format: Optional[Union[str, ChannelDimension]] = None, data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs, **kwargs,
) -> np.ndarray: ) -> np.ndarray:
""" """
...@@ -135,16 +137,24 @@ class PerceiverImageProcessor(BaseImageProcessor): ...@@ -135,16 +137,24 @@ class PerceiverImageProcessor(BaseImageProcessor):
Size of the image after resizing. If not provided, the self.size attribute will be used. Size of the image after resizing. If not provided, the self.size attribute will be used.
data_format (`str` or `ChannelDimension`, *optional*): data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image. The channel dimension format of the image. If not provided, it will be the same as the input image.
input_data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred.
""" """
size = self.size if size is None else size size = self.size if size is None else size
size = get_size_dict(size) size = get_size_dict(size)
crop_size = get_size_dict(crop_size, param_name="crop_size") crop_size = get_size_dict(crop_size, param_name="crop_size")
height, width = get_image_size(image) height, width = get_image_size(image, channel_dim=input_data_format)
min_dim = min(height, width) min_dim = min(height, width)
cropped_height = (size["height"] / crop_size["height"]) * min_dim cropped_height = (size["height"] / crop_size["height"]) * min_dim
cropped_width = (size["width"] / crop_size["width"]) * min_dim cropped_width = (size["width"] / crop_size["width"]) * min_dim
return center_crop(image, size=(cropped_height, cropped_width), data_format=data_format, **kwargs) return center_crop(
image,
size=(cropped_height, cropped_width),
data_format=data_format,
input_data_format=input_data_format,
**kwargs,
)
# Copied from transformers.models.vit.image_processing_vit.ViTImageProcessor.resize with PILImageResampling.BILINEAR->PILImageResampling.BICUBIC # Copied from transformers.models.vit.image_processing_vit.ViTImageProcessor.resize with PILImageResampling.BILINEAR->PILImageResampling.BICUBIC
def resize( def resize(
...@@ -153,6 +163,7 @@ class PerceiverImageProcessor(BaseImageProcessor): ...@@ -153,6 +163,7 @@ class PerceiverImageProcessor(BaseImageProcessor):
size: Dict[str, int], size: Dict[str, int],
resample: PILImageResampling = PILImageResampling.BICUBIC, resample: PILImageResampling = PILImageResampling.BICUBIC,
data_format: Optional[Union[str, ChannelDimension]] = None, data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs, **kwargs,
) -> np.ndarray: ) -> np.ndarray:
""" """
...@@ -170,6 +181,13 @@ class PerceiverImageProcessor(BaseImageProcessor): ...@@ -170,6 +181,13 @@ class PerceiverImageProcessor(BaseImageProcessor):
image is used. Can be one of: image is used. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
Returns: Returns:
`np.ndarray`: The resized image. `np.ndarray`: The resized image.
...@@ -178,7 +196,14 @@ class PerceiverImageProcessor(BaseImageProcessor): ...@@ -178,7 +196,14 @@ class PerceiverImageProcessor(BaseImageProcessor):
if "height" not in size or "width" not in size: if "height" not in size or "width" not in size:
raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}") raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}")
output_size = (size["height"], size["width"]) output_size = (size["height"], size["width"])
return resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs) return resize(
image,
size=output_size,
resample=resample,
data_format=data_format,
input_data_format=input_data_format,
**kwargs,
)
def preprocess( def preprocess(
self, self,
...@@ -195,6 +220,7 @@ class PerceiverImageProcessor(BaseImageProcessor): ...@@ -195,6 +220,7 @@ class PerceiverImageProcessor(BaseImageProcessor):
image_std: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None,
return_tensors: Optional[Union[str, TensorType]] = None, return_tensors: Optional[Union[str, TensorType]] = None,
data_format: ChannelDimension = ChannelDimension.FIRST, data_format: ChannelDimension = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs, **kwargs,
) -> PIL.Image.Image: ) -> PIL.Image.Image:
""" """
...@@ -235,6 +261,12 @@ class PerceiverImageProcessor(BaseImageProcessor): ...@@ -235,6 +261,12 @@ class PerceiverImageProcessor(BaseImageProcessor):
The channel dimension format for the output image. Can be one of: The channel dimension format for the output image. Can be one of:
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `ChannelDimension.LAST`: image in (height, width, num_channels) format. - `ChannelDimension.LAST`: image in (height, width, num_channels) format.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
""" """
do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
crop_size = crop_size if crop_size is not None else self.crop_size crop_size = crop_size if crop_size is not None else self.crop_size
...@@ -272,19 +304,36 @@ class PerceiverImageProcessor(BaseImageProcessor): ...@@ -272,19 +304,36 @@ class PerceiverImageProcessor(BaseImageProcessor):
# All transformations expect numpy arrays. # All transformations expect numpy arrays.
images = [to_numpy_array(image) for image in images] images = [to_numpy_array(image) for image in images]
if input_data_format is None:
# We assume that all images have the same channel dimension format.
input_data_format = infer_channel_dimension_format(images[0])
if do_center_crop: if do_center_crop:
images = [self.center_crop(image, crop_size, size=size) for image in images] images = [
self.center_crop(image, crop_size, size=size, input_data_format=input_data_format) for image in images
]
if do_resize: if do_resize:
images = [self.resize(image=image, size=size, resample=resample) for image in images] images = [
self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
for image in images
]
if do_rescale: if do_rescale:
images = [self.rescale(image=image, scale=rescale_factor) for image in images] images = [
self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
for image in images
]
if do_normalize: if do_normalize:
images = [self.normalize(image=image, mean=image_mean, std=image_std) for image in images] images = [
self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
images = [to_channel_dimension_format(image, data_format) for image in images] for image in images
]
images = [
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
]
data = {"pixel_values": images} data = {"pixel_values": images}
return BatchFeature(data=data, tensor_type=return_tensors) return BatchFeature(data=data, tensor_type=return_tensors)
...@@ -157,7 +157,9 @@ def render_text( ...@@ -157,7 +157,9 @@ def render_text(
# Adapted from https://github.com/google-research/pix2struct/blob/0e1779af0f4db4b652c1d92b3bbd2550a7399123/pix2struct/preprocessing/preprocessing_utils.py#L87 # Adapted from https://github.com/google-research/pix2struct/blob/0e1779af0f4db4b652c1d92b3bbd2550a7399123/pix2struct/preprocessing/preprocessing_utils.py#L87
def render_header(image: np.ndarray, header: str, **kwargs): def render_header(
image: np.ndarray, header: str, input_data_format: Optional[Union[str, ChildProcessError]] = None, **kwargs
):
""" """
Renders the input text as a header on the input image. Renders the input text as a header on the input image.
...@@ -176,7 +178,7 @@ def render_header(image: np.ndarray, header: str, **kwargs): ...@@ -176,7 +178,7 @@ def render_header(image: np.ndarray, header: str, **kwargs):
requires_backends(render_header, "vision") requires_backends(render_header, "vision")
# Convert to PIL image if necessary # Convert to PIL image if necessary
image = to_pil_image(image) image = to_pil_image(image, input_data_format=input_data_format)
header_image = render_text(header, **kwargs) header_image = render_text(header, **kwargs)
new_width = max(header_image.width, image.width) new_width = max(header_image.width, image.width)
...@@ -236,7 +238,14 @@ class Pix2StructImageProcessor(BaseImageProcessor): ...@@ -236,7 +238,14 @@ class Pix2StructImageProcessor(BaseImageProcessor):
self.max_patches = max_patches self.max_patches = max_patches
self.is_vqa = is_vqa self.is_vqa = is_vqa
def extract_flattened_patches(self, image: np.ndarray, max_patches: int, patch_size: dict, **kwargs) -> np.ndarray: def extract_flattened_patches(
self,
image: np.ndarray,
max_patches: int,
patch_size: dict,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs,
) -> np.ndarray:
""" """
Extract flattened patches from an image. Extract flattened patches from an image.
...@@ -256,11 +265,11 @@ class Pix2StructImageProcessor(BaseImageProcessor): ...@@ -256,11 +265,11 @@ class Pix2StructImageProcessor(BaseImageProcessor):
_check_torch_version() _check_torch_version()
# convert to torch # convert to torch
image = to_channel_dimension_format(image, ChannelDimension.FIRST) image = to_channel_dimension_format(image, ChannelDimension.FIRST, input_data_format)
image = torch.from_numpy(image) image = torch.from_numpy(image)
patch_height, patch_width = patch_size["height"], patch_size["width"] patch_height, patch_width = patch_size["height"], patch_size["width"]
image_height, image_width = get_image_size(image) image_height, image_width = get_image_size(image, ChannelDimension.FIRST)
# maximize scale s.t. # maximize scale s.t.
scale = math.sqrt(max_patches * (patch_height / image_height) * (patch_width / image_width)) scale = math.sqrt(max_patches * (patch_height / image_height) * (patch_width / image_width))
...@@ -312,7 +321,11 @@ class Pix2StructImageProcessor(BaseImageProcessor): ...@@ -312,7 +321,11 @@ class Pix2StructImageProcessor(BaseImageProcessor):
return result return result
def normalize( def normalize(
self, image: np.ndarray, data_format: Optional[Union[str, ChannelDimension]] = None, **kwargs self,
image: np.ndarray,
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs,
) -> np.ndarray: ) -> np.ndarray:
""" """
Normalize an image. image = (image - image_mean) / image_std. Normalize an image. image = (image - image_mean) / image_std.
...@@ -323,6 +336,11 @@ class Pix2StructImageProcessor(BaseImageProcessor): ...@@ -323,6 +336,11 @@ class Pix2StructImageProcessor(BaseImageProcessor):
Args: Args:
image (`np.ndarray`): image (`np.ndarray`):
Image to normalize. Image to normalize.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format for the output image. If unset, the channel dimension format of the input
image is used.
input_data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred.
""" """
if image.dtype == np.uint8: if image.dtype == np.uint8:
image = image.astype(np.float32) image = image.astype(np.float32)
...@@ -332,7 +350,14 @@ class Pix2StructImageProcessor(BaseImageProcessor): ...@@ -332,7 +350,14 @@ class Pix2StructImageProcessor(BaseImageProcessor):
std = np.std(image) std = np.std(image)
adjusted_stddev = max(std, 1.0 / math.sqrt(np.prod(image.shape))) adjusted_stddev = max(std, 1.0 / math.sqrt(np.prod(image.shape)))
return normalize(image, mean=mean, std=adjusted_stddev, **kwargs) return normalize(
image,
mean=mean,
std=adjusted_stddev,
data_format=data_format,
input_data_format=input_data_format,
**kwargs,
)
def preprocess( def preprocess(
self, self,
...@@ -344,6 +369,7 @@ class Pix2StructImageProcessor(BaseImageProcessor): ...@@ -344,6 +369,7 @@ class Pix2StructImageProcessor(BaseImageProcessor):
patch_size: Optional[Dict[str, int]] = None, patch_size: Optional[Dict[str, int]] = None,
return_tensors: Optional[Union[str, TensorType]] = None, return_tensors: Optional[Union[str, TensorType]] = None,
data_format: ChannelDimension = ChannelDimension.FIRST, data_format: ChannelDimension = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs, **kwargs,
) -> ImageInput: ) -> ImageInput:
""" """
...@@ -374,6 +400,17 @@ class Pix2StructImageProcessor(BaseImageProcessor): ...@@ -374,6 +400,17 @@ class Pix2StructImageProcessor(BaseImageProcessor):
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
The channel dimension format for the output image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- Unset: Use the channel dimension format of the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
""" """
do_normalize = do_normalize if do_normalize is not None else self.do_normalize do_normalize = do_normalize if do_normalize is not None else self.do_normalize
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
...@@ -399,6 +436,10 @@ class Pix2StructImageProcessor(BaseImageProcessor): ...@@ -399,6 +436,10 @@ class Pix2StructImageProcessor(BaseImageProcessor):
# All transformations expect numpy arrays. # All transformations expect numpy arrays.
images = [to_numpy_array(image) for image in images] images = [to_numpy_array(image) for image in images]
if input_data_format is None:
# We assume that all images have the same channel dimension format.
input_data_format = infer_channel_dimension_format(images[0])
if is_vqa: if is_vqa:
if header_text is None: if header_text is None:
raise ValueError("A header text must be provided for VQA models.") raise ValueError("A header text must be provided for VQA models.")
...@@ -414,11 +455,13 @@ class Pix2StructImageProcessor(BaseImageProcessor): ...@@ -414,11 +455,13 @@ class Pix2StructImageProcessor(BaseImageProcessor):
] ]
if do_normalize: if do_normalize:
images = [self.normalize(image=image) for image in images] images = [self.normalize(image=image, input_data_format=input_data_format) for image in images]
# convert to torch tensor and permute # convert to torch tensor and permute
images = [ images = [
self.extract_flattened_patches(image=image, max_patches=max_patches, patch_size=patch_size) self.extract_flattened_patches(
image=image, max_patches=max_patches, patch_size=patch_size, input_data_format=input_data_format
)
for image in images for image in images
] ]
......
...@@ -30,6 +30,7 @@ from ...image_utils import ( ...@@ -30,6 +30,7 @@ from ...image_utils import (
ChannelDimension, ChannelDimension,
ImageInput, ImageInput,
PILImageResampling, PILImageResampling,
infer_channel_dimension_format,
make_list_of_images, make_list_of_images,
to_numpy_array, to_numpy_array,
valid_images, valid_images,
...@@ -137,6 +138,7 @@ class PoolFormerImageProcessor(BaseImageProcessor): ...@@ -137,6 +138,7 @@ class PoolFormerImageProcessor(BaseImageProcessor):
crop_pct: Optional[float] = None, crop_pct: Optional[float] = None,
resample: PILImageResampling = PILImageResampling.BICUBIC, resample: PILImageResampling = PILImageResampling.BICUBIC,
data_format: Optional[Union[str, ChannelDimension]] = None, data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs, **kwargs,
) -> np.ndarray: ) -> np.ndarray:
""" """
...@@ -166,6 +168,8 @@ class PoolFormerImageProcessor(BaseImageProcessor): ...@@ -166,6 +168,8 @@ class PoolFormerImageProcessor(BaseImageProcessor):
Resampling filter to use when resizing the image. Resampling filter to use when resizing the image.
data_format (`str` or `ChannelDimension`, *optional*): data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image. The channel dimension format of the image. If not provided, it will be the same as the input image.
input_data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred.
""" """
size = get_size_dict(size, default_to_square=False) size = get_size_dict(size, default_to_square=False)
if "shortest_edge" not in size and ("height" not in size or "width" not in size): if "shortest_edge" not in size and ("height" not in size or "width" not in size):
...@@ -181,16 +185,27 @@ class PoolFormerImageProcessor(BaseImageProcessor): ...@@ -181,16 +185,27 @@ class PoolFormerImageProcessor(BaseImageProcessor):
else: else:
raise ValueError("Invalid size for resize: {}".format(size)) raise ValueError("Invalid size for resize: {}".format(size))
output_size = get_resize_output_image_size(image, size=scale_size, default_to_square=False) output_size = get_resize_output_image_size(
image, size=scale_size, default_to_square=False, input_data_format=input_data_format
)
else: else:
if "shortest_edge" in size: if "shortest_edge" in size:
output_size = get_resize_output_image_size(image, size=size["shortest_edge"], default_to_square=False) output_size = get_resize_output_image_size(
image, size=size["shortest_edge"], default_to_square=False, input_data_format=input_data_format
)
elif "height" in size and "width" in size: elif "height" in size and "width" in size:
output_size = (size["height"], size["width"]) output_size = (size["height"], size["width"])
else: else:
raise ValueError("Invalid size for resize: {}".format(size)) raise ValueError("Invalid size for resize: {}".format(size))
return resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs) return resize(
image,
size=output_size,
resample=resample,
data_format=data_format,
input_data_format=input_data_format,
**kwargs,
)
def preprocess( def preprocess(
self, self,
...@@ -208,6 +223,7 @@ class PoolFormerImageProcessor(BaseImageProcessor): ...@@ -208,6 +223,7 @@ class PoolFormerImageProcessor(BaseImageProcessor):
image_std: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None,
return_tensors: Optional[Union[str, TensorType]] = None, return_tensors: Optional[Union[str, TensorType]] = None,
data_format: ChannelDimension = ChannelDimension.FIRST, data_format: ChannelDimension = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs, **kwargs,
) -> PIL.Image.Image: ) -> PIL.Image.Image:
""" """
...@@ -250,6 +266,12 @@ class PoolFormerImageProcessor(BaseImageProcessor): ...@@ -250,6 +266,12 @@ class PoolFormerImageProcessor(BaseImageProcessor):
The channel dimension format for the output image. Can be one of: The channel dimension format for the output image. Can be one of:
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `ChannelDimension.LAST`: image in (height, width, num_channels) format. - `ChannelDimension.LAST`: image in (height, width, num_channels) format.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
""" """
do_resize = do_resize if do_resize is not None else self.do_resize do_resize = do_resize if do_resize is not None else self.do_resize
crop_pct = crop_pct if crop_pct is not None else self.crop_pct crop_pct = crop_pct if crop_pct is not None else self.crop_pct
...@@ -289,19 +311,38 @@ class PoolFormerImageProcessor(BaseImageProcessor): ...@@ -289,19 +311,38 @@ class PoolFormerImageProcessor(BaseImageProcessor):
# All transformations expect numpy arrays. # All transformations expect numpy arrays.
images = [to_numpy_array(image) for image in images] images = [to_numpy_array(image) for image in images]
if input_data_format is None:
# We assume that all images have the same channel dimension format.
input_data_format = infer_channel_dimension_format(images[0])
if do_resize: if do_resize:
images = [self.resize(image=image, size=size, crop_pct=crop_pct, resample=resample) for image in images] images = [
self.resize(
image=image, size=size, crop_pct=crop_pct, resample=resample, input_data_format=input_data_format
)
for image in images
]
if do_center_crop: if do_center_crop:
images = [self.center_crop(image=image, size=crop_size) for image in images] images = [
self.center_crop(image=image, size=crop_size, input_data_format=input_data_format) for image in images
]
if do_rescale: if do_rescale:
images = [self.rescale(image=image, scale=rescale_factor) for image in images] images = [
self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
for image in images
]
if do_normalize: if do_normalize:
images = [self.normalize(image=image, mean=image_mean, std=image_std) for image in images] images = [
self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
images = [to_channel_dimension_format(image, data_format) for image in images] for image in images
]
images = [
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
]
data = {"pixel_values": images} data = {"pixel_values": images}
return BatchFeature(data=data, tensor_type=return_tensors) return BatchFeature(data=data, tensor_type=return_tensors)
...@@ -26,6 +26,7 @@ from ...image_utils import ( ...@@ -26,6 +26,7 @@ from ...image_utils import (
ChannelDimension, ChannelDimension,
ImageInput, ImageInput,
PILImageResampling, PILImageResampling,
infer_channel_dimension_format,
make_list_of_images, make_list_of_images,
to_numpy_array, to_numpy_array,
valid_images, valid_images,
...@@ -100,6 +101,7 @@ class PvtImageProcessor(BaseImageProcessor): ...@@ -100,6 +101,7 @@ class PvtImageProcessor(BaseImageProcessor):
size: Dict[str, int], size: Dict[str, int],
resample: PILImageResampling = PILImageResampling.BILINEAR, resample: PILImageResampling = PILImageResampling.BILINEAR,
data_format: Optional[Union[str, ChannelDimension]] = None, data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs, **kwargs,
) -> np.ndarray: ) -> np.ndarray:
""" """
...@@ -117,6 +119,13 @@ class PvtImageProcessor(BaseImageProcessor): ...@@ -117,6 +119,13 @@ class PvtImageProcessor(BaseImageProcessor):
image is used. Can be one of: image is used. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
Returns: Returns:
`np.ndarray`: The resized image. `np.ndarray`: The resized image.
...@@ -125,7 +134,14 @@ class PvtImageProcessor(BaseImageProcessor): ...@@ -125,7 +134,14 @@ class PvtImageProcessor(BaseImageProcessor):
if "height" not in size or "width" not in size: if "height" not in size or "width" not in size:
raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}") raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}")
output_size = (size["height"], size["width"]) output_size = (size["height"], size["width"])
return resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs) return resize(
image,
size=output_size,
resample=resample,
data_format=data_format,
input_data_format=input_data_format,
**kwargs,
)
def preprocess( def preprocess(
self, self,
...@@ -140,6 +156,7 @@ class PvtImageProcessor(BaseImageProcessor): ...@@ -140,6 +156,7 @@ class PvtImageProcessor(BaseImageProcessor):
image_std: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None,
return_tensors: Optional[Union[str, TensorType]] = None, return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST, data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs, **kwargs,
): ):
""" """
...@@ -178,6 +195,12 @@ class PvtImageProcessor(BaseImageProcessor): ...@@ -178,6 +195,12 @@ class PvtImageProcessor(BaseImageProcessor):
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- Unset: Use the channel dimension format of the input image. - Unset: Use the channel dimension format of the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
""" """
do_resize = do_resize if do_resize is not None else self.do_resize do_resize = do_resize if do_resize is not None else self.do_resize
do_rescale = do_rescale if do_rescale is not None else self.do_rescale do_rescale = do_rescale if do_rescale is not None else self.do_rescale
...@@ -207,16 +230,31 @@ class PvtImageProcessor(BaseImageProcessor): ...@@ -207,16 +230,31 @@ class PvtImageProcessor(BaseImageProcessor):
# All transformations expect numpy arrays. # All transformations expect numpy arrays.
images = [to_numpy_array(image) for image in images] images = [to_numpy_array(image) for image in images]
if input_data_format is None:
# We assume that all images have the same channel dimension format.
input_data_format = infer_channel_dimension_format(images[0])
if do_resize: if do_resize:
images = [self.resize(image=image, size=size_dict, resample=resample) for image in images] images = [
self.resize(image=image, size=size_dict, resample=resample, input_data_format=input_data_format)
for image in images
]
if do_rescale: if do_rescale:
images = [self.rescale(image=image, scale=rescale_factor) for image in images] images = [
self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
for image in images
]
if do_normalize: if do_normalize:
images = [self.normalize(image=image, mean=image_mean, std=image_std) for image in images] images = [
self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
for image in images
]
images = [to_channel_dimension_format(image, data_format) for image in images] images = [
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
]
data = {"pixel_values": images} data = {"pixel_values": images}
return BatchFeature(data=data, tensor_type=return_tensors) return BatchFeature(data=data, tensor_type=return_tensors)
...@@ -143,6 +143,7 @@ class SamImageProcessor(BaseImageProcessor): ...@@ -143,6 +143,7 @@ class SamImageProcessor(BaseImageProcessor):
image: np.ndarray, image: np.ndarray,
pad_size: Dict[str, int], pad_size: Dict[str, int],
data_format: Optional[Union[str, ChannelDimension]] = None, data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs, **kwargs,
) -> np.ndarray: ) -> np.ndarray:
""" """
...@@ -156,14 +157,22 @@ class SamImageProcessor(BaseImageProcessor): ...@@ -156,14 +157,22 @@ class SamImageProcessor(BaseImageProcessor):
data_format (`str` or `ChannelDimension`, *optional*): data_format (`str` or `ChannelDimension`, *optional*):
The data format of the image. Can be either "channels_first" or "channels_last". If `None`, the The data format of the image. Can be either "channels_first" or "channels_last". If `None`, the
`data_format` of the `image` will be used. `data_format` of the `image` will be used.
input_data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred.
""" """
output_height, output_width = pad_size["height"], pad_size["width"] output_height, output_width = pad_size["height"], pad_size["width"]
input_height, input_width = get_image_size(image) input_height, input_width = get_image_size(image, channel_dim=input_data_format)
pad_width = output_width - input_width pad_width = output_width - input_width
pad_height = output_height - input_height pad_height = output_height - input_height
padded_image = pad(image, ((0, pad_height), (0, pad_width)), data_format=data_format, **kwargs) padded_image = pad(
image,
((0, pad_height), (0, pad_width)),
data_format=data_format,
input_data_format=input_data_format,
**kwargs,
)
return padded_image return padded_image
def _get_preprocess_shape(self, old_shape: Tuple[int, int], longest_edge: int): def _get_preprocess_shape(self, old_shape: Tuple[int, int], longest_edge: int):
...@@ -183,6 +192,7 @@ class SamImageProcessor(BaseImageProcessor): ...@@ -183,6 +192,7 @@ class SamImageProcessor(BaseImageProcessor):
size: Dict[str, int], size: Dict[str, int],
resample: PILImageResampling = PILImageResampling.BICUBIC, resample: PILImageResampling = PILImageResampling.BICUBIC,
data_format: Optional[Union[str, ChannelDimension]] = None, data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs, **kwargs,
) -> np.ndarray: ) -> np.ndarray:
""" """
...@@ -202,15 +212,28 @@ class SamImageProcessor(BaseImageProcessor): ...@@ -202,15 +212,28 @@ class SamImageProcessor(BaseImageProcessor):
image is used. Can be one of: image is used. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
Returns: Returns:
`np.ndarray`: The resized image. `np.ndarray`: The resized image.
""" """
size = get_size_dict(size) size = get_size_dict(size)
if "longest_edge" not in size: if "longest_edge" not in size:
raise ValueError(f"The `size` dictionary must contain the key `longest_edge`. Got {size.keys()}") raise ValueError(f"The `size` dictionary must contain the key `longest_edge`. Got {size.keys()}")
input_size = get_image_size(image) input_size = get_image_size(image, channel_dim=input_data_format)
output_height, output_width = self._get_preprocess_shape(input_size, size["longest_edge"]) output_height, output_width = self._get_preprocess_shape(input_size, size["longest_edge"])
return resize(image, size=(output_height, output_width), resample=resample, data_format=data_format, **kwargs) return resize(
image,
size=(output_height, output_width),
resample=resample,
data_format=data_format,
input_data_format=input_data_format,
**kwargs,
)
def preprocess( def preprocess(
self, self,
...@@ -228,6 +251,7 @@ class SamImageProcessor(BaseImageProcessor): ...@@ -228,6 +251,7 @@ class SamImageProcessor(BaseImageProcessor):
do_convert_rgb: bool = None, do_convert_rgb: bool = None,
return_tensors: Optional[Union[str, TensorType]] = None, return_tensors: Optional[Union[str, TensorType]] = None,
data_format: ChannelDimension = ChannelDimension.FIRST, data_format: ChannelDimension = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs, **kwargs,
): ):
""" """
...@@ -272,6 +296,12 @@ class SamImageProcessor(BaseImageProcessor): ...@@ -272,6 +296,12 @@ class SamImageProcessor(BaseImageProcessor):
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- Unset: Use the channel dimension format of the input image. - Unset: Use the channel dimension format of the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
""" """
do_resize = do_resize if do_resize is not None else self.do_resize do_resize = do_resize if do_resize is not None else self.do_resize
size = size if size is not None else self.size size = size if size is not None else self.size
...@@ -314,23 +344,40 @@ class SamImageProcessor(BaseImageProcessor): ...@@ -314,23 +344,40 @@ class SamImageProcessor(BaseImageProcessor):
# All transformations expect numpy arrays. # All transformations expect numpy arrays.
images = [to_numpy_array(image) for image in images] images = [to_numpy_array(image) for image in images]
original_sizes = [get_image_size(image) for image in images] if input_data_format is None:
# We assume that all images have the same channel dimension format.
input_data_format = infer_channel_dimension_format(images[0])
original_sizes = [get_image_size(image, channel_dim=input_data_format) for image in images]
if do_resize: if do_resize:
images = [self.resize(image=image, size=size, resample=resample) for image in images] images = [
self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
for image in images
]
reshaped_input_sizes = [get_image_size(image) for image in images] reshaped_input_sizes = [get_image_size(image, channel_dim=input_data_format) for image in images]
if do_rescale: if do_rescale:
images = [self.rescale(image=image, scale=rescale_factor) for image in images] images = [
self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
for image in images
]
if do_normalize: if do_normalize:
images = [self.normalize(image=image, mean=image_mean, std=image_std) for image in images] images = [
self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
for image in images
]
if do_pad: if do_pad:
images = [self.pad_image(image=image, pad_size=pad_size) for image in images] images = [
self.pad_image(image=image, pad_size=pad_size, input_data_format=input_data_format) for image in images
]
images = [to_channel_dimension_format(image, data_format) for image in images] images = [
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
]
encoded_outputs = BatchFeature( encoded_outputs = BatchFeature(
data={ data={
"pixel_values": images, "pixel_values": images,
...@@ -517,6 +564,7 @@ class SamImageProcessor(BaseImageProcessor): ...@@ -517,6 +564,7 @@ class SamImageProcessor(BaseImageProcessor):
points_per_crop: Optional[int] = 32, points_per_crop: Optional[int] = 32,
crop_n_points_downscale_factor: Optional[List[int]] = 1, crop_n_points_downscale_factor: Optional[List[int]] = 1,
device: Optional["torch.device"] = None, device: Optional["torch.device"] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
return_tensors: str = "pt", return_tensors: str = "pt",
): ):
""" """
...@@ -539,6 +587,8 @@ class SamImageProcessor(BaseImageProcessor): ...@@ -539,6 +587,8 @@ class SamImageProcessor(BaseImageProcessor):
The number of points-per-side sampled in layer n is scaled down by crop_n_points_downscale_factor**n. The number of points-per-side sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
device (`torch.device`, *optional*, defaults to None): device (`torch.device`, *optional*, defaults to None):
Device to use for the computation. If None, cpu will be used. Device to use for the computation. If None, cpu will be used.
input_data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred.
return_tensors (`str`, *optional*, defaults to `pt`): return_tensors (`str`, *optional*, defaults to `pt`):
If `pt`, returns `torch.Tensor`. If `tf`, returns `tf.Tensor`. If `pt`, returns `torch.Tensor`. If `tf`, returns `tf.Tensor`.
""" """
...@@ -549,6 +599,7 @@ class SamImageProcessor(BaseImageProcessor): ...@@ -549,6 +599,7 @@ class SamImageProcessor(BaseImageProcessor):
overlap_ratio, overlap_ratio,
points_per_crop, points_per_crop,
crop_n_points_downscale_factor, crop_n_points_downscale_factor,
input_data_format,
) )
if return_tensors == "pt": if return_tensors == "pt":
if device is None: if device is None:
...@@ -855,6 +906,7 @@ def _generate_crop_boxes( ...@@ -855,6 +906,7 @@ def _generate_crop_boxes(
overlap_ratio: float = 512 / 1500, overlap_ratio: float = 512 / 1500,
points_per_crop: Optional[int] = 32, points_per_crop: Optional[int] = 32,
crop_n_points_downscale_factor: Optional[List[int]] = 1, crop_n_points_downscale_factor: Optional[List[int]] = 1,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> Tuple[List[List[int]], List[int]]: ) -> Tuple[List[List[int]], List[int]]:
""" """
Generates a list of crop boxes of different sizes. Each layer has (2**i)**2 boxes for the ith layer. Generates a list of crop boxes of different sizes. Each layer has (2**i)**2 boxes for the ith layer.
...@@ -874,12 +926,14 @@ def _generate_crop_boxes( ...@@ -874,12 +926,14 @@ def _generate_crop_boxes(
Number of points to sample per crop. Number of points to sample per crop.
crop_n_points_downscale_factor (`int`, *optional*): crop_n_points_downscale_factor (`int`, *optional*):
The number of points-per-side sampled in layer n is scaled down by crop_n_points_downscale_factor**n. The number of points-per-side sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
input_data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred.
""" """
if isinstance(image, list): if isinstance(image, list):
raise ValueError("Only one image is allowed for crop generation.") raise ValueError("Only one image is allowed for crop generation.")
image = to_numpy_array(image) image = to_numpy_array(image)
original_size = get_image_size(image) original_size = get_image_size(image, input_data_format)
points_grid = [] points_grid = []
for i in range(crop_n_layers + 1): for i in range(crop_n_layers + 1):
...@@ -889,7 +943,7 @@ def _generate_crop_boxes( ...@@ -889,7 +943,7 @@ def _generate_crop_boxes(
crop_boxes, layer_idxs = _generate_per_layer_crops(crop_n_layers, overlap_ratio, original_size) crop_boxes, layer_idxs = _generate_per_layer_crops(crop_n_layers, overlap_ratio, original_size)
cropped_images, point_grid_per_crop = _generate_crop_images( cropped_images, point_grid_per_crop = _generate_crop_images(
crop_boxes, image, points_grid, layer_idxs, target_size, original_size crop_boxes, image, points_grid, layer_idxs, target_size, original_size, input_data_format
) )
crop_boxes = np.array(crop_boxes) crop_boxes = np.array(crop_boxes)
crop_boxes = crop_boxes.astype(np.float32) crop_boxes = crop_boxes.astype(np.float32)
...@@ -935,7 +989,9 @@ def _generate_per_layer_crops(crop_n_layers, overlap_ratio, original_size): ...@@ -935,7 +989,9 @@ def _generate_per_layer_crops(crop_n_layers, overlap_ratio, original_size):
return crop_boxes, layer_idxs return crop_boxes, layer_idxs
def _generate_crop_images(crop_boxes, image, points_grid, layer_idxs, target_size, original_size): def _generate_crop_images(
crop_boxes, image, points_grid, layer_idxs, target_size, original_size, input_data_format=None
):
""" """
Takes as an input bounding boxes that are used to crop the image. Based in the crops, the corresponding points are Takes as an input bounding boxes that are used to crop the image. Based in the crops, the corresponding points are
also passed. also passed.
...@@ -945,7 +1001,7 @@ def _generate_crop_images(crop_boxes, image, points_grid, layer_idxs, target_siz ...@@ -945,7 +1001,7 @@ def _generate_crop_images(crop_boxes, image, points_grid, layer_idxs, target_siz
for i, crop_box in enumerate(crop_boxes): for i, crop_box in enumerate(crop_boxes):
left, top, right, bottom = crop_box left, top, right, bottom = crop_box
channel_dim = infer_channel_dimension_format(image) channel_dim = infer_channel_dimension_format(image, input_data_format)
if channel_dim == ChannelDimension.LAST: if channel_dim == ChannelDimension.LAST:
cropped_im = image[top:bottom, left:right, :] cropped_im = image[top:bottom, left:right, :]
else: else:
...@@ -953,7 +1009,7 @@ def _generate_crop_images(crop_boxes, image, points_grid, layer_idxs, target_siz ...@@ -953,7 +1009,7 @@ def _generate_crop_images(crop_boxes, image, points_grid, layer_idxs, target_siz
cropped_images.append(cropped_im) cropped_images.append(cropped_im)
cropped_im_size = get_image_size(cropped_im) cropped_im_size = get_image_size(cropped_im, channel_dim)
points_scale = np.array(cropped_im_size)[None, ::-1] points_scale = np.array(cropped_im_size)[None, ::-1]
points = points_grid[layer_idxs[i]] * points_scale points = points_grid[layer_idxs[i]] * points_scale
......
...@@ -27,6 +27,7 @@ from ...image_utils import ( ...@@ -27,6 +27,7 @@ from ...image_utils import (
ChannelDimension, ChannelDimension,
ImageInput, ImageInput,
PILImageResampling, PILImageResampling,
infer_channel_dimension_format,
make_list_of_images, make_list_of_images,
to_numpy_array, to_numpy_array,
valid_images, valid_images,
...@@ -135,6 +136,7 @@ class SegformerImageProcessor(BaseImageProcessor): ...@@ -135,6 +136,7 @@ class SegformerImageProcessor(BaseImageProcessor):
size: Dict[str, int], size: Dict[str, int],
resample: PILImageResampling = PILImageResampling.BILINEAR, resample: PILImageResampling = PILImageResampling.BILINEAR,
data_format: Optional[Union[str, ChannelDimension]] = None, data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs, **kwargs,
) -> np.ndarray: ) -> np.ndarray:
""" """
...@@ -152,6 +154,13 @@ class SegformerImageProcessor(BaseImageProcessor): ...@@ -152,6 +154,13 @@ class SegformerImageProcessor(BaseImageProcessor):
image is used. Can be one of: image is used. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
Returns: Returns:
`np.ndarray`: The resized image. `np.ndarray`: The resized image.
...@@ -160,7 +169,14 @@ class SegformerImageProcessor(BaseImageProcessor): ...@@ -160,7 +169,14 @@ class SegformerImageProcessor(BaseImageProcessor):
if "height" not in size or "width" not in size: if "height" not in size or "width" not in size:
raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}") raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}")
output_size = (size["height"], size["width"]) output_size = (size["height"], size["width"])
return resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs) return resize(
image,
size=output_size,
resample=resample,
data_format=data_format,
input_data_format=input_data_format,
**kwargs,
)
# Copied from transformers.models.beit.image_processing_beit.BeitImageProcessor.reduce_label # Copied from transformers.models.beit.image_processing_beit.BeitImageProcessor.reduce_label
def reduce_label(self, label: ImageInput) -> np.ndarray: def reduce_label(self, label: ImageInput) -> np.ndarray:
...@@ -183,18 +199,19 @@ class SegformerImageProcessor(BaseImageProcessor): ...@@ -183,18 +199,19 @@ class SegformerImageProcessor(BaseImageProcessor):
rescale_factor: Optional[float] = None, rescale_factor: Optional[float] = None,
image_mean: Optional[Union[float, List[float]]] = None, image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
): ):
if do_reduce_labels: if do_reduce_labels:
image = self.reduce_label(image) image = self.reduce_label(image)
if do_resize: if do_resize:
image = self.resize(image=image, size=size, resample=resample) image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
if do_rescale: if do_rescale:
image = self.rescale(image=image, scale=rescale_factor) image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
if do_normalize: if do_normalize:
image = self.normalize(image=image, mean=image_mean, std=image_std) image = self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
return image return image
...@@ -210,10 +227,13 @@ class SegformerImageProcessor(BaseImageProcessor): ...@@ -210,10 +227,13 @@ class SegformerImageProcessor(BaseImageProcessor):
image_mean: Optional[Union[float, List[float]]] = None, image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None,
data_format: Optional[Union[str, ChannelDimension]] = None, data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray: ) -> np.ndarray:
"""Preprocesses a single image.""" """Preprocesses a single image."""
# All transformations expect numpy arrays. # All transformations expect numpy arrays.
image = to_numpy_array(image) image = to_numpy_array(image)
if input_data_format is None:
input_data_format = infer_channel_dimension_format(image)
image = self._preprocess( image = self._preprocess(
image=image, image=image,
do_reduce_labels=False, do_reduce_labels=False,
...@@ -225,9 +245,10 @@ class SegformerImageProcessor(BaseImageProcessor): ...@@ -225,9 +245,10 @@ class SegformerImageProcessor(BaseImageProcessor):
do_normalize=do_normalize, do_normalize=do_normalize,
image_mean=image_mean, image_mean=image_mean,
image_std=image_std, image_std=image_std,
input_data_format=input_data_format,
) )
if data_format is not None: if data_format is not None:
image = to_channel_dimension_format(image, data_format) image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
return image return image
def _preprocess_mask( def _preprocess_mask(
...@@ -236,14 +257,19 @@ class SegformerImageProcessor(BaseImageProcessor): ...@@ -236,14 +257,19 @@ class SegformerImageProcessor(BaseImageProcessor):
do_reduce_labels: bool = None, do_reduce_labels: bool = None,
do_resize: bool = None, do_resize: bool = None,
size: Dict[str, int] = None, size: Dict[str, int] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray: ) -> np.ndarray:
"""Preprocesses a single mask.""" """Preprocesses a single mask."""
segmentation_map = to_numpy_array(segmentation_map) segmentation_map = to_numpy_array(segmentation_map)
# Add channel dimension if missing - needed for certain transformations # Add channel dimension if missing - needed for certain transformations
added_channel_dim = False
if segmentation_map.ndim == 2: if segmentation_map.ndim == 2:
added_channel_dim = True added_channel_dim = True
segmentation_map = segmentation_map[None, ...] segmentation_map = segmentation_map[None, ...]
input_data_format = ChannelDimension.FIRST
else:
added_channel_dim = False
if input_data_format is None:
input_data_format = infer_channel_dimension_format(segmentation_map, num_channels=1)
# reduce zero label if needed # reduce zero label if needed
segmentation_map = self._preprocess( segmentation_map = self._preprocess(
image=segmentation_map, image=segmentation_map,
...@@ -253,6 +279,7 @@ class SegformerImageProcessor(BaseImageProcessor): ...@@ -253,6 +279,7 @@ class SegformerImageProcessor(BaseImageProcessor):
size=size, size=size,
do_rescale=False, do_rescale=False,
do_normalize=False, do_normalize=False,
input_data_format=input_data_format,
) )
# Remove extra channel dimension if added for processing # Remove extra channel dimension if added for processing
if added_channel_dim: if added_channel_dim:
...@@ -284,6 +311,7 @@ class SegformerImageProcessor(BaseImageProcessor): ...@@ -284,6 +311,7 @@ class SegformerImageProcessor(BaseImageProcessor):
do_reduce_labels: Optional[bool] = None, do_reduce_labels: Optional[bool] = None,
return_tensors: Optional[Union[str, TensorType]] = None, return_tensors: Optional[Union[str, TensorType]] = None,
data_format: ChannelDimension = ChannelDimension.FIRST, data_format: ChannelDimension = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs, **kwargs,
) -> PIL.Image.Image: ) -> PIL.Image.Image:
""" """
...@@ -326,6 +354,12 @@ class SegformerImageProcessor(BaseImageProcessor): ...@@ -326,6 +354,12 @@ class SegformerImageProcessor(BaseImageProcessor):
The channel dimension format for the output image. Can be one of: The channel dimension format for the output image. Can be one of:
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `ChannelDimension.LAST`: image in (height, width, num_channels) format. - `ChannelDimension.LAST`: image in (height, width, num_channels) format.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
""" """
do_resize = do_resize if do_resize is not None else self.do_resize do_resize = do_resize if do_resize is not None else self.do_resize
do_rescale = do_rescale if do_rescale is not None else self.do_rescale do_rescale = do_rescale if do_rescale is not None else self.do_rescale
...@@ -374,6 +408,7 @@ class SegformerImageProcessor(BaseImageProcessor): ...@@ -374,6 +408,7 @@ class SegformerImageProcessor(BaseImageProcessor):
image_mean=image_mean, image_mean=image_mean,
image_std=image_std, image_std=image_std,
data_format=data_format, data_format=data_format,
input_data_format=input_data_format,
) )
for img in images for img in images
] ]
...@@ -387,6 +422,7 @@ class SegformerImageProcessor(BaseImageProcessor): ...@@ -387,6 +422,7 @@ class SegformerImageProcessor(BaseImageProcessor):
do_reduce_labels=do_reduce_labels, do_reduce_labels=do_reduce_labels,
do_resize=do_resize, do_resize=do_resize,
size=size, size=size,
input_data_format=input_data_format,
) )
for segmentation_map in segmentation_maps for segmentation_map in segmentation_maps
] ]
......
...@@ -20,7 +20,14 @@ import numpy as np ...@@ -20,7 +20,14 @@ import numpy as np
from ...image_processing_utils import BaseImageProcessor, BatchFeature from ...image_processing_utils import BaseImageProcessor, BatchFeature
from ...image_transforms import get_image_size, pad, to_channel_dimension_format from ...image_transforms import get_image_size, pad, to_channel_dimension_format
from ...image_utils import ChannelDimension, ImageInput, make_list_of_images, to_numpy_array, valid_images from ...image_utils import (
ChannelDimension,
ImageInput,
infer_channel_dimension_format,
make_list_of_images,
to_numpy_array,
valid_images,
)
from ...utils import TensorType, logging from ...utils import TensorType, logging
...@@ -57,7 +64,13 @@ class Swin2SRImageProcessor(BaseImageProcessor): ...@@ -57,7 +64,13 @@ class Swin2SRImageProcessor(BaseImageProcessor):
self.do_pad = do_pad self.do_pad = do_pad
self.pad_size = pad_size self.pad_size = pad_size
def pad(self, image: np.ndarray, size: int, data_format: Optional[Union[str, ChannelDimension]] = None): def pad(
self,
image: np.ndarray,
size: int,
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
):
""" """
Pad an image to make the height and width divisible by `size`. Pad an image to make the height and width divisible by `size`.
...@@ -71,15 +84,26 @@ class Swin2SRImageProcessor(BaseImageProcessor): ...@@ -71,15 +84,26 @@ class Swin2SRImageProcessor(BaseImageProcessor):
image is used. Can be one of: image is used. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
input_data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
Returns: Returns:
`np.ndarray`: The padded image. `np.ndarray`: The padded image.
""" """
old_height, old_width = get_image_size(image) old_height, old_width = get_image_size(image, input_data_format)
pad_height = (old_height // size + 1) * size - old_height pad_height = (old_height // size + 1) * size - old_height
pad_width = (old_width // size + 1) * size - old_width pad_width = (old_width // size + 1) * size - old_width
return pad(image, ((0, pad_height), (0, pad_width)), mode="symmetric", data_format=data_format) return pad(
image,
((0, pad_height), (0, pad_width)),
mode="symmetric",
data_format=data_format,
input_data_format=input_data_format,
)
def preprocess( def preprocess(
self, self,
...@@ -90,6 +114,7 @@ class Swin2SRImageProcessor(BaseImageProcessor): ...@@ -90,6 +114,7 @@ class Swin2SRImageProcessor(BaseImageProcessor):
pad_size: Optional[int] = None, pad_size: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None, return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST, data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs, **kwargs,
): ):
""" """
...@@ -104,12 +129,13 @@ class Swin2SRImageProcessor(BaseImageProcessor): ...@@ -104,12 +129,13 @@ class Swin2SRImageProcessor(BaseImageProcessor):
Rescale factor to rescale the image by if `do_rescale` is set to `True`. Rescale factor to rescale the image by if `do_rescale` is set to `True`.
do_pad (`bool`, *optional*, defaults to `True`): do_pad (`bool`, *optional*, defaults to `True`):
Whether to pad the image to make the height and width divisible by `window_size`. Whether to pad the image to make the height and width divisible by `window_size`.
pad_size (`int`, *optional*, defaults to `32`): pad_size (`int`, *optional*, defaults to 32):
The size of the sliding window for the local attention. The size of the sliding window for the local attention.
return_tensors (`str` or `TensorType`, *optional*): return_tensors (`str` or `TensorType`, *optional*):
The type of tensors to return. Can be one of: The type of tensors to return. Can be one of:
- Unset: Return a list of `np.ndarray`. - Unset: Return a list of `np.ndarray`.
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of typ, input_data_format=input_data_formate
`tf.Tensor`.
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
...@@ -118,6 +144,12 @@ class Swin2SRImageProcessor(BaseImageProcessor): ...@@ -118,6 +144,12 @@ class Swin2SRImageProcessor(BaseImageProcessor):
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- Unset: Use the channel dimension format of the input image. - Unset: Use the channel dimension format of the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
""" """
do_rescale = do_rescale if do_rescale is not None else self.do_rescale do_rescale = do_rescale if do_rescale is not None else self.do_rescale
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
...@@ -138,13 +170,22 @@ class Swin2SRImageProcessor(BaseImageProcessor): ...@@ -138,13 +170,22 @@ class Swin2SRImageProcessor(BaseImageProcessor):
# All transformations expect numpy arrays. # All transformations expect numpy arrays.
images = [to_numpy_array(image) for image in images] images = [to_numpy_array(image) for image in images]
if input_data_format is None:
# We assume that all images have the same channel dimension format.
input_data_format = infer_channel_dimension_format(images[0])
if do_rescale: if do_rescale:
images = [self.rescale(image=image, scale=rescale_factor) for image in images] images = [
self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
for image in images
]
if do_pad: if do_pad:
images = [self.pad(image, size=pad_size) for image in images] images = [self.pad(image, size=pad_size, input_data_format=input_data_format) for image in images]
images = [to_channel_dimension_format(image, data_format) for image in images] images = [
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
]
data = {"pixel_values": images} data = {"pixel_values": images}
return BatchFeature(data=data, tensor_type=return_tensors) return BatchFeature(data=data, tensor_type=return_tensors)
...@@ -29,6 +29,7 @@ from ...image_utils import ( ...@@ -29,6 +29,7 @@ from ...image_utils import (
ChannelDimension, ChannelDimension,
ImageInput, ImageInput,
PILImageResampling, PILImageResampling,
infer_channel_dimension_format,
is_valid_image, is_valid_image,
to_numpy_array, to_numpy_array,
valid_images, valid_images,
...@@ -155,6 +156,7 @@ class TvltImageProcessor(BaseImageProcessor): ...@@ -155,6 +156,7 @@ class TvltImageProcessor(BaseImageProcessor):
size: Dict[str, int], size: Dict[str, int],
resample: PILImageResampling = PILImageResampling.BILINEAR, resample: PILImageResampling = PILImageResampling.BILINEAR,
data_format: Optional[Union[str, ChannelDimension]] = None, data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs, **kwargs,
) -> np.ndarray: ) -> np.ndarray:
""" """
...@@ -171,15 +173,26 @@ class TvltImageProcessor(BaseImageProcessor): ...@@ -171,15 +173,26 @@ class TvltImageProcessor(BaseImageProcessor):
Resampling filter to use when resiizing the image. Resampling filter to use when resiizing the image.
data_format (`str` or `ChannelDimension`, *optional*): data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image. The channel dimension format of the image. If not provided, it will be the same as the input image.
input_data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred.
""" """
size = get_size_dict(size, default_to_square=False) size = get_size_dict(size, default_to_square=False)
if "shortest_edge" in size: if "shortest_edge" in size:
output_size = get_resize_output_image_size(image, size["shortest_edge"], default_to_square=False) output_size = get_resize_output_image_size(
image, size["shortest_edge"], default_to_square=False, input_data_format=input_data_format
)
elif "height" in size and "width" in size: elif "height" in size and "width" in size:
output_size = (size["height"], size["width"]) output_size = (size["height"], size["width"])
else: else:
raise ValueError(f"Size must have 'height' and 'width' or 'shortest_edge' as keys. Got {size.keys()}") raise ValueError(f"Size must have 'height' and 'width' or 'shortest_edge' as keys. Got {size.keys()}")
return resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs) return resize(
image,
size=output_size,
resample=resample,
data_format=data_format,
input_data_format=input_data_format,
**kwargs,
)
def _preprocess_image( def _preprocess_image(
self, self,
...@@ -195,6 +208,7 @@ class TvltImageProcessor(BaseImageProcessor): ...@@ -195,6 +208,7 @@ class TvltImageProcessor(BaseImageProcessor):
image_mean: Optional[Union[float, List[float]]] = None, image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None,
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray: ) -> np.ndarray:
"""Preprocesses a single image.""" """Preprocesses a single image."""
if do_resize and size is None or resample is None: if do_resize and size is None or resample is None:
...@@ -212,18 +226,21 @@ class TvltImageProcessor(BaseImageProcessor): ...@@ -212,18 +226,21 @@ class TvltImageProcessor(BaseImageProcessor):
# All transformations expect numpy arrays. # All transformations expect numpy arrays.
image = to_numpy_array(image) image = to_numpy_array(image)
if input_data_format is None:
input_data_format = infer_channel_dimension_format(image)
if do_resize: if do_resize:
image = self.resize(image=image, size=size, resample=resample) image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
if do_center_crop: if do_center_crop:
image = self.center_crop(image, size=crop_size) image = self.center_crop(image, size=crop_size, input_data_format=input_data_format)
if do_rescale: if do_rescale:
image = self.rescale(image=image, scale=rescale_factor) image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
if do_normalize: if do_normalize:
image = self.normalize(image=image, mean=image_mean, std=image_std) image = self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
image = to_channel_dimension_format(image, data_format) image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
return image return image
def preprocess( def preprocess(
...@@ -244,6 +261,7 @@ class TvltImageProcessor(BaseImageProcessor): ...@@ -244,6 +261,7 @@ class TvltImageProcessor(BaseImageProcessor):
is_mixed: bool = False, is_mixed: bool = False,
return_tensors: Optional[Union[str, TensorType]] = None, return_tensors: Optional[Union[str, TensorType]] = None,
data_format: ChannelDimension = ChannelDimension.FIRST, data_format: ChannelDimension = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs, **kwargs,
) -> BatchFeature: ) -> BatchFeature:
""" """
...@@ -291,6 +309,12 @@ class TvltImageProcessor(BaseImageProcessor): ...@@ -291,6 +309,12 @@ class TvltImageProcessor(BaseImageProcessor):
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `ChannelDimension.LAST`: image in (height, width, num_channels) format. - `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- Unset: Use the inferred channel dimension format of the input image. - Unset: Use the inferred channel dimension format of the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
Returns: Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields: [`BatchFeature`]: A [`BatchFeature`] with the following fields:
...@@ -361,6 +385,7 @@ class TvltImageProcessor(BaseImageProcessor): ...@@ -361,6 +385,7 @@ class TvltImageProcessor(BaseImageProcessor):
image_mean=image_mean, image_mean=image_mean,
image_std=image_std, image_std=image_std,
data_format=data_format, data_format=data_format,
input_data_format=input_data_format,
) )
for img in video for img in video
] ]
......
...@@ -30,6 +30,7 @@ from ...image_utils import ( ...@@ -30,6 +30,7 @@ from ...image_utils import (
ChannelDimension, ChannelDimension,
ImageInput, ImageInput,
PILImageResampling, PILImageResampling,
infer_channel_dimension_format,
is_valid_image, is_valid_image,
to_numpy_array, to_numpy_array,
valid_images, valid_images,
...@@ -134,6 +135,7 @@ class VideoMAEImageProcessor(BaseImageProcessor): ...@@ -134,6 +135,7 @@ class VideoMAEImageProcessor(BaseImageProcessor):
size: Dict[str, int], size: Dict[str, int],
resample: PILImageResampling = PILImageResampling.BILINEAR, resample: PILImageResampling = PILImageResampling.BILINEAR,
data_format: Optional[Union[str, ChannelDimension]] = None, data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs, **kwargs,
) -> np.ndarray: ) -> np.ndarray:
""" """
...@@ -150,15 +152,26 @@ class VideoMAEImageProcessor(BaseImageProcessor): ...@@ -150,15 +152,26 @@ class VideoMAEImageProcessor(BaseImageProcessor):
Resampling filter to use when resiizing the image. Resampling filter to use when resiizing the image.
data_format (`str` or `ChannelDimension`, *optional*): data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image. The channel dimension format of the image. If not provided, it will be the same as the input image.
input_data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred.
""" """
size = get_size_dict(size, default_to_square=False) size = get_size_dict(size, default_to_square=False)
if "shortest_edge" in size: if "shortest_edge" in size:
output_size = get_resize_output_image_size(image, size["shortest_edge"], default_to_square=False) output_size = get_resize_output_image_size(
image, size["shortest_edge"], default_to_square=False, input_data_format=input_data_format
)
elif "height" in size and "width" in size: elif "height" in size and "width" in size:
output_size = (size["height"], size["width"]) output_size = (size["height"], size["width"])
else: else:
raise ValueError(f"Size must have 'height' and 'width' or 'shortest_edge' as keys. Got {size.keys()}") raise ValueError(f"Size must have 'height' and 'width' or 'shortest_edge' as keys. Got {size.keys()}")
return resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs) return resize(
image,
size=output_size,
resample=resample,
data_format=data_format,
input_data_format=input_data_format,
**kwargs,
)
def _preprocess_image( def _preprocess_image(
self, self,
...@@ -174,6 +187,7 @@ class VideoMAEImageProcessor(BaseImageProcessor): ...@@ -174,6 +187,7 @@ class VideoMAEImageProcessor(BaseImageProcessor):
image_mean: Optional[Union[float, List[float]]] = None, image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None,
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray: ) -> np.ndarray:
"""Preprocesses a single image.""" """Preprocesses a single image."""
if do_resize and size is None or resample is None: if do_resize and size is None or resample is None:
...@@ -191,19 +205,22 @@ class VideoMAEImageProcessor(BaseImageProcessor): ...@@ -191,19 +205,22 @@ class VideoMAEImageProcessor(BaseImageProcessor):
# All transformations expect numpy arrays. # All transformations expect numpy arrays.
image = to_numpy_array(image) image = to_numpy_array(image)
if input_data_format is None:
input_data_format = infer_channel_dimension_format(image)
if do_resize: if do_resize:
image = self.resize(image=image, size=size, resample=resample) image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
if do_center_crop: if do_center_crop:
image = self.center_crop(image, size=crop_size) image = self.center_crop(image, size=crop_size, input_data_format=input_data_format)
if do_rescale: if do_rescale:
image = self.rescale(image=image, scale=rescale_factor) image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
if do_normalize: if do_normalize:
image = self.normalize(image=image, mean=image_mean, std=image_std) image = self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
image = to_channel_dimension_format(image, data_format) image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
return image return image
def preprocess( def preprocess(
...@@ -221,6 +238,7 @@ class VideoMAEImageProcessor(BaseImageProcessor): ...@@ -221,6 +238,7 @@ class VideoMAEImageProcessor(BaseImageProcessor):
image_std: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None,
return_tensors: Optional[Union[str, TensorType]] = None, return_tensors: Optional[Union[str, TensorType]] = None,
data_format: ChannelDimension = ChannelDimension.FIRST, data_format: ChannelDimension = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs, **kwargs,
) -> PIL.Image.Image: ) -> PIL.Image.Image:
""" """
...@@ -262,6 +280,12 @@ class VideoMAEImageProcessor(BaseImageProcessor): ...@@ -262,6 +280,12 @@ class VideoMAEImageProcessor(BaseImageProcessor):
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `ChannelDimension.LAST`: image in (height, width, num_channels) format. - `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- Unset: Use the inferred channel dimension format of the input image. - Unset: Use the inferred channel dimension format of the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
""" """
do_resize = do_resize if do_resize is not None else self.do_resize do_resize = do_resize if do_resize is not None else self.do_resize
resample = resample if resample is not None else self.resample resample = resample if resample is not None else self.resample
...@@ -300,6 +324,7 @@ class VideoMAEImageProcessor(BaseImageProcessor): ...@@ -300,6 +324,7 @@ class VideoMAEImageProcessor(BaseImageProcessor):
image_mean=image_mean, image_mean=image_mean,
image_std=image_std, image_std=image_std,
data_format=data_format, data_format=data_format,
input_data_format=input_data_format,
) )
for img in video for img in video
] ]
......
...@@ -49,7 +49,9 @@ def max_across_indices(values: Iterable[Any]) -> List[Any]: ...@@ -49,7 +49,9 @@ def max_across_indices(values: Iterable[Any]) -> List[Any]:
return [max(values_i) for values_i in zip(*values)] return [max(values_i) for values_i in zip(*values)]
def make_pixel_mask(image: np.ndarray, output_size: Tuple[int, int]) -> np.ndarray: def make_pixel_mask(
image: np.ndarray, output_size: Tuple[int, int], input_data_format: Optional[Union[str, ChannelDimension]] = None
) -> np.ndarray:
""" """
Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding. Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.
...@@ -59,31 +61,38 @@ def make_pixel_mask(image: np.ndarray, output_size: Tuple[int, int]) -> np.ndarr ...@@ -59,31 +61,38 @@ def make_pixel_mask(image: np.ndarray, output_size: Tuple[int, int]) -> np.ndarr
output_size (`Tuple[int, int]`): output_size (`Tuple[int, int]`):
Output size of the mask. Output size of the mask.
""" """
input_height, input_width = get_image_size(image) input_height, input_width = get_image_size(image, channel_dim=input_data_format)
mask = np.zeros(output_size, dtype=np.int64) mask = np.zeros(output_size, dtype=np.int64)
mask[:input_height, :input_width] = 1 mask[:input_height, :input_width] = 1
return mask return mask
def get_max_height_width(images: List[np.ndarray]) -> List[int]: def get_max_height_width(
images: List[np.ndarray], input_data_format: Optional[Union[str, ChannelDimension]] = None
) -> List[int]:
""" """
Get the maximum height and width across all images in a batch. Get the maximum height and width across all images in a batch.
""" """
input_channel_dimension = infer_channel_dimension_format(images[0]) if input_data_format is None:
input_data_format = infer_channel_dimension_format(images[0])
if input_channel_dimension == ChannelDimension.FIRST: if input_data_format == ChannelDimension.FIRST:
_, max_height, max_width = max_across_indices([img.shape for img in images]) _, max_height, max_width = max_across_indices([img.shape for img in images])
elif input_channel_dimension == ChannelDimension.LAST: elif input_data_format == ChannelDimension.LAST:
max_height, max_width, _ = max_across_indices([img.shape for img in images]) max_height, max_width, _ = max_across_indices([img.shape for img in images])
else: else:
raise ValueError(f"Invalid channel dimension format: {input_channel_dimension}") raise ValueError(f"Invalid channel dimension format: {input_data_format}")
return (max_height, max_width) return (max_height, max_width)
def get_resize_output_image_size( def get_resize_output_image_size(
input_image: np.ndarray, shorter: int = 800, longer: int = 1333, size_divisor: int = 32 input_image: np.ndarray,
shorter: int = 800,
longer: int = 1333,
size_divisor: int = 32,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> Tuple[int, int]: ) -> Tuple[int, int]:
input_height, input_width = get_image_size(input_image) input_height, input_width = get_image_size(input_image, input_data_format)
min_size, max_size = shorter, longer min_size, max_size = shorter, longer
scale = min_size / min(input_height, input_width) scale = min_size / min(input_height, input_width)
...@@ -200,6 +209,7 @@ class ViltImageProcessor(BaseImageProcessor): ...@@ -200,6 +209,7 @@ class ViltImageProcessor(BaseImageProcessor):
size_divisor: int = 32, size_divisor: int = 32,
resample: PILImageResampling = PILImageResampling.BICUBIC, resample: PILImageResampling = PILImageResampling.BICUBIC,
data_format: Optional[Union[str, ChannelDimension]] = None, data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs, **kwargs,
) -> np.ndarray: ) -> np.ndarray:
""" """
...@@ -220,14 +230,25 @@ class ViltImageProcessor(BaseImageProcessor): ...@@ -220,14 +230,25 @@ class ViltImageProcessor(BaseImageProcessor):
Resampling filter to use when resiizing the image. Resampling filter to use when resiizing the image.
data_format (`str` or `ChannelDimension`, *optional*): data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image. The channel dimension format of the image. If not provided, it will be the same as the input image.
input_data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred.
""" """
size = get_size_dict(size, default_to_square=False) size = get_size_dict(size, default_to_square=False)
if "shortest_edge" not in size: if "shortest_edge" not in size:
raise ValueError(f"The `size` dictionary must contain the key `shortest_edge`. Got {size.keys()}") raise ValueError(f"The `size` dictionary must contain the key `shortest_edge`. Got {size.keys()}")
shorter = size["shortest_edge"] shorter = size["shortest_edge"]
longer = int(1333 / 800 * shorter) longer = int(1333 / 800 * shorter)
output_size = get_resize_output_image_size(image, shorter=shorter, longer=longer, size_divisor=size_divisor) output_size = get_resize_output_image_size(
return resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs) image, shorter=shorter, longer=longer, size_divisor=size_divisor, input_data_format=input_data_format
)
return resize(
image,
size=output_size,
resample=resample,
data_format=data_format,
input_data_format=input_data_format,
**kwargs,
)
# Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor._pad_image # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor._pad_image
def _pad_image( def _pad_image(
...@@ -236,18 +257,24 @@ class ViltImageProcessor(BaseImageProcessor): ...@@ -236,18 +257,24 @@ class ViltImageProcessor(BaseImageProcessor):
output_size: Tuple[int, int], output_size: Tuple[int, int],
constant_values: Union[float, Iterable[float]] = 0, constant_values: Union[float, Iterable[float]] = 0,
data_format: Optional[ChannelDimension] = None, data_format: Optional[ChannelDimension] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray: ) -> np.ndarray:
""" """
Pad an image with zeros to the given size. Pad an image with zeros to the given size.
""" """
input_height, input_width = get_image_size(image) input_height, input_width = get_image_size(image, channel_dim=input_data_format)
output_height, output_width = output_size output_height, output_width = output_size
pad_bottom = output_height - input_height pad_bottom = output_height - input_height
pad_right = output_width - input_width pad_right = output_width - input_width
padding = ((0, pad_bottom), (0, pad_right)) padding = ((0, pad_bottom), (0, pad_right))
padded_image = pad( padded_image = pad(
image, padding, mode=PaddingMode.CONSTANT, constant_values=constant_values, data_format=data_format image,
padding,
mode=PaddingMode.CONSTANT,
constant_values=constant_values,
data_format=data_format,
input_data_format=input_data_format,
) )
return padded_image return padded_image
...@@ -259,6 +286,7 @@ class ViltImageProcessor(BaseImageProcessor): ...@@ -259,6 +286,7 @@ class ViltImageProcessor(BaseImageProcessor):
return_pixel_mask: bool = True, return_pixel_mask: bool = True,
return_tensors: Optional[Union[str, TensorType]] = None, return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Optional[ChannelDimension] = None, data_format: Optional[ChannelDimension] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> BatchFeature: ) -> BatchFeature:
""" """
Pads a batch of images to the bottom and right of the image with zeros to the size of largest height and width Pads a batch of images to the bottom and right of the image with zeros to the size of largest height and width
...@@ -280,17 +308,28 @@ class ViltImageProcessor(BaseImageProcessor): ...@@ -280,17 +308,28 @@ class ViltImageProcessor(BaseImageProcessor):
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
data_format (`str` or `ChannelDimension`, *optional*): data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image. The channel dimension format of the image. If not provided, it will be the same as the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred.
""" """
pad_size = get_max_height_width(images) pad_size = get_max_height_width(images, input_data_format=input_data_format)
padded_images = [ padded_images = [
self._pad_image(image, pad_size, constant_values=constant_values, data_format=data_format) self._pad_image(
image,
pad_size,
constant_values=constant_values,
data_format=data_format,
input_data_format=input_data_format,
)
for image in images for image in images
] ]
data = {"pixel_values": padded_images} data = {"pixel_values": padded_images}
if return_pixel_mask: if return_pixel_mask:
masks = [make_pixel_mask(image=image, output_size=pad_size) for image in images] masks = [
make_pixel_mask(image=image, output_size=pad_size, input_data_format=input_data_format)
for image in images
]
data["pixel_mask"] = masks data["pixel_mask"] = masks
return BatchFeature(data=data, tensor_type=return_tensors) return BatchFeature(data=data, tensor_type=return_tensors)
...@@ -310,6 +349,7 @@ class ViltImageProcessor(BaseImageProcessor): ...@@ -310,6 +349,7 @@ class ViltImageProcessor(BaseImageProcessor):
do_pad: Optional[bool] = None, do_pad: Optional[bool] = None,
return_tensors: Optional[Union[str, TensorType]] = None, return_tensors: Optional[Union[str, TensorType]] = None,
data_format: ChannelDimension = ChannelDimension.FIRST, data_format: ChannelDimension = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs, **kwargs,
) -> PIL.Image.Image: ) -> PIL.Image.Image:
""" """
...@@ -353,6 +393,12 @@ class ViltImageProcessor(BaseImageProcessor): ...@@ -353,6 +393,12 @@ class ViltImageProcessor(BaseImageProcessor):
The channel dimension format for the output image. Can be one of: The channel dimension format for the output image. Can be one of:
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `ChannelDimension.LAST`: image in (height, width, num_channels) format. - `ChannelDimension.LAST`: image in (height, width, num_channels) format.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
""" """
do_resize = do_resize if do_resize is not None else self.do_resize do_resize = do_resize if do_resize is not None else self.do_resize
size_divisor = size_divisor if size_divisor is not None else self.size_divisor size_divisor = size_divisor if size_divisor is not None else self.size_divisor
...@@ -387,21 +433,42 @@ class ViltImageProcessor(BaseImageProcessor): ...@@ -387,21 +433,42 @@ class ViltImageProcessor(BaseImageProcessor):
# All transformations expect numpy arrays. # All transformations expect numpy arrays.
images = [to_numpy_array(image) for image in images] images = [to_numpy_array(image) for image in images]
if input_data_format is None:
# We assume that all images have the same channel dimension format.
input_data_format = infer_channel_dimension_format(images[0])
if do_resize: if do_resize:
images = [ images = [
self.resize(image=image, size=size, size_divisor=size_divisor, resample=resample) for image in images self.resize(
image=image,
size=size,
size_divisor=size_divisor,
resample=resample,
input_data_format=input_data_format,
)
for image in images
] ]
if do_rescale: if do_rescale:
images = [self.rescale(image=image, scale=rescale_factor) for image in images] images = [
self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
for image in images
]
if do_normalize: if do_normalize:
images = [self.normalize(image=image, mean=image_mean, std=image_std) for image in images] images = [
self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
for image in images
]
images = [to_channel_dimension_format(image, data_format) for image in images] images = [
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
]
if do_pad: if do_pad:
encoded_outputs = self.pad(images, return_pixel_mask=True, return_tensors=return_tensors) encoded_outputs = self.pad(
images, return_pixel_mask=True, return_tensors=return_tensors, input_data_format=data_format
)
else: else:
encoded_outputs = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors) encoded_outputs = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment