Unverified Commit 1c9134f0 authored by Pablo Montalvo's avatar Pablo Montalvo Committed by GitHub
Browse files

Abstract image processor arg checks. (#28843)



* abstract image processor arg checks.

* fix signatures and quality

* add validate_ method to rescale-prone processors

* add more validations

* quality

* quality

* fix formatting
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* fix formatting
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* fix formatting
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Fix formatting mishap
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* fix crop_size compatibility

* fix default mutable arg

* fix segmentation map + image arg validity

* remove segmentation check from arg validation

* fix quality

* fix missing segmap

* protect PILImageResampling type

* Apply suggestions from code review
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* add back segmentation maps check

---------
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent f7ef7cec
......@@ -29,6 +29,7 @@ from ...image_utils import (
make_list_of_images,
to_numpy_array,
valid_images,
validate_preprocess_arguments,
)
from ...utils import TensorType, is_vision_available, logging
......@@ -243,8 +244,13 @@ class ImageGPTImageProcessor(BaseImageProcessor):
"torch.Tensor, tf.Tensor or jax.ndarray."
)
if do_resize and size is None or resample is None:
raise ValueError("Size and resample must be specified if do_resize is True.")
# Here, normalize() is using a constant factor to divide pixel values.
# hence, the method does not need iamge_mean and image_std.
validate_preprocess_arguments(
do_resize=do_resize,
size=size,
resample=resample,
)
if do_color_quantize and clusters is None:
raise ValueError("Clusters must be specified if do_color_quantize is True.")
......
......@@ -28,6 +28,7 @@ from ...image_utils import (
make_list_of_images,
to_numpy_array,
valid_images,
validate_preprocess_arguments,
)
from ...utils import TensorType, is_pytesseract_available, is_vision_available, logging, requires_backends
......@@ -248,9 +249,11 @@ class LayoutLMv2ImageProcessor(BaseImageProcessor):
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
if do_resize and size is None:
raise ValueError("Size must be specified if do_resize is True.")
validate_preprocess_arguments(
do_resize=do_resize,
size=size,
resample=resample,
)
# All transformations expect numpy arrays.
images = [to_numpy_array(image) for image in images]
......
......@@ -31,6 +31,7 @@ from ...image_utils import (
make_list_of_images,
to_numpy_array,
valid_images,
validate_preprocess_arguments,
)
from ...utils import TensorType, is_pytesseract_available, is_vision_available, logging, requires_backends
......@@ -295,7 +296,6 @@ class LayoutLMv3ImageProcessor(BaseImageProcessor):
apply_ocr = apply_ocr if apply_ocr is not None else self.apply_ocr
ocr_lang = ocr_lang if ocr_lang is not None else self.ocr_lang
tesseract_config = tesseract_config if tesseract_config is not None else self.tesseract_config
images = make_list_of_images(images)
if not valid_images(images):
......@@ -303,15 +303,16 @@ class LayoutLMv3ImageProcessor(BaseImageProcessor):
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
if do_resize and size is None:
raise ValueError("Size must be specified if do_resize is True.")
if do_rescale and rescale_factor is None:
raise ValueError("Rescale factor must be specified if do_rescale is True.")
if do_normalize and (image_mean is None or image_std is None):
raise ValueError("If do_normalize is True, image_mean and image_std must be specified.")
validate_preprocess_arguments(
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
do_resize=do_resize,
size=size,
resample=resample,
)
# All transformations expect numpy arrays.
images = [to_numpy_array(image) for image in images]
......
......@@ -35,6 +35,7 @@ from ...image_utils import (
make_list_of_images,
to_numpy_array,
valid_images,
validate_preprocess_arguments,
)
from ...utils import TensorType, logging
......@@ -251,7 +252,6 @@ class LevitImageProcessor(BaseImageProcessor):
size = get_size_dict(size, default_to_square=False)
crop_size = crop_size if crop_size is not None else self.crop_size
crop_size = get_size_dict(crop_size, param_name="crop_size")
images = make_list_of_images(images)
if not valid_images(images):
......@@ -259,19 +259,18 @@ class LevitImageProcessor(BaseImageProcessor):
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
if do_resize and size is None:
raise ValueError("Size must be specified if do_resize is True.")
if do_center_crop and crop_size is None:
raise ValueError("Crop size must be specified if do_center_crop is True.")
if do_rescale and rescale_factor is None:
raise ValueError("Rescale factor must be specified if do_rescale is True.")
if do_normalize and (image_mean is None or image_std is None):
raise ValueError("Image mean and std must be specified if do_normalize is True.")
validate_preprocess_arguments(
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
do_center_crop=do_center_crop,
crop_size=crop_size,
do_resize=do_resize,
size=size,
resample=resample,
)
# All transformations expect numpy arrays.
images = [to_numpy_array(image) for image in images]
......
......@@ -39,6 +39,7 @@ from ...image_utils import (
is_scaled_image,
to_numpy_array,
valid_images,
validate_preprocess_arguments,
)
from ...utils import (
IMAGENET_DEFAULT_MEAN,
......@@ -707,21 +708,23 @@ class Mask2FormerImageProcessor(BaseImageProcessor):
ignore_index = ignore_index if ignore_index is not None else self.ignore_index
reduce_labels = reduce_labels if reduce_labels is not None else self.reduce_labels
if do_resize is not None and size is None or size_divisor is None:
raise ValueError("If `do_resize` is True, `size` and `size_divisor` must be provided.")
if do_rescale is not None and rescale_factor is None:
raise ValueError("If `do_rescale` is True, `rescale_factor` must be provided.")
if do_normalize is not None and (image_mean is None or image_std is None):
raise ValueError("If `do_normalize` is True, `image_mean` and `image_std` must be provided.")
if not valid_images(images):
raise ValueError(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
validate_preprocess_arguments(
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
do_resize=do_resize,
size=size,
resample=resample,
)
if segmentation_maps is not None and not valid_images(segmentation_maps):
raise ValueError(
"Invalid segmentation map type. Must be of type PIL.Image.Image, numpy.ndarray, "
......
......@@ -39,6 +39,7 @@ from ...image_utils import (
make_list_of_images,
to_numpy_array,
valid_images,
validate_preprocess_arguments,
)
from ...utils import (
IMAGENET_DEFAULT_MEAN,
......@@ -724,20 +725,21 @@ class MaskFormerImageProcessor(BaseImageProcessor):
ignore_index = ignore_index if ignore_index is not None else self.ignore_index
do_reduce_labels = do_reduce_labels if do_reduce_labels is not None else self.do_reduce_labels
if do_resize is not None and size is None or size_divisor is None:
raise ValueError("If `do_resize` is True, `size` and `size_divisor` must be provided.")
if do_rescale is not None and rescale_factor is None:
raise ValueError("If `do_rescale` is True, `rescale_factor` must be provided.")
if do_normalize is not None and (image_mean is None or image_std is None):
raise ValueError("If `do_normalize` is True, `image_mean` and `image_std` must be provided.")
if not valid_images(images):
raise ValueError(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
validate_preprocess_arguments(
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
do_resize=do_resize,
size=size,
resample=resample,
)
if segmentation_maps is not None and not valid_images(segmentation_maps):
raise ValueError(
......
......@@ -35,6 +35,7 @@ from ...image_utils import (
make_list_of_images,
to_numpy_array,
valid_images,
validate_preprocess_arguments,
)
from ...utils import TensorType, logging
......@@ -249,18 +250,18 @@ class MobileNetV1ImageProcessor(BaseImageProcessor):
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
if do_resize and size is None:
raise ValueError("Size must be specified if do_resize is True.")
if do_center_crop and crop_size is None:
raise ValueError("Crop size must be specified if do_center_crop is True.")
if do_rescale and rescale_factor is None:
raise ValueError("Rescale factor must be specified if do_rescale is True.")
if do_normalize and (image_mean is None or image_std is None):
raise ValueError("Image mean and std must be specified if do_normalize is True.")
validate_preprocess_arguments(
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
do_center_crop=do_center_crop,
crop_size=crop_size,
do_resize=do_resize,
size=size,
resample=resample,
)
# All transformations expect numpy arrays.
images = [to_numpy_array(image) for image in images]
......
......@@ -35,6 +35,7 @@ from ...image_utils import (
make_list_of_images,
to_numpy_array,
valid_images,
validate_preprocess_arguments,
)
from ...utils import TensorType, is_torch_available, is_torch_tensor, logging
......@@ -253,19 +254,18 @@ class MobileNetV2ImageProcessor(BaseImageProcessor):
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
if do_resize and size is None:
raise ValueError("Size must be specified if do_resize is True.")
if do_center_crop and crop_size is None:
raise ValueError("Crop size must be specified if do_center_crop is True.")
if do_rescale and rescale_factor is None:
raise ValueError("Rescale factor must be specified if do_rescale is True.")
if do_normalize and (image_mean is None or image_std is None):
raise ValueError("Image mean and std must be specified if do_normalize is True.")
validate_preprocess_arguments(
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
do_center_crop=do_center_crop,
crop_size=crop_size,
do_resize=do_resize,
size=size,
resample=resample,
)
# All transformations expect numpy arrays.
images = [to_numpy_array(image) for image in images]
......
......@@ -29,6 +29,7 @@ from ...image_utils import (
make_list_of_images,
to_numpy_array,
valid_images,
validate_preprocess_arguments,
)
from ...utils import TensorType, is_torch_available, is_torch_tensor, is_vision_available, logging
......@@ -368,6 +369,8 @@ class MobileViTImageProcessor(BaseImageProcessor):
if segmentation_maps is not None:
segmentation_maps = make_list_of_images(segmentation_maps, expected_ndims=2)
images = make_list_of_images(images)
if not valid_images(images):
raise ValueError(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
......@@ -380,14 +383,15 @@ class MobileViTImageProcessor(BaseImageProcessor):
"torch.Tensor, tf.Tensor or jax.ndarray."
)
if do_resize and size is None:
raise ValueError("Size must be specified if do_resize is True.")
if do_rescale and rescale_factor is None:
raise ValueError("Rescale factor must be specified if do_rescale is True.")
if do_center_crop and crop_size is None:
raise ValueError("Crop size must be specified if do_center_crop is True.")
validate_preprocess_arguments(
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_center_crop=do_center_crop,
crop_size=crop_size,
do_resize=do_resize,
size=size,
resample=resample,
)
images = [
self._preprocess_image(
......
......@@ -38,6 +38,7 @@ from ...image_utils import (
make_list_of_images,
to_numpy_array,
valid_images,
validate_preprocess_arguments,
)
from ...utils import TensorType, logging
from ...utils.import_utils import is_cv2_available, is_vision_available
......@@ -446,18 +447,18 @@ class NougatImageProcessor(BaseImageProcessor):
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
if do_resize and size is None:
raise ValueError("Size must be specified if do_resize is True.")
if do_pad and size is None:
raise ValueError("Size must be specified if do_pad is True.")
if do_rescale and rescale_factor is None:
raise ValueError("Rescale factor must be specified if do_rescale is True.")
if do_normalize and (image_mean is None or image_std is None):
raise ValueError("Image mean and std must be specified if do_normalize is True.")
validate_preprocess_arguments(
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
do_pad=do_pad,
size_divisibility=size, # There is no pad divisibility in this processor, but pad requires the size arg.
do_resize=do_resize,
size=size,
resample=resample,
)
# All transformations expect numpy arrays.
images = [to_numpy_array(image) for image in images]
......
......@@ -42,6 +42,7 @@ from ...image_utils import (
make_list_of_images,
to_numpy_array,
valid_images,
validate_preprocess_arguments,
)
from ...utils import (
IMAGENET_DEFAULT_MEAN,
......@@ -708,20 +709,21 @@ class OneFormerImageProcessor(BaseImageProcessor):
ignore_index = ignore_index if ignore_index is not None else self.ignore_index
do_reduce_labels = do_reduce_labels if do_reduce_labels is not None else self.do_reduce_labels
if do_resize is not None and size is None:
raise ValueError("If `do_resize` is True, `size` must be provided.")
if do_rescale is not None and rescale_factor is None:
raise ValueError("If `do_rescale` is True, `rescale_factor` must be provided.")
if do_normalize is not None and (image_mean is None or image_std is None):
raise ValueError("If `do_normalize` is True, `image_mean` and `image_std` must be provided.")
if not valid_images(images):
raise ValueError(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
validate_preprocess_arguments(
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
do_resize=do_resize,
size=size,
resample=resample,
)
if segmentation_maps is not None and not valid_images(segmentation_maps):
raise ValueError(
......
......@@ -37,6 +37,7 @@ from ...image_utils import (
make_list_of_images,
to_numpy_array,
valid_images,
validate_preprocess_arguments,
)
from ...utils import (
TensorType,
......@@ -405,15 +406,18 @@ class Owlv2ImageProcessor(BaseImageProcessor):
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
if do_resize and size is None:
raise ValueError("Size must be specified if do_resize is True.")
if do_rescale and rescale_factor is None:
raise ValueError("Rescale factor must be specified if do_rescale is True.")
if do_normalize and (image_mean is None or image_std is None):
raise ValueError("Image mean and std must be specified if do_normalize is True.")
# Here, pad and resize methods are different from the rest of image processors
# as they don't have any resampling in resize()
# or pad size in pad() (the maximum of (height, width) is taken instead).
# hence, these arguments don't need to be passed in validate_preprocess_arguments.
validate_preprocess_arguments(
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
size=size,
)
# All transformations expect numpy arrays.
images = [to_numpy_array(image) for image in images]
......
......@@ -38,6 +38,7 @@ from ...image_utils import (
make_list_of_images,
to_numpy_array,
valid_images,
validate_preprocess_arguments,
)
from ...utils import TensorType, is_torch_available, logging
......@@ -348,18 +349,6 @@ class OwlViTImageProcessor(BaseImageProcessor):
image_mean = image_mean if image_mean is not None else self.image_mean
image_std = image_std if image_std is not None else self.image_std
if do_resize is not None and size is None:
raise ValueError("Size and max_size must be specified if do_resize is True.")
if do_center_crop is not None and crop_size is None:
raise ValueError("Crop size must be specified if do_center_crop is True.")
if do_rescale is not None and rescale_factor is None:
raise ValueError("Rescale factor must be specified if do_rescale is True.")
if do_normalize is not None and (image_mean is None or image_std is None):
raise ValueError("Image mean and std must be specified if do_normalize is True.")
images = make_list_of_images(images)
if not valid_images(images):
......@@ -368,6 +357,19 @@ class OwlViTImageProcessor(BaseImageProcessor):
"torch.Tensor, tf.Tensor or jax.ndarray."
)
validate_preprocess_arguments(
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
do_center_crop=do_center_crop,
crop_size=crop_size,
do_resize=do_resize,
size=size,
resample=resample,
)
# All transformations expect numpy arrays
images = [to_numpy_array(image) for image in images]
......
......@@ -32,6 +32,7 @@ from ...image_utils import (
make_list_of_images,
to_numpy_array,
valid_images,
validate_preprocess_arguments,
)
from ...utils import TensorType, is_vision_available, logging
......@@ -290,18 +291,18 @@ class PerceiverImageProcessor(BaseImageProcessor):
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
if do_center_crop and crop_size is None:
raise ValueError("If `do_center_crop` is set to `True`, `crop_size` must be provided.")
if do_resize and size is None:
raise ValueError("Size must be specified if do_resize is True.")
if do_rescale and rescale_factor is None:
raise ValueError("Rescale factor must be specified if do_rescale is True.")
if do_normalize and (image_mean is None or image_std is None):
raise ValueError("Image mean and image standard deviation must be specified if do_normalize is True.")
validate_preprocess_arguments(
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
do_center_crop=do_center_crop,
crop_size=crop_size,
do_resize=do_resize,
size=size,
resample=resample,
)
# All transformations expect numpy arrays.
images = [to_numpy_array(image) for image in images]
......
......@@ -35,6 +35,7 @@ from ...image_utils import (
make_list_of_images,
to_numpy_array,
valid_images,
validate_preprocess_arguments,
)
from ...utils import TensorType, is_vision_available, logging
......@@ -297,18 +298,18 @@ class PoolFormerImageProcessor(BaseImageProcessor):
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
if do_resize and size is None or resample is None:
raise ValueError("Size and resample must be specified if do_resize is True.")
if do_center_crop and crop_pct is None:
raise ValueError("Crop_pct must be specified if do_center_crop is True.")
if do_rescale and rescale_factor is None:
raise ValueError("Rescale factor must be specified if do_rescale is True.")
if do_normalize and (image_mean is None or image_std is None):
raise ValueError("Image mean and std must be specified if do_normalize is True.")
validate_preprocess_arguments(
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
do_center_crop=do_center_crop,
crop_size=crop_size,
do_resize=do_resize,
size=size,
resample=resample,
)
# All transformations expect numpy arrays.
images = [to_numpy_array(image) for image in images]
......
......@@ -31,6 +31,7 @@ from ...image_utils import (
make_list_of_images,
to_numpy_array,
valid_images,
validate_preprocess_arguments,
)
from ...utils import TensorType, logging
......@@ -222,12 +223,16 @@ class PvtImageProcessor(BaseImageProcessor):
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
if do_resize and size is None:
raise ValueError("Size must be specified if do_resize is True.")
if do_rescale and rescale_factor is None:
raise ValueError("Rescale factor must be specified if do_rescale is True.")
validate_preprocess_arguments(
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
do_resize=do_resize,
size=size,
resample=resample,
)
# All transformations expect numpy arrays.
images = [to_numpy_array(image) for image in images]
......
......@@ -34,6 +34,7 @@ from ...image_utils import (
make_list_of_images,
to_numpy_array,
valid_images,
validate_preprocess_arguments,
)
from ...utils import (
TensorType,
......@@ -504,18 +505,18 @@ class SamImageProcessor(BaseImageProcessor):
"Invalid segmentation map type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
if do_resize and (size is None or resample is None):
raise ValueError("Size and resample must be specified if do_resize is True.")
if do_rescale and rescale_factor is None:
raise ValueError("Rescale factor must be specified if do_rescale is True.")
if do_normalize and (image_mean is None or image_std is None):
raise ValueError("Image mean and std must be specified if do_normalize is True.")
if do_pad and pad_size is None:
raise ValueError("Pad size must be specified if do_pad is True.")
validate_preprocess_arguments(
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
do_pad=do_pad,
size_divisibility=pad_size, # Here _preprocess needs do_pad and pad_size.
do_resize=do_resize,
size=size,
resample=resample,
)
images, original_sizes, reshaped_input_sizes = zip(
*(
......
......@@ -32,6 +32,7 @@ from ...image_utils import (
make_list_of_images,
to_numpy_array,
valid_images,
validate_preprocess_arguments,
)
from ...utils import TensorType, is_torch_available, is_torch_tensor, is_vision_available, logging
......@@ -387,22 +388,17 @@ class SegformerImageProcessor(BaseImageProcessor):
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
if segmentation_maps is not None and not valid_images(segmentation_maps):
raise ValueError(
"Invalid segmentation map type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
validate_preprocess_arguments(
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
do_resize=do_resize,
size=size,
resample=resample,
)
if do_resize and size is None or resample is None:
raise ValueError("Size and resample must be specified if do_resize is True.")
if do_rescale and rescale_factor is None:
raise ValueError("Rescale factor must be specified if do_rescale is True.")
if do_normalize and (image_mean is None or image_std is None):
raise ValueError("Image mean and std must be specified if do_normalize is True.")
images = [
self._preprocess_image(
image=img,
......
......@@ -32,6 +32,7 @@ from ...image_utils import (
make_list_of_images,
to_numpy_array,
valid_images,
validate_preprocess_arguments,
)
from ...utils import TensorType, is_vision_available, logging
......@@ -178,13 +179,16 @@ class SiglipImageProcessor(BaseImageProcessor):
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
if do_resize and size is None:
raise ValueError("Size must be specified if do_resize is True.")
if do_rescale and rescale_factor is None:
raise ValueError("Rescale factor must be specified if do_rescale is True.")
validate_preprocess_arguments(
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
do_resize=do_resize,
size=size,
resample=resample,
)
# All transformations expect numpy arrays.
images = [to_numpy_array(image) for image in images]
......
......@@ -28,6 +28,7 @@ from ...image_utils import (
make_list_of_images,
to_numpy_array,
valid_images,
validate_preprocess_arguments,
)
from ...utils import TensorType, logging
......@@ -165,9 +166,12 @@ class Swin2SRImageProcessor(BaseImageProcessor):
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
if do_rescale and rescale_factor is None:
raise ValueError("Rescale factor must be specified if do_rescale is True.")
validate_preprocess_arguments(
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_pad=do_pad,
size_divisibility=pad_size, # Here the pad function simply requires pad_size.
)
# All transformations expect numpy arrays.
images = [to_numpy_array(image) for image in images]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment