Unverified Commit 1c9134f0 authored by Pablo Montalvo's avatar Pablo Montalvo Committed by GitHub
Browse files

Abstract image processor arg checks. (#28843)



* abstract image processor arg checks.

* fix signatures and quality

* add validate_ method to rescale-prone processors

* add more validations

* quality

* quality

* fix formatting
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* fix formatting
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* fix formatting
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Fix formatting mishap
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* fix crop_size compatibility

* fix default mutable arg

* fix segmentation map + image arg validity

* remove segmentation check from arg validation

* fix quality

* fix missing segmap

* protect PILImageResampling type

* Apply suggestions from code review
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* add back segmentation maps check

---------
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent f7ef7cec
...@@ -34,6 +34,7 @@ from ...image_utils import ( ...@@ -34,6 +34,7 @@ from ...image_utils import (
is_valid_image, is_valid_image,
to_numpy_array, to_numpy_array,
valid_images, valid_images,
validate_preprocess_arguments,
) )
from ...utils import TensorType, logging from ...utils import TensorType, logging
...@@ -212,17 +213,19 @@ class TvltImageProcessor(BaseImageProcessor): ...@@ -212,17 +213,19 @@ class TvltImageProcessor(BaseImageProcessor):
input_data_format: Optional[Union[str, ChannelDimension]] = None, input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray: ) -> np.ndarray:
"""Preprocesses a single image.""" """Preprocesses a single image."""
if do_resize and size is None or resample is None:
raise ValueError("Size and resample must be specified if do_resize is True.")
if do_center_crop and crop_size is None: validate_preprocess_arguments(
raise ValueError("Crop size must be specified if do_center_crop is True.") do_rescale=do_rescale,
rescale_factor=rescale_factor,
if do_rescale and rescale_factor is None: do_normalize=do_normalize,
raise ValueError("Rescale factor must be specified if do_rescale is True.") image_mean=image_mean,
image_std=image_std,
if do_normalize and (image_mean is None or image_std is None): do_center_crop=do_center_crop,
raise ValueError("Image mean and std must be specified if do_normalize is True.") crop_size=crop_size,
do_resize=do_resize,
size=size,
resample=resample,
)
# All transformations expect numpy arrays. # All transformations expect numpy arrays.
image = to_numpy_array(image) image = to_numpy_array(image)
......
...@@ -36,6 +36,7 @@ from ...image_utils import ( ...@@ -36,6 +36,7 @@ from ...image_utils import (
is_valid_image, is_valid_image,
to_numpy_array, to_numpy_array,
valid_images, valid_images,
validate_preprocess_arguments,
) )
from ...utils import TensorType, is_vision_available, logging from ...utils import TensorType, is_vision_available, logging
...@@ -285,20 +286,21 @@ class TvpImageProcessor(BaseImageProcessor): ...@@ -285,20 +286,21 @@ class TvpImageProcessor(BaseImageProcessor):
**kwargs, **kwargs,
) -> np.ndarray: ) -> np.ndarray:
"""Preprocesses a single image.""" """Preprocesses a single image."""
if do_resize and size is None or resample is None:
raise ValueError("Size and resample must be specified if do_resize is True.")
if do_center_crop and crop_size is None: validate_preprocess_arguments(
raise ValueError("Crop size must be specified if do_center_crop is True.") do_rescale=do_rescale,
rescale_factor=rescale_factor,
if do_rescale and rescale_factor is None: do_normalize=do_normalize,
raise ValueError("Rescale factor must be specified if do_rescale is True.") image_mean=image_mean,
image_std=image_std,
if do_pad and pad_size is None: do_pad=do_pad,
raise ValueError("Padding size must be specified if do_pad is True.") size_divisibility=pad_size, # here the pad() method simply requires the pad_size argument.
do_center_crop=do_center_crop,
if do_normalize and (image_mean is None or image_std is None): crop_size=crop_size,
raise ValueError("Image mean and std must be specified if do_normalize is True.") do_resize=do_resize,
size=size,
resample=resample,
)
# All transformations expect numpy arrays. # All transformations expect numpy arrays.
image = to_numpy_array(image) image = to_numpy_array(image)
......
...@@ -35,6 +35,7 @@ from ...image_utils import ( ...@@ -35,6 +35,7 @@ from ...image_utils import (
is_valid_image, is_valid_image,
to_numpy_array, to_numpy_array,
valid_images, valid_images,
validate_preprocess_arguments,
) )
from ...utils import TensorType, is_vision_available, logging from ...utils import TensorType, is_vision_available, logging
...@@ -191,17 +192,18 @@ class VideoMAEImageProcessor(BaseImageProcessor): ...@@ -191,17 +192,18 @@ class VideoMAEImageProcessor(BaseImageProcessor):
input_data_format: Optional[Union[str, ChannelDimension]] = None, input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray: ) -> np.ndarray:
"""Preprocesses a single image.""" """Preprocesses a single image."""
if do_resize and size is None or resample is None: validate_preprocess_arguments(
raise ValueError("Size and resample must be specified if do_resize is True.") do_rescale=do_rescale,
rescale_factor=rescale_factor,
if do_center_crop and crop_size is None: do_normalize=do_normalize,
raise ValueError("Crop size must be specified if do_center_crop is True.") image_mean=image_mean,
image_std=image_std,
if do_rescale and rescale_factor is None: do_center_crop=do_center_crop,
raise ValueError("Rescale factor must be specified if do_rescale is True.") crop_size=crop_size,
do_resize=do_resize,
if do_normalize and (image_mean is None or image_std is None): size=size,
raise ValueError("Image mean and std must be specified if do_normalize is True.") resample=resample,
)
# All transformations expect numpy arrays. # All transformations expect numpy arrays.
image = to_numpy_array(image) image = to_numpy_array(image)
......
...@@ -32,6 +32,7 @@ from ...image_utils import ( ...@@ -32,6 +32,7 @@ from ...image_utils import (
make_list_of_images, make_list_of_images,
to_numpy_array, to_numpy_array,
valid_images, valid_images,
validate_preprocess_arguments,
) )
from ...utils import TensorType, is_vision_available, logging from ...utils import TensorType, is_vision_available, logging
...@@ -421,14 +422,18 @@ class ViltImageProcessor(BaseImageProcessor): ...@@ -421,14 +422,18 @@ class ViltImageProcessor(BaseImageProcessor):
"torch.Tensor, tf.Tensor or jax.ndarray." "torch.Tensor, tf.Tensor or jax.ndarray."
) )
if do_resize and size is None or resample is None: # Here the pad() method does not require any additional argument as it takes the maximum of (height, width).
raise ValueError("Size and resample must be specified if do_resize is True.") # Hence, it does not need to be passed to a validate_preprocess_arguments() method.
validate_preprocess_arguments(
if do_rescale and rescale_factor is None: do_rescale=do_rescale,
raise ValueError("Rescale factor must be specified if do_rescale is True.") rescale_factor=rescale_factor,
do_normalize=do_normalize,
if do_normalize and (image_mean is None or image_std is None): image_mean=image_mean,
raise ValueError("Image mean and std must be specified if do_normalize is True.") image_std=image_std,
do_resize=do_resize,
size=size,
resample=resample,
)
# All transformations expect numpy arrays. # All transformations expect numpy arrays.
images = [to_numpy_array(image) for image in images] images = [to_numpy_array(image) for image in images]
......
...@@ -31,6 +31,7 @@ from ...image_utils import ( ...@@ -31,6 +31,7 @@ from ...image_utils import (
make_list_of_images, make_list_of_images,
to_numpy_array, to_numpy_array,
valid_images, valid_images,
validate_preprocess_arguments,
) )
from ...utils import TensorType, logging from ...utils import TensorType, logging
...@@ -221,12 +222,16 @@ class ViTImageProcessor(BaseImageProcessor): ...@@ -221,12 +222,16 @@ class ViTImageProcessor(BaseImageProcessor):
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray." "torch.Tensor, tf.Tensor or jax.ndarray."
) )
validate_preprocess_arguments(
if do_resize and size is None: do_rescale=do_rescale,
raise ValueError("Size must be specified if do_resize is True.") rescale_factor=rescale_factor,
do_normalize=do_normalize,
if do_rescale and rescale_factor is None: image_mean=image_mean,
raise ValueError("Rescale factor must be specified if do_rescale is True.") image_std=image_std,
do_resize=do_resize,
size=size,
resample=resample,
)
# All transformations expect numpy arrays. # All transformations expect numpy arrays.
images = [to_numpy_array(image) for image in images] images = [to_numpy_array(image) for image in images]
......
...@@ -36,6 +36,7 @@ from ...image_utils import ( ...@@ -36,6 +36,7 @@ from ...image_utils import (
make_list_of_images, make_list_of_images,
to_numpy_array, to_numpy_array,
valid_images, valid_images,
validate_preprocess_arguments,
) )
from ...utils import TensorType, is_vision_available, logging from ...utils import TensorType, is_vision_available, logging
...@@ -262,18 +263,18 @@ class ViTHybridImageProcessor(BaseImageProcessor): ...@@ -262,18 +263,18 @@ class ViTHybridImageProcessor(BaseImageProcessor):
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray." "torch.Tensor, tf.Tensor or jax.ndarray."
) )
validate_preprocess_arguments(
if do_resize and size is None: do_rescale=do_rescale,
raise ValueError("Size must be specified if do_resize is True.") rescale_factor=rescale_factor,
do_normalize=do_normalize,
if do_center_crop and crop_size is None: image_mean=image_mean,
raise ValueError("Crop size must be specified if do_center_crop is True.") image_std=image_std,
do_center_crop=do_center_crop,
if do_rescale and rescale_factor is None: crop_size=crop_size,
raise ValueError("Rescale factor must be specified if do_rescale is True.") do_resize=do_resize,
size=size,
if do_normalize and (image_mean is None or image_std is None): resample=resample,
raise ValueError("Image mean and std must be specified if do_normalize is True.") )
# PIL RGBA images are converted to RGB # PIL RGBA images are converted to RGB
if do_convert_rgb: if do_convert_rgb:
......
...@@ -31,6 +31,7 @@ from ...image_utils import ( ...@@ -31,6 +31,7 @@ from ...image_utils import (
make_list_of_images, make_list_of_images,
to_numpy_array, to_numpy_array,
valid_images, valid_images,
validate_preprocess_arguments,
) )
from ...utils import TensorType, logging from ...utils import TensorType, logging
...@@ -197,25 +198,28 @@ class VitMatteImageProcessor(BaseImageProcessor): ...@@ -197,25 +198,28 @@ class VitMatteImageProcessor(BaseImageProcessor):
images = make_list_of_images(images) images = make_list_of_images(images)
trimaps = make_list_of_images(trimaps, expected_ndims=2) trimaps = make_list_of_images(trimaps, expected_ndims=2)
if not valid_images(images):
raise ValueError(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
if not valid_images(trimaps): if not valid_images(trimaps):
raise ValueError( raise ValueError(
"Invalid trimap type. Must be of type PIL.Image.Image, numpy.ndarray, " "Invalid trimap type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray." "torch.Tensor, tf.Tensor or jax.ndarray."
) )
if do_rescale and rescale_factor is None: images = make_list_of_images(images)
raise ValueError("Rescale factor must be specified if do_rescale is True.")
if do_pad and size_divisibility is None:
raise ValueError("Size divisilibyt must be specified if do_pad is True.")
if do_normalize and (image_mean is None or image_std is None): if not valid_images(images):
raise ValueError("Image mean and std must be specified if do_normalize is True.") raise ValueError(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
validate_preprocess_arguments(
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
do_pad=do_pad,
size_divisibility=size_divisibility,
)
# All transformations expect numpy arrays. # All transformations expect numpy arrays.
images = [to_numpy_array(image) for image in images] images = [to_numpy_array(image) for image in images]
......
...@@ -38,6 +38,7 @@ from ...image_utils import ( ...@@ -38,6 +38,7 @@ from ...image_utils import (
is_valid_image, is_valid_image,
to_numpy_array, to_numpy_array,
valid_images, valid_images,
validate_preprocess_arguments,
) )
from ...utils import logging from ...utils import logging
...@@ -240,17 +241,19 @@ class VivitImageProcessor(BaseImageProcessor): ...@@ -240,17 +241,19 @@ class VivitImageProcessor(BaseImageProcessor):
input_data_format: Optional[Union[str, ChannelDimension]] = None, input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray: ) -> np.ndarray:
"""Preprocesses a single image.""" """Preprocesses a single image."""
if do_resize and size is None or resample is None:
raise ValueError("Size and resample must be specified if do_resize is True.")
if do_center_crop and crop_size is None: validate_preprocess_arguments(
raise ValueError("Crop size must be specified if do_center_crop is True.") do_rescale=do_rescale,
rescale_factor=rescale_factor,
if do_rescale and rescale_factor is None: do_normalize=do_normalize,
raise ValueError("Rescale factor must be specified if do_rescale is True.") image_mean=image_mean,
image_std=image_std,
if do_normalize and (image_mean is None or image_std is None): do_center_crop=do_center_crop,
raise ValueError("Image mean and std must be specified if do_normalize is True.") crop_size=crop_size,
do_resize=do_resize,
size=size,
resample=resample,
)
if offset and not do_rescale: if offset and not do_rescale:
raise ValueError("For offset, do_rescale must also be set to True.") raise ValueError("For offset, do_rescale must also be set to True.")
......
...@@ -47,6 +47,7 @@ from ...image_utils import ( ...@@ -47,6 +47,7 @@ from ...image_utils import (
to_numpy_array, to_numpy_array,
valid_images, valid_images,
validate_annotations, validate_annotations,
validate_preprocess_arguments,
) )
from ...utils import ( from ...utils import (
TensorType, TensorType,
...@@ -1185,16 +1186,25 @@ class YolosImageProcessor(BaseImageProcessor): ...@@ -1185,16 +1186,25 @@ class YolosImageProcessor(BaseImageProcessor):
do_pad = self.do_pad if do_pad is None else do_pad do_pad = self.do_pad if do_pad is None else do_pad
format = self.format if format is None else format format = self.format if format is None else format
if do_resize is not None and size is None: images = make_list_of_images(images)
raise ValueError("Size and max_size must be specified if do_resize is True.")
if do_rescale is not None and rescale_factor is None:
raise ValueError("Rescale factor must be specified if do_rescale is True.")
if do_normalize is not None and (image_mean is None or image_std is None): if not valid_images(images):
raise ValueError("Image mean and std must be specified if do_normalize is True.") raise ValueError(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
# Here the pad() method pads using the max of (width, height) and does not need to be validated.
validate_preprocess_arguments(
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
do_resize=do_resize,
size=size,
resample=resample,
)
images = make_list_of_images(images)
if annotations is not None and isinstance(annotations, dict): if annotations is not None and isinstance(annotations, dict):
annotations = [annotations] annotations = [annotations]
...@@ -1203,12 +1213,6 @@ class YolosImageProcessor(BaseImageProcessor): ...@@ -1203,12 +1213,6 @@ class YolosImageProcessor(BaseImageProcessor):
f"The number of images ({len(images)}) and annotations ({len(annotations)}) do not match." f"The number of images ({len(images)}) and annotations ({len(annotations)}) do not match."
) )
if not valid_images(images):
raise ValueError(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
format = AnnotationFormat(format) format = AnnotationFormat(format)
if annotations is not None: if annotations is not None:
validate_annotations(format, SUPPORTED_ANNOTATION_FORMATS, annotations) validate_annotations(format, SUPPORTED_ANNOTATION_FORMATS, annotations)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment