Unverified Commit b58beebe authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

Add vision requirement to image transforms (#20712)

* Add require_vision decorator

* Fixup

* Use requires_backends

* Add requires_backend to utils functions
parent fd2bed7f
...@@ -18,26 +18,28 @@ from typing import Iterable, List, Optional, Tuple, Union ...@@ -18,26 +18,28 @@ from typing import Iterable, List, Optional, Tuple, Union
import numpy as np import numpy as np
from transformers.utils import ExplicitEnum, TensorType from transformers.image_utils import (
from transformers.utils.import_utils import is_flax_available, is_tf_available, is_torch_available, is_vision_available
if is_vision_available():
import PIL
from .image_utils import (
ChannelDimension, ChannelDimension,
PILImageResampling,
get_channel_dimension_axis, get_channel_dimension_axis,
get_image_size, get_image_size,
infer_channel_dimension_format, infer_channel_dimension_format,
is_jax_tensor,
is_tf_tensor,
is_torch_tensor,
to_numpy_array, to_numpy_array,
) )
from transformers.utils import ExplicitEnum, TensorType, is_jax_tensor, is_tf_tensor, is_torch_tensor
from transformers.utils.import_utils import (
is_flax_available,
is_tf_available,
is_torch_available,
is_vision_available,
requires_backends,
)
if is_vision_available():
import PIL
from .image_utils import PILImageResampling
if is_torch_available(): if is_torch_available():
import torch import torch
...@@ -116,9 +118,9 @@ def rescale( ...@@ -116,9 +118,9 @@ def rescale(
def to_pil_image( def to_pil_image(
image: Union[np.ndarray, PIL.Image.Image, "torch.Tensor", "tf.Tensor", "jnp.ndarray"], image: Union[np.ndarray, "PIL.Image.Image", "torch.Tensor", "tf.Tensor", "jnp.ndarray"],
do_rescale: Optional[bool] = None, do_rescale: Optional[bool] = None,
) -> PIL.Image.Image: ) -> "PIL.Image.Image":
""" """
Converts `image` to a PIL Image. Optionally rescales it and puts the channel dimension back as the last axis if Converts `image` to a PIL Image. Optionally rescales it and puts the channel dimension back as the last axis if
needed. needed.
...@@ -133,6 +135,8 @@ def to_pil_image( ...@@ -133,6 +135,8 @@ def to_pil_image(
Returns: Returns:
`PIL.Image.Image`: The converted image. `PIL.Image.Image`: The converted image.
""" """
requires_backends(to_pil_image, ["vision"])
if isinstance(image, PIL.Image.Image): if isinstance(image, PIL.Image.Image):
return image return image
...@@ -226,7 +230,7 @@ def get_resize_output_image_size( ...@@ -226,7 +230,7 @@ def get_resize_output_image_size(
def resize( def resize(
image, image,
size: Tuple[int, int], size: Tuple[int, int],
resample=PILImageResampling.BILINEAR, resample: "PILImageResampling" = None,
reducing_gap: Optional[int] = None, reducing_gap: Optional[int] = None,
data_format: Optional[ChannelDimension] = None, data_format: Optional[ChannelDimension] = None,
return_numpy: bool = True, return_numpy: bool = True,
...@@ -253,6 +257,10 @@ def resize( ...@@ -253,6 +257,10 @@ def resize(
Returns: Returns:
`np.ndarray`: The resized image. `np.ndarray`: The resized image.
""" """
requires_backends(resize, ["vision"])
resample = resample if resample is not None else PILImageResampling.BILINEAR
if not len(size) == 2: if not len(size) == 2:
raise ValueError("size must have 2 elements") raise ValueError("size must have 2 elements")
...@@ -303,6 +311,8 @@ def normalize( ...@@ -303,6 +311,8 @@ def normalize(
data_format (`ChannelDimension`, *optional*): data_format (`ChannelDimension`, *optional*):
The channel dimension format of the output image. If unset, will use the inferred format from the input. The channel dimension format of the output image. If unset, will use the inferred format from the input.
""" """
requires_backends(normalize, ["vision"])
if isinstance(image, PIL.Image.Image): if isinstance(image, PIL.Image.Image):
warnings.warn( warnings.warn(
"PIL.Image.Image inputs are deprecated and will be removed in v4.26.0. Please use numpy arrays instead.", "PIL.Image.Image inputs are deprecated and will be removed in v4.26.0. Please use numpy arrays instead.",
...@@ -372,6 +382,8 @@ def center_crop( ...@@ -372,6 +382,8 @@ def center_crop(
Returns: Returns:
`np.ndarray`: The cropped image. `np.ndarray`: The cropped image.
""" """
requires_backends(center_crop, ["vision"])
if isinstance(image, PIL.Image.Image): if isinstance(image, PIL.Image.Image):
warnings.warn( warnings.warn(
"PIL.Image.Image inputs are deprecated and will be removed in v4.26.0. Please use numpy arrays instead.", "PIL.Image.Image inputs are deprecated and will be removed in v4.26.0. Please use numpy arrays instead.",
......
...@@ -28,6 +28,7 @@ from .utils import ( ...@@ -28,6 +28,7 @@ from .utils import (
is_torch_available, is_torch_available,
is_torch_tensor, is_torch_tensor,
is_vision_available, is_vision_available,
requires_backends,
to_numpy, to_numpy,
) )
from .utils.constants import ( # noqa: F401 from .utils.constants import ( # noqa: F401
...@@ -64,7 +65,8 @@ class ChannelDimension(ExplicitEnum): ...@@ -64,7 +65,8 @@ class ChannelDimension(ExplicitEnum):
def is_valid_image(img): def is_valid_image(img):
return ( return (
isinstance(img, (PIL.Image.Image, np.ndarray)) (is_vision_available() and isinstance(img, PIL.Image.Image))
or isinstance(img, np.ndarray)
or is_torch_tensor(img) or is_torch_tensor(img)
or is_tf_tensor(img) or is_tf_tensor(img)
or is_jax_tensor(img) or is_jax_tensor(img)
...@@ -90,7 +92,10 @@ def is_batched(img): ...@@ -90,7 +92,10 @@ def is_batched(img):
def to_numpy_array(img) -> np.ndarray: def to_numpy_array(img) -> np.ndarray:
if isinstance(img, PIL.Image.Image): if not is_valid_image(img):
raise ValueError(f"Invalid image type: {type(img)}")
if is_vision_available() and isinstance(img, PIL.Image.Image):
return np.array(img) return np.array(img)
return to_numpy(img) return to_numpy(img)
...@@ -215,6 +220,7 @@ def load_image(image: Union[str, "PIL.Image.Image"]) -> "PIL.Image.Image": ...@@ -215,6 +220,7 @@ def load_image(image: Union[str, "PIL.Image.Image"]) -> "PIL.Image.Image":
Returns: Returns:
`PIL.Image.Image`: A PIL Image. `PIL.Image.Image`: A PIL Image.
""" """
requires_backends(load_image, ["vision"])
if isinstance(image, str): if isinstance(image, str):
if image.startswith("http://") or image.startswith("https://"): if image.startswith("http://") or image.startswith("https://"):
# We need to actually check for a real protocol, otherwise it's impossible to use a local file # We need to actually check for a real protocol, otherwise it's impossible to use a local file
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment