Unverified Commit 2030d208 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

register tensor and PIL kernel the same way as datapoints (#7797)


Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
parent 84db2ac4
......@@ -2,7 +2,6 @@ import inspect
import math
import os
import re
from unittest import mock
import numpy as np
import PIL.Image
......@@ -25,7 +24,6 @@ from torchvision.transforms.functional import _get_perspective_coeffs
from torchvision.transforms.v2 import functional as F
from torchvision.transforms.v2.functional._geometry import _center_crop_compute_padding
from torchvision.transforms.v2.functional._meta import clamp_bounding_boxes, convert_format_bounding_boxes
from torchvision.transforms.v2.functional._utils import _KERNEL_REGISTRY
from torchvision.transforms.v2.utils import is_simple_tensor
from transforms_v2_dispatcher_infos import DISPATCHER_INFOS
from transforms_v2_kernel_infos import KERNEL_INFOS
......@@ -359,18 +357,6 @@ class TestDispatchers:
def test_scriptable(self, dispatcher):
script(dispatcher)
@image_sample_inputs
def test_dispatch_simple_tensor(self, info, args_kwargs, spy_on):
(image_datapoint, *other_args), kwargs = args_kwargs.load()
image_simple_tensor = torch.Tensor(image_datapoint)
kernel_info = info.kernel_infos[datapoints.Image]
spy = spy_on(kernel_info.kernel, module=info.dispatcher.__module__, name=kernel_info.id)
info.dispatcher(image_simple_tensor, *other_args, **kwargs)
spy.assert_called_once()
@image_sample_inputs
def test_simple_tensor_output_type(self, info, args_kwargs):
(image_datapoint, *other_args), kwargs = args_kwargs.load()
......@@ -381,25 +367,6 @@ class TestDispatchers:
# We cannot use `isinstance` here since all datapoints are instances of `torch.Tensor` as well
assert type(output) is torch.Tensor
@make_info_args_kwargs_parametrization(
[info for info in DISPATCHER_INFOS if info.pil_kernel_info is not None],
args_kwargs_fn=lambda info: info.sample_inputs(datapoints.Image),
)
def test_dispatch_pil(self, info, args_kwargs, spy_on):
(image_datapoint, *other_args), kwargs = args_kwargs.load()
if image_datapoint.ndim > 3:
pytest.skip("Input is batched")
image_pil = F.to_image_pil(image_datapoint)
pil_kernel_info = info.pil_kernel_info
spy = spy_on(pil_kernel_info.kernel, module=info.dispatcher.__module__, name=pil_kernel_info.id)
info.dispatcher(image_pil, *other_args, **kwargs)
spy.assert_called_once()
@make_info_args_kwargs_parametrization(
[info for info in DISPATCHER_INFOS if info.pil_kernel_info is not None],
args_kwargs_fn=lambda info: info.sample_inputs(datapoints.Image),
......@@ -416,28 +383,6 @@ class TestDispatchers:
assert isinstance(output, PIL.Image.Image)
@make_info_args_kwargs_parametrization(
DISPATCHER_INFOS,
args_kwargs_fn=lambda info: info.sample_inputs(),
)
def test_dispatch_datapoint(self, info, args_kwargs, spy_on):
(datapoint, *other_args), kwargs = args_kwargs.load()
input_type = type(datapoint)
wrapped_kernel = _KERNEL_REGISTRY[info.dispatcher][input_type]
# In case the wrapper was decorated with @functools.wraps, we can make the check more strict and test if the
# proper kernel was wrapped
if hasattr(wrapped_kernel, "__wrapped__"):
assert wrapped_kernel.__wrapped__ is info.kernels[input_type]
spy = mock.MagicMock(wraps=wrapped_kernel, name=wrapped_kernel.__name__)
with mock.patch.dict(_KERNEL_REGISTRY[info.dispatcher], values={input_type: spy}):
info.dispatcher(datapoint, *other_args, **kwargs)
spy.assert_called_once()
@make_info_args_kwargs_parametrization(
DISPATCHER_INFOS,
args_kwargs_fn=lambda info: info.sample_inputs(),
......@@ -449,6 +394,9 @@ class TestDispatchers:
assert isinstance(output, type(datapoint))
if isinstance(datapoint, datapoints.BoundingBoxes) and info.dispatcher is not F.convert_format_bounding_boxes:
assert output.format == datapoint.format
@pytest.mark.parametrize(
("dispatcher_info", "datapoint_type", "kernel_info"),
[
......
......@@ -39,7 +39,7 @@ from torchvision import datapoints
from torchvision.transforms._functional_tensor import _max_value as get_max_value
from torchvision.transforms.functional import pil_modes_mapping
from torchvision.transforms.v2 import functional as F
from torchvision.transforms.v2.functional._utils import _KERNEL_REGISTRY
from torchvision.transforms.v2.functional._utils import _get_kernel, _KERNEL_REGISTRY, _noop, _register_kernel_internal
@pytest.fixture(autouse=True)
......@@ -173,59 +173,32 @@ def _check_dispatcher_scripted_smoke(dispatcher, input, *args, **kwargs):
dispatcher_scripted(input.as_subclass(torch.Tensor), *args, **kwargs)
def _check_dispatcher_dispatch(dispatcher, kernel, input, *args, **kwargs):
"""Checks if the dispatcher correctly dispatches the input to the corresponding kernel and that the input type is
preserved in doing so. For bounding boxes also checks that the format is preserved.
"""
input_type = type(input)
if isinstance(input, datapoints.Datapoint):
wrapped_kernel = _KERNEL_REGISTRY[dispatcher][input_type]
# In case the wrapper was decorated with @functools.wraps, we can make the check more strict and test if the
# proper kernel was wrapped
if hasattr(wrapped_kernel, "__wrapped__"):
assert wrapped_kernel.__wrapped__ is kernel
spy = mock.MagicMock(wraps=wrapped_kernel, name=wrapped_kernel.__name__)
with mock.patch.dict(_KERNEL_REGISTRY[dispatcher], values={input_type: spy}):
output = dispatcher(input, *args, **kwargs)
spy.assert_called_once()
else:
with mock.patch(f"{dispatcher.__module__}.{kernel.__name__}", wraps=kernel) as spy:
output = dispatcher(input, *args, **kwargs)
spy.assert_called_once()
assert isinstance(output, input_type)
if isinstance(input, datapoints.BoundingBoxes):
assert output.format == input.format
def check_dispatcher(
dispatcher,
# TODO: remove this parameter
kernel,
input,
*args,
check_scripted_smoke=True,
check_dispatch=True,
**kwargs,
):
unknown_input = object()
with pytest.raises(TypeError, match=re.escape(str(type(unknown_input)))):
dispatcher(unknown_input, *args, **kwargs)
with mock.patch("torch._C._log_api_usage_once", wraps=torch._C._log_api_usage_once) as spy:
with pytest.raises(TypeError, match=re.escape(str(type(unknown_input)))):
dispatcher(unknown_input, *args, **kwargs)
output = dispatcher(input, *args, **kwargs)
spy.assert_any_call(f"{dispatcher.__module__}.{dispatcher.__name__}")
assert isinstance(output, type(input))
if isinstance(input, datapoints.BoundingBoxes):
assert output.format == input.format
if check_scripted_smoke:
_check_dispatcher_scripted_smoke(dispatcher, input, *args, **kwargs)
if check_dispatch:
_check_dispatcher_dispatch(dispatcher, kernel, input, *args, **kwargs)
def check_dispatcher_kernel_signature_match(dispatcher, *, kernel, input_type):
"""Checks if the signature of the dispatcher matches the kernel signature."""
......@@ -412,18 +385,20 @@ def reference_affine_bounding_boxes_helper(bounding_boxes, *, format, canvas_siz
@pytest.mark.parametrize(
("dispatcher", "registered_datapoint_clss"),
("dispatcher", "registered_input_types"),
[(dispatcher, set(registry.keys())) for dispatcher, registry in _KERNEL_REGISTRY.items()],
)
def test_exhaustive_kernel_registration(dispatcher, registered_datapoint_clss):
def test_exhaustive_kernel_registration(dispatcher, registered_input_types):
missing = {
torch.Tensor,
PIL.Image.Image,
datapoints.Image,
datapoints.BoundingBoxes,
datapoints.Mask,
datapoints.Video,
} - registered_datapoint_clss
} - registered_input_types
if missing:
names = sorted(f"datapoints.{cls.__name__}" for cls in missing)
names = sorted(str(t) for t in missing)
raise AssertionError(
"\n".join(
[
......@@ -1753,11 +1728,6 @@ class TestToDtype:
F.to_dtype,
kernel,
make_input(dtype=input_dtype, device=device),
# TODO: we could leave check_dispatch to True but it currently fails
# in _check_dispatcher_dispatch because there is no to_dtype() method on the datapoints.
# We should be able to put this back if we change the dispatch
# mechanism e.g. via https://github.com/pytorch/vision/pull/7733
check_dispatch=False,
dtype=output_dtype,
scale=scale,
)
......@@ -2208,9 +2178,105 @@ class TestRegisterKernel:
t(torch.rand(3, 10, 10)).shape == (3, 224, 224)
t(datapoints.Image(torch.rand(3, 10, 10))).shape == (3, 224, 224)
def test_bad_disaptcher_name(self):
class CustomDatapoint(datapoints.Datapoint):
def test_errors(self):
with pytest.raises(ValueError, match="Could not find dispatcher with name"):
F.register_kernel("bad_name", datapoints.Image)
with pytest.raises(ValueError, match="Kernels can only be registered on dispatchers"):
F.register_kernel(datapoints.Image, F.resize)
with pytest.raises(ValueError, match="Kernels can only be registered for subclasses"):
F.register_kernel(F.resize, object)
with pytest.raises(ValueError, match="already has a kernel registered for type"):
F.register_kernel(F.resize, datapoints.Image)(F.resize_image_tensor)
class TestGetKernel:
# We are using F.resize as dispatcher and the kernels below as proxy. Any other dispatcher / kernels combination
# would also be fine
KERNELS = {
torch.Tensor: F.resize_image_tensor,
PIL.Image.Image: F.resize_image_pil,
datapoints.Image: F.resize_image_tensor,
datapoints.BoundingBoxes: F.resize_bounding_boxes,
datapoints.Mask: F.resize_mask,
datapoints.Video: F.resize_video,
}
def test_unsupported_types(self):
class MyTensor(torch.Tensor):
pass
with pytest.raises(ValueError, match="Could not find dispatcher with name"):
F.register_kernel("bad_name", CustomDatapoint)
class MyPILImage(PIL.Image.Image):
pass
for input_type in [str, int, object, MyTensor, MyPILImage]:
with pytest.raises(
TypeError,
match=(
"supports inputs of type torch.Tensor, PIL.Image.Image, "
"and subclasses of torchvision.datapoints.Datapoint"
),
):
_get_kernel(F.resize, input_type)
def test_exact_match(self):
# We cannot use F.resize together with self.KERNELS mapping here directly here, since this is only the
# ideal wrapping. Practically, we have an intermediate wrapper layer. Thus, we create a new resize dispatcher
# here, register the kernels without wrapper, and check the exact matching afterwards.
def resize_with_pure_kernels():
pass
for input_type, kernel in self.KERNELS.items():
_register_kernel_internal(resize_with_pure_kernels, input_type, datapoint_wrapper=False)(kernel)
assert _get_kernel(resize_with_pure_kernels, input_type) is kernel
def test_builtin_datapoint_subclass(self):
# We cannot use F.resize together with self.KERNELS mapping here directly here, since this is only the
# ideal wrapping. Practically, we have an intermediate wrapper layer. Thus, we create a new resize dispatcher
# here, register the kernels without wrapper, and check if subclasses of our builtin datapoints get dispatched
# to the kernel of the corresponding superclass
def resize_with_pure_kernels():
pass
class MyImage(datapoints.Image):
pass
class MyBoundingBoxes(datapoints.BoundingBoxes):
pass
class MyMask(datapoints.Mask):
pass
class MyVideo(datapoints.Video):
pass
for custom_datapoint_subclass in [
MyImage,
MyBoundingBoxes,
MyMask,
MyVideo,
]:
builtin_datapoint_class = custom_datapoint_subclass.__mro__[1]
builtin_datapoint_kernel = self.KERNELS[builtin_datapoint_class]
_register_kernel_internal(resize_with_pure_kernels, builtin_datapoint_class, datapoint_wrapper=False)(
builtin_datapoint_kernel
)
assert _get_kernel(resize_with_pure_kernels, custom_datapoint_subclass) is builtin_datapoint_kernel
def test_datapoint_subclass(self):
class MyDatapoint(datapoints.Datapoint):
pass
# Note that this will be an error in the future
assert _get_kernel(F.resize, MyDatapoint) is _noop
def resize_my_datapoint():
pass
_register_kernel_internal(F.resize, MyDatapoint, datapoint_wrapper=False)(resize_my_datapoint)
assert _get_kernel(F.resize, MyDatapoint) is resize_my_datapoint
......@@ -7,7 +7,7 @@ from torchvision import datapoints
from torchvision.transforms.functional import pil_to_tensor, to_pil_image
from torchvision.utils import _log_api_usage_once
from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_internal, is_simple_tensor
from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_internal
@_register_explicit_noop(datapoints.Mask, datapoints.BoundingBoxes, warn_passthrough=True)
......@@ -20,23 +20,16 @@ def erase(
v: torch.Tensor,
inplace: bool = False,
) -> Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]:
if not torch.jit.is_scripting():
_log_api_usage_once(erase)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
if torch.jit.is_scripting():
return erase_image_tensor(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
elif isinstance(inpt, datapoints.Datapoint):
kernel = _get_kernel(erase, type(inpt))
return kernel(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
elif isinstance(inpt, PIL.Image.Image):
return erase_image_pil(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
_log_api_usage_once(erase)
kernel = _get_kernel(erase, type(inpt))
return kernel(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
@_register_kernel_internal(erase, torch.Tensor)
@_register_kernel_internal(erase, datapoints.Image)
def erase_image_tensor(
image: torch.Tensor, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False
......@@ -48,7 +41,7 @@ def erase_image_tensor(
return image
@torch.jit.unused
@_register_kernel_internal(erase, PIL.Image.Image)
def erase_image_pil(
image: PIL.Image.Image, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False
) -> PIL.Image.Image:
......
......@@ -10,29 +10,20 @@ from torchvision.transforms._functional_tensor import _max_value
from torchvision.utils import _log_api_usage_once
from ._misc import _num_value_bits, to_dtype_image_tensor
from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_internal, is_simple_tensor
from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_internal
@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask, datapoints.Video)
def rgb_to_grayscale(
inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT], num_output_channels: int = 1
) -> Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]:
if not torch.jit.is_scripting():
_log_api_usage_once(rgb_to_grayscale)
if num_output_channels not in (1, 3):
raise ValueError(f"num_output_channels must be 1 or 3, got {num_output_channels}.")
if torch.jit.is_scripting() or is_simple_tensor(inpt):
if torch.jit.is_scripting():
return rgb_to_grayscale_image_tensor(inpt, num_output_channels=num_output_channels)
elif isinstance(inpt, datapoints.Datapoint):
kernel = _get_kernel(rgb_to_grayscale, type(inpt))
return kernel(inpt, num_output_channels=num_output_channels)
elif isinstance(inpt, PIL.Image.Image):
return rgb_to_grayscale_image_pil(inpt, num_output_channels=num_output_channels)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
_log_api_usage_once(rgb_to_grayscale)
kernel = _get_kernel(rgb_to_grayscale, type(inpt))
return kernel(inpt, num_output_channels=num_output_channels)
# `to_grayscale` actually predates `rgb_to_grayscale` in v1, but only handles PIL images. Since `rgb_to_grayscale` is a
......@@ -56,12 +47,19 @@ def _rgb_to_grayscale_image_tensor(
return l_img
@_register_kernel_internal(rgb_to_grayscale, torch.Tensor)
@_register_kernel_internal(rgb_to_grayscale, datapoints.Image)
def rgb_to_grayscale_image_tensor(image: torch.Tensor, num_output_channels: int = 1) -> torch.Tensor:
if num_output_channels not in (1, 3):
raise ValueError(f"num_output_channels must be 1 or 3, got {num_output_channels}.")
return _rgb_to_grayscale_image_tensor(image, num_output_channels=num_output_channels, preserve_dtype=True)
rgb_to_grayscale_image_pil = _FP.to_grayscale
@_register_kernel_internal(rgb_to_grayscale, PIL.Image.Image)
def rgb_to_grayscale_image_pil(image: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Image.Image:
if num_output_channels not in (1, 3):
raise ValueError(f"num_output_channels must be 1 or 3, got {num_output_channels}.")
return _FP.to_grayscale(image, num_output_channels=num_output_channels)
def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Tensor:
......@@ -74,23 +72,16 @@ def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Te
@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask)
def adjust_brightness(inpt: datapoints._InputTypeJIT, brightness_factor: float) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(adjust_brightness)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
if torch.jit.is_scripting():
return adjust_brightness_image_tensor(inpt, brightness_factor=brightness_factor)
elif isinstance(inpt, datapoints.Datapoint):
kernel = _get_kernel(adjust_brightness, type(inpt))
return kernel(inpt, brightness_factor=brightness_factor)
elif isinstance(inpt, PIL.Image.Image):
return adjust_brightness_image_pil(inpt, brightness_factor=brightness_factor)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
_log_api_usage_once(adjust_brightness)
kernel = _get_kernel(adjust_brightness, type(inpt))
return kernel(inpt, brightness_factor=brightness_factor)
@_register_kernel_internal(adjust_brightness, torch.Tensor)
@_register_kernel_internal(adjust_brightness, datapoints.Image)
def adjust_brightness_image_tensor(image: torch.Tensor, brightness_factor: float) -> torch.Tensor:
if brightness_factor < 0:
......@@ -106,6 +97,7 @@ def adjust_brightness_image_tensor(image: torch.Tensor, brightness_factor: float
return output if fp else output.to(image.dtype)
@_register_kernel_internal(adjust_brightness, PIL.Image.Image)
def adjust_brightness_image_pil(image: PIL.Image.Image, brightness_factor: float) -> PIL.Image.Image:
return _FP.adjust_brightness(image, brightness_factor=brightness_factor)
......@@ -117,23 +109,16 @@ def adjust_brightness_video(video: torch.Tensor, brightness_factor: float) -> to
@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask)
def adjust_saturation(inpt: datapoints._InputTypeJIT, saturation_factor: float) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(adjust_saturation)
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, datapoints.Datapoint)):
if torch.jit.is_scripting():
return adjust_saturation_image_tensor(inpt, saturation_factor=saturation_factor)
elif isinstance(inpt, datapoints.Datapoint):
kernel = _get_kernel(adjust_saturation, type(inpt))
return kernel(inpt, saturation_factor=saturation_factor)
elif isinstance(inpt, PIL.Image.Image):
return adjust_saturation_image_pil(inpt, saturation_factor=saturation_factor)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
_log_api_usage_once(adjust_saturation)
kernel = _get_kernel(adjust_saturation, type(inpt))
return kernel(inpt, saturation_factor=saturation_factor)
@_register_kernel_internal(adjust_saturation, torch.Tensor)
@_register_kernel_internal(adjust_saturation, datapoints.Image)
def adjust_saturation_image_tensor(image: torch.Tensor, saturation_factor: float) -> torch.Tensor:
if saturation_factor < 0:
......@@ -153,7 +138,7 @@ def adjust_saturation_image_tensor(image: torch.Tensor, saturation_factor: float
return _blend(image, grayscale_image, saturation_factor)
adjust_saturation_image_pil = _FP.adjust_saturation
adjust_saturation_image_pil = _register_kernel_internal(adjust_saturation, PIL.Image.Image)(_FP.adjust_saturation)
@_register_kernel_internal(adjust_saturation, datapoints.Video)
......@@ -163,23 +148,16 @@ def adjust_saturation_video(video: torch.Tensor, saturation_factor: float) -> to
@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask)
def adjust_contrast(inpt: datapoints._InputTypeJIT, contrast_factor: float) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(adjust_contrast)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
if torch.jit.is_scripting():
return adjust_contrast_image_tensor(inpt, contrast_factor=contrast_factor)
elif isinstance(inpt, datapoints.Datapoint):
kernel = _get_kernel(adjust_contrast, type(inpt))
return kernel(inpt, contrast_factor=contrast_factor)
elif isinstance(inpt, PIL.Image.Image):
return adjust_contrast_image_pil(inpt, contrast_factor=contrast_factor)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
_log_api_usage_once(adjust_contrast)
kernel = _get_kernel(adjust_contrast, type(inpt))
return kernel(inpt, contrast_factor=contrast_factor)
@_register_kernel_internal(adjust_contrast, torch.Tensor)
@_register_kernel_internal(adjust_contrast, datapoints.Image)
def adjust_contrast_image_tensor(image: torch.Tensor, contrast_factor: float) -> torch.Tensor:
if contrast_factor < 0:
......@@ -199,7 +177,7 @@ def adjust_contrast_image_tensor(image: torch.Tensor, contrast_factor: float) ->
return _blend(image, mean, contrast_factor)
adjust_contrast_image_pil = _FP.adjust_contrast
adjust_contrast_image_pil = _register_kernel_internal(adjust_contrast, PIL.Image.Image)(_FP.adjust_contrast)
@_register_kernel_internal(adjust_contrast, datapoints.Video)
......@@ -209,23 +187,16 @@ def adjust_contrast_video(video: torch.Tensor, contrast_factor: float) -> torch.
@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask)
def adjust_sharpness(inpt: datapoints._InputTypeJIT, sharpness_factor: float) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(adjust_sharpness)
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, datapoints.Datapoint)):
if torch.jit.is_scripting():
return adjust_sharpness_image_tensor(inpt, sharpness_factor=sharpness_factor)
elif isinstance(inpt, datapoints.Datapoint):
kernel = _get_kernel(adjust_sharpness, type(inpt))
return kernel(inpt, sharpness_factor=sharpness_factor)
elif isinstance(inpt, PIL.Image.Image):
return adjust_sharpness_image_pil(inpt, sharpness_factor=sharpness_factor)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
_log_api_usage_once(adjust_sharpness)
kernel = _get_kernel(adjust_sharpness, type(inpt))
return kernel(inpt, sharpness_factor=sharpness_factor)
@_register_kernel_internal(adjust_sharpness, torch.Tensor)
@_register_kernel_internal(adjust_sharpness, datapoints.Image)
def adjust_sharpness_image_tensor(image: torch.Tensor, sharpness_factor: float) -> torch.Tensor:
num_channels, height, width = image.shape[-3:]
......@@ -279,7 +250,7 @@ def adjust_sharpness_image_tensor(image: torch.Tensor, sharpness_factor: float)
return output
adjust_sharpness_image_pil = _FP.adjust_sharpness
adjust_sharpness_image_pil = _register_kernel_internal(adjust_sharpness, PIL.Image.Image)(_FP.adjust_sharpness)
@_register_kernel_internal(adjust_sharpness, datapoints.Video)
......@@ -289,21 +260,13 @@ def adjust_sharpness_video(video: torch.Tensor, sharpness_factor: float) -> torc
@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask)
def adjust_hue(inpt: datapoints._InputTypeJIT, hue_factor: float) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(adjust_hue)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
if torch.jit.is_scripting():
return adjust_hue_image_tensor(inpt, hue_factor=hue_factor)
elif isinstance(inpt, datapoints.Datapoint):
kernel = _get_kernel(adjust_hue, type(inpt))
return kernel(inpt, hue_factor=hue_factor)
elif isinstance(inpt, PIL.Image.Image):
return adjust_hue_image_pil(inpt, hue_factor=hue_factor)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
_log_api_usage_once(adjust_hue)
kernel = _get_kernel(adjust_hue, type(inpt))
return kernel(inpt, hue_factor=hue_factor)
def _rgb_to_hsv(image: torch.Tensor) -> torch.Tensor:
......@@ -370,6 +333,7 @@ def _hsv_to_rgb(img: torch.Tensor) -> torch.Tensor:
return (a4.mul_(mask.unsqueeze(dim=-4))).sum(dim=-3)
@_register_kernel_internal(adjust_hue, torch.Tensor)
@_register_kernel_internal(adjust_hue, datapoints.Image)
def adjust_hue_image_tensor(image: torch.Tensor, hue_factor: float) -> torch.Tensor:
if not (-0.5 <= hue_factor <= 0.5):
......@@ -398,7 +362,7 @@ def adjust_hue_image_tensor(image: torch.Tensor, hue_factor: float) -> torch.Ten
return to_dtype_image_tensor(image_hue_adj, orig_dtype, scale=True)
adjust_hue_image_pil = _FP.adjust_hue
adjust_hue_image_pil = _register_kernel_internal(adjust_hue, PIL.Image.Image)(_FP.adjust_hue)
@_register_kernel_internal(adjust_hue, datapoints.Video)
......@@ -408,23 +372,16 @@ def adjust_hue_video(video: torch.Tensor, hue_factor: float) -> torch.Tensor:
@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask)
def adjust_gamma(inpt: datapoints._InputTypeJIT, gamma: float, gain: float = 1) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(adjust_gamma)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
if torch.jit.is_scripting():
return adjust_gamma_image_tensor(inpt, gamma=gamma, gain=gain)
elif isinstance(inpt, datapoints.Datapoint):
kernel = _get_kernel(adjust_gamma, type(inpt))
return kernel(inpt, gamma=gamma, gain=gain)
elif isinstance(inpt, PIL.Image.Image):
return adjust_gamma_image_pil(inpt, gamma=gamma, gain=gain)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
_log_api_usage_once(adjust_gamma)
kernel = _get_kernel(adjust_gamma, type(inpt))
return kernel(inpt, gamma=gamma, gain=gain)
@_register_kernel_internal(adjust_gamma, torch.Tensor)
@_register_kernel_internal(adjust_gamma, datapoints.Image)
def adjust_gamma_image_tensor(image: torch.Tensor, gamma: float, gain: float = 1.0) -> torch.Tensor:
if gamma < 0:
......@@ -445,7 +402,7 @@ def adjust_gamma_image_tensor(image: torch.Tensor, gamma: float, gain: float = 1
return to_dtype_image_tensor(output, image.dtype, scale=True)
adjust_gamma_image_pil = _FP.adjust_gamma
adjust_gamma_image_pil = _register_kernel_internal(adjust_gamma, PIL.Image.Image)(_FP.adjust_gamma)
@_register_kernel_internal(adjust_gamma, datapoints.Video)
......@@ -455,23 +412,16 @@ def adjust_gamma_video(video: torch.Tensor, gamma: float, gain: float = 1) -> to
@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask)
def posterize(inpt: datapoints._InputTypeJIT, bits: int) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(posterize)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
if torch.jit.is_scripting():
return posterize_image_tensor(inpt, bits=bits)
elif isinstance(inpt, datapoints.Datapoint):
kernel = _get_kernel(posterize, type(inpt))
return kernel(inpt, bits=bits)
elif isinstance(inpt, PIL.Image.Image):
return posterize_image_pil(inpt, bits=bits)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
_log_api_usage_once(posterize)
kernel = _get_kernel(posterize, type(inpt))
return kernel(inpt, bits=bits)
@_register_kernel_internal(posterize, torch.Tensor)
@_register_kernel_internal(posterize, datapoints.Image)
def posterize_image_tensor(image: torch.Tensor, bits: int) -> torch.Tensor:
if image.is_floating_point():
......@@ -486,7 +436,7 @@ def posterize_image_tensor(image: torch.Tensor, bits: int) -> torch.Tensor:
return image & mask
posterize_image_pil = _FP.posterize
posterize_image_pil = _register_kernel_internal(posterize, PIL.Image.Image)(_FP.posterize)
@_register_kernel_internal(posterize, datapoints.Video)
......@@ -496,23 +446,16 @@ def posterize_video(video: torch.Tensor, bits: int) -> torch.Tensor:
@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask)
def solarize(inpt: datapoints._InputTypeJIT, threshold: float) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(solarize)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
if torch.jit.is_scripting():
return solarize_image_tensor(inpt, threshold=threshold)
elif isinstance(inpt, datapoints.Datapoint):
kernel = _get_kernel(solarize, type(inpt))
return kernel(inpt, threshold=threshold)
elif isinstance(inpt, PIL.Image.Image):
return solarize_image_pil(inpt, threshold=threshold)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
_log_api_usage_once(solarize)
kernel = _get_kernel(solarize, type(inpt))
return kernel(inpt, threshold=threshold)
@_register_kernel_internal(solarize, torch.Tensor)
@_register_kernel_internal(solarize, datapoints.Image)
def solarize_image_tensor(image: torch.Tensor, threshold: float) -> torch.Tensor:
if threshold > _max_value(image.dtype):
......@@ -521,7 +464,7 @@ def solarize_image_tensor(image: torch.Tensor, threshold: float) -> torch.Tensor
return torch.where(image >= threshold, invert_image_tensor(image), image)
solarize_image_pil = _FP.solarize
solarize_image_pil = _register_kernel_internal(solarize, PIL.Image.Image)(_FP.solarize)
@_register_kernel_internal(solarize, datapoints.Video)
......@@ -531,25 +474,16 @@ def solarize_video(video: torch.Tensor, threshold: float) -> torch.Tensor:
@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask)
def autocontrast(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(autocontrast)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
if torch.jit.is_scripting():
return autocontrast_image_tensor(inpt)
elif isinstance(inpt, datapoints.Datapoint):
kernel = _get_kernel(autocontrast, type(inpt))
return kernel(
inpt,
)
elif isinstance(inpt, PIL.Image.Image):
return autocontrast_image_pil(inpt)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
_log_api_usage_once(autocontrast)
kernel = _get_kernel(autocontrast, type(inpt))
return kernel(inpt)
@_register_kernel_internal(autocontrast, torch.Tensor)
@_register_kernel_internal(autocontrast, datapoints.Image)
def autocontrast_image_tensor(image: torch.Tensor) -> torch.Tensor:
c = image.shape[-3]
......@@ -580,7 +514,7 @@ def autocontrast_image_tensor(image: torch.Tensor) -> torch.Tensor:
return diff.div_(inv_scale).clamp_(0, bound).to(image.dtype)
autocontrast_image_pil = _FP.autocontrast
autocontrast_image_pil = _register_kernel_internal(autocontrast, PIL.Image.Image)(_FP.autocontrast)
@_register_kernel_internal(autocontrast, datapoints.Video)
......@@ -590,25 +524,16 @@ def autocontrast_video(video: torch.Tensor) -> torch.Tensor:
@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask)
def equalize(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(equalize)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
if torch.jit.is_scripting():
return equalize_image_tensor(inpt)
elif isinstance(inpt, datapoints.Datapoint):
kernel = _get_kernel(equalize, type(inpt))
return kernel(
inpt,
)
elif isinstance(inpt, PIL.Image.Image):
return equalize_image_pil(inpt)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
_log_api_usage_once(equalize)
kernel = _get_kernel(equalize, type(inpt))
return kernel(inpt)
@_register_kernel_internal(equalize, torch.Tensor)
@_register_kernel_internal(equalize, datapoints.Image)
def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor:
if image.numel() == 0:
......@@ -679,7 +604,7 @@ def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor:
return to_dtype_image_tensor(output, output_dtype, scale=True)
equalize_image_pil = _FP.equalize
equalize_image_pil = _register_kernel_internal(equalize, PIL.Image.Image)(_FP.equalize)
@_register_kernel_internal(equalize, datapoints.Video)
......@@ -689,25 +614,16 @@ def equalize_video(video: torch.Tensor) -> torch.Tensor:
@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask)
def invert(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(invert)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
if torch.jit.is_scripting():
return invert_image_tensor(inpt)
elif isinstance(inpt, datapoints.Datapoint):
kernel = _get_kernel(invert, type(inpt))
return kernel(
inpt,
)
elif isinstance(inpt, PIL.Image.Image):
return invert_image_pil(inpt)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
_log_api_usage_once(invert)
kernel = _get_kernel(invert, type(inpt))
return kernel(inpt)
@_register_kernel_internal(invert, torch.Tensor)
@_register_kernel_internal(invert, datapoints.Image)
def invert_image_tensor(image: torch.Tensor) -> torch.Tensor:
if image.is_floating_point():
......@@ -719,7 +635,7 @@ def invert_image_tensor(image: torch.Tensor) -> torch.Tensor:
return image.bitwise_xor((1 << _num_value_bits(image.dtype)) - 1)
invert_image_pil = _FP.invert
invert_image_pil = _register_kernel_internal(invert, PIL.Image.Image)(_FP.invert)
@_register_kernel_internal(invert, datapoints.Video)
......
......@@ -25,13 +25,7 @@ from torchvision.utils import _log_api_usage_once
from ._meta import clamp_bounding_boxes, convert_format_bounding_boxes, get_size_image_pil
from ._utils import (
_get_kernel,
_register_explicit_noop,
_register_five_ten_crop_kernel,
_register_kernel_internal,
is_simple_tensor,
)
from ._utils import _get_kernel, _register_explicit_noop, _register_five_ten_crop_kernel, _register_kernel_internal
def _check_interpolation(interpolation: Union[InterpolationMode, int]) -> InterpolationMode:
......@@ -46,30 +40,22 @@ def _check_interpolation(interpolation: Union[InterpolationMode, int]) -> Interp
def horizontal_flip(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(horizontal_flip)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
if torch.jit.is_scripting():
return horizontal_flip_image_tensor(inpt)
elif isinstance(inpt, datapoints.Datapoint):
kernel = _get_kernel(horizontal_flip, type(inpt))
return kernel(
inpt,
)
elif isinstance(inpt, PIL.Image.Image):
return horizontal_flip_image_pil(inpt)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
_log_api_usage_once(horizontal_flip)
kernel = _get_kernel(horizontal_flip, type(inpt))
return kernel(inpt)
@_register_kernel_internal(horizontal_flip, torch.Tensor)
@_register_kernel_internal(horizontal_flip, datapoints.Image)
def horizontal_flip_image_tensor(image: torch.Tensor) -> torch.Tensor:
return image.flip(-1)
@_register_kernel_internal(horizontal_flip, PIL.Image.Image)
def horizontal_flip_image_pil(image: PIL.Image.Image) -> PIL.Image.Image:
return _FP.hflip(image)
......@@ -110,30 +96,22 @@ def horizontal_flip_video(video: torch.Tensor) -> torch.Tensor:
def vertical_flip(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(vertical_flip)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
if torch.jit.is_scripting():
return vertical_flip_image_tensor(inpt)
elif isinstance(inpt, datapoints.Datapoint):
kernel = _get_kernel(vertical_flip, type(inpt))
return kernel(
inpt,
)
elif isinstance(inpt, PIL.Image.Image):
return vertical_flip_image_pil(inpt)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
_log_api_usage_once(vertical_flip)
kernel = _get_kernel(vertical_flip, type(inpt))
return kernel(inpt)
@_register_kernel_internal(vertical_flip, torch.Tensor)
@_register_kernel_internal(vertical_flip, datapoints.Image)
def vertical_flip_image_tensor(image: torch.Tensor) -> torch.Tensor:
return image.flip(-2)
@_register_kernel_internal(vertical_flip, PIL.Image.Image)
def vertical_flip_image_pil(image: PIL.Image) -> PIL.Image:
return _FP.vflip(image)
......@@ -199,24 +177,16 @@ def resize(
max_size: Optional[int] = None,
antialias: Optional[Union[str, bool]] = "warn",
) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(resize)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return resize_image_tensor(inpt, size, interpolation=interpolation, max_size=max_size, antialias=antialias)
elif isinstance(inpt, datapoints.Datapoint):
kernel = _get_kernel(resize, type(inpt))
return kernel(inpt, size, interpolation=interpolation, max_size=max_size, antialias=antialias)
elif isinstance(inpt, PIL.Image.Image):
if antialias is False:
warnings.warn("Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.")
return resize_image_pil(inpt, size, interpolation=interpolation, max_size=max_size)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
if torch.jit.is_scripting():
return resize_image_tensor(inpt, size=size, interpolation=interpolation, max_size=max_size, antialias=antialias)
_log_api_usage_once(resize)
kernel = _get_kernel(resize, type(inpt))
return kernel(inpt, size=size, interpolation=interpolation, max_size=max_size, antialias=antialias)
@_register_kernel_internal(resize, torch.Tensor)
@_register_kernel_internal(resize, datapoints.Image)
def resize_image_tensor(
image: torch.Tensor,
......@@ -297,7 +267,6 @@ def resize_image_tensor(
return image.reshape(shape[:-3] + (num_channels, new_height, new_width))
@torch.jit.unused
def resize_image_pil(
image: PIL.Image.Image,
size: Union[Sequence[int], int],
......@@ -319,6 +288,19 @@ def resize_image_pil(
return image.resize((new_width, new_height), resample=pil_modes_mapping[interpolation])
@_register_kernel_internal(resize, PIL.Image.Image)
def _resize_image_pil_dispatch(
image: PIL.Image.Image,
size: Union[Sequence[int], int],
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
max_size: Optional[int] = None,
antialias: Optional[Union[str, bool]] = "warn",
) -> PIL.Image.Image:
if antialias is False:
warnings.warn("Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.")
return resize_image_pil(image, size=size, interpolation=interpolation, max_size=max_size)
def resize_mask(mask: torch.Tensor, size: List[int], max_size: Optional[int] = None) -> torch.Tensor:
if mask.ndim < 3:
mask = mask.unsqueeze(0)
......@@ -391,26 +373,10 @@ def affine(
fill: datapoints._FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(affine)
# TODO: consider deprecating integers from angle and shear on the future
if torch.jit.is_scripting() or is_simple_tensor(inpt):
if torch.jit.is_scripting():
return affine_image_tensor(
inpt,
angle,
translate=translate,
scale=scale,
shear=shear,
interpolation=interpolation,
fill=fill,
center=center,
)
elif isinstance(inpt, datapoints.Datapoint):
kernel = _get_kernel(affine, type(inpt))
return kernel(
inpt,
angle,
angle=angle,
translate=translate,
scale=scale,
shear=shear,
......@@ -418,22 +384,20 @@ def affine(
fill=fill,
center=center,
)
elif isinstance(inpt, PIL.Image.Image):
return affine_image_pil(
inpt,
angle,
translate=translate,
scale=scale,
shear=shear,
interpolation=interpolation,
fill=fill,
center=center,
)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
_log_api_usage_once(affine)
kernel = _get_kernel(affine, type(inpt))
return kernel(
inpt,
angle=angle,
translate=translate,
scale=scale,
shear=shear,
interpolation=interpolation,
fill=fill,
center=center,
)
def _affine_parse_args(
......@@ -684,6 +648,7 @@ def _affine_grid(
return output_grid.view(1, oh, ow, 2)
@_register_kernel_internal(affine, torch.Tensor)
@_register_kernel_internal(affine, datapoints.Image)
def affine_image_tensor(
image: torch.Tensor,
......@@ -736,7 +701,7 @@ def affine_image_tensor(
return output
@torch.jit.unused
@_register_kernel_internal(affine, PIL.Image.Image)
def affine_image_pil(
image: PIL.Image.Image,
angle: Union[int, float],
......@@ -983,23 +948,18 @@ def rotate(
center: Optional[List[float]] = None,
fill: datapoints._FillTypeJIT = None,
) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(rotate)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return rotate_image_tensor(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center)
elif isinstance(inpt, datapoints.Datapoint):
kernel = _get_kernel(rotate, type(inpt))
return kernel(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center)
elif isinstance(inpt, PIL.Image.Image):
return rotate_image_pil(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
if torch.jit.is_scripting():
return rotate_image_tensor(
inpt, angle=angle, interpolation=interpolation, expand=expand, fill=fill, center=center
)
_log_api_usage_once(rotate)
kernel = _get_kernel(rotate, type(inpt))
return kernel(inpt, angle=angle, interpolation=interpolation, expand=expand, fill=fill, center=center)
@_register_kernel_internal(rotate, torch.Tensor)
@_register_kernel_internal(rotate, datapoints.Image)
def rotate_image_tensor(
image: torch.Tensor,
......@@ -1045,7 +1005,7 @@ def rotate_image_tensor(
return output.reshape(shape[:-3] + (num_channels, new_height, new_width))
@torch.jit.unused
@_register_kernel_internal(rotate, PIL.Image.Image)
def rotate_image_pil(
image: PIL.Image.Image,
angle: float,
......@@ -1162,22 +1122,13 @@ def pad(
fill: Optional[Union[int, float, List[float]]] = None,
padding_mode: str = "constant",
) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(pad)
if torch.jit.is_scripting():
return pad_image_tensor(inpt, padding=padding, fill=fill, padding_mode=padding_mode)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return pad_image_tensor(inpt, padding, fill=fill, padding_mode=padding_mode)
_log_api_usage_once(pad)
elif isinstance(inpt, datapoints.Datapoint):
kernel = _get_kernel(pad, type(inpt))
return kernel(inpt, padding, fill=fill, padding_mode=padding_mode)
elif isinstance(inpt, PIL.Image.Image):
return pad_image_pil(inpt, padding, fill=fill, padding_mode=padding_mode)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
kernel = _get_kernel(pad, type(inpt))
return kernel(inpt, padding=padding, fill=fill, padding_mode=padding_mode)
def _parse_pad_padding(padding: Union[int, List[int]]) -> List[int]:
......@@ -1204,6 +1155,7 @@ def _parse_pad_padding(padding: Union[int, List[int]]) -> List[int]:
return [pad_left, pad_right, pad_top, pad_bottom]
@_register_kernel_internal(pad, torch.Tensor)
@_register_kernel_internal(pad, datapoints.Image)
def pad_image_tensor(
image: torch.Tensor,
......@@ -1303,7 +1255,7 @@ def _pad_with_vector_fill(
return output
pad_image_pil = _FP.pad
pad_image_pil = _register_kernel_internal(pad, PIL.Image.Image)(_FP.pad)
@_register_kernel_internal(pad, datapoints.Mask)
......@@ -1385,23 +1337,16 @@ def pad_video(
def crop(inpt: datapoints._InputTypeJIT, top: int, left: int, height: int, width: int) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(crop)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return crop_image_tensor(inpt, top, left, height, width)
elif isinstance(inpt, datapoints.Datapoint):
kernel = _get_kernel(crop, type(inpt))
return kernel(inpt, top, left, height, width)
elif isinstance(inpt, PIL.Image.Image):
return crop_image_pil(inpt, top, left, height, width)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
if torch.jit.is_scripting():
return crop_image_tensor(inpt, top=top, left=left, height=height, width=width)
_log_api_usage_once(crop)
kernel = _get_kernel(crop, type(inpt))
return kernel(inpt, top=top, left=left, height=height, width=width)
@_register_kernel_internal(crop, torch.Tensor)
@_register_kernel_internal(crop, datapoints.Image)
def crop_image_tensor(image: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor:
h, w = image.shape[-2:]
......@@ -1422,6 +1367,7 @@ def crop_image_tensor(image: torch.Tensor, top: int, left: int, height: int, wid
crop_image_pil = _FP.crop
_register_kernel_internal(crop, PIL.Image.Image)(crop_image_pil)
def crop_bounding_boxes(
......@@ -1484,25 +1430,28 @@ def perspective(
fill: datapoints._FillTypeJIT = None,
coefficients: Optional[List[float]] = None,
) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(perspective)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
if torch.jit.is_scripting():
return perspective_image_tensor(
inpt, startpoints, endpoints, interpolation=interpolation, fill=fill, coefficients=coefficients
)
elif isinstance(inpt, datapoints.Datapoint):
kernel = _get_kernel(perspective, type(inpt))
return kernel(inpt, startpoints, endpoints, interpolation=interpolation, fill=fill, coefficients=coefficients)
elif isinstance(inpt, PIL.Image.Image):
return perspective_image_pil(
inpt, startpoints, endpoints, interpolation=interpolation, fill=fill, coefficients=coefficients
)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
inpt,
startpoints=startpoints,
endpoints=endpoints,
interpolation=interpolation,
fill=fill,
coefficients=coefficients,
)
_log_api_usage_once(perspective)
kernel = _get_kernel(perspective, type(inpt))
return kernel(
inpt,
startpoints=startpoints,
endpoints=endpoints,
interpolation=interpolation,
fill=fill,
coefficients=coefficients,
)
def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
# https://github.com/python-pillow/Pillow/blob/4634eafe3c695a014267eefdce830b4a825beed7/
......@@ -1551,6 +1500,7 @@ def _perspective_coefficients(
raise ValueError("Either the startpoints/endpoints or the coefficients must have non `None` values.")
@_register_kernel_internal(perspective, torch.Tensor)
@_register_kernel_internal(perspective, datapoints.Image)
def perspective_image_tensor(
image: torch.Tensor,
......@@ -1598,7 +1548,7 @@ def perspective_image_tensor(
return output
@torch.jit.unused
@_register_kernel_internal(perspective, PIL.Image.Image)
def perspective_image_pil(
image: PIL.Image.Image,
startpoints: Optional[List[List[int]]],
......@@ -1787,29 +1737,19 @@ def elastic(
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: datapoints._FillTypeJIT = None,
) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(elastic)
if not isinstance(displacement, torch.Tensor):
raise TypeError("Argument displacement should be a Tensor")
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return elastic_image_tensor(inpt, displacement, interpolation=interpolation, fill=fill)
elif isinstance(inpt, datapoints.Datapoint):
kernel = _get_kernel(elastic, type(inpt))
return kernel(inpt, displacement, interpolation=interpolation, fill=fill)
elif isinstance(inpt, PIL.Image.Image):
return elastic_image_pil(inpt, displacement, interpolation=interpolation, fill=fill)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
if torch.jit.is_scripting():
return elastic_image_tensor(inpt, displacement=displacement, interpolation=interpolation, fill=fill)
_log_api_usage_once(elastic)
kernel = _get_kernel(elastic, type(inpt))
return kernel(inpt, displacement=displacement, interpolation=interpolation, fill=fill)
elastic_transform = elastic
@_register_kernel_internal(elastic, torch.Tensor)
@_register_kernel_internal(elastic, datapoints.Image)
def elastic_image_tensor(
image: torch.Tensor,
......@@ -1867,7 +1807,7 @@ def elastic_image_tensor(
return output
@torch.jit.unused
@_register_kernel_internal(elastic, PIL.Image.Image)
def elastic_image_pil(
image: PIL.Image.Image,
displacement: torch.Tensor,
......@@ -1990,21 +1930,13 @@ def elastic_video(
def center_crop(inpt: datapoints._InputTypeJIT, output_size: List[int]) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(center_crop)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return center_crop_image_tensor(inpt, output_size)
elif isinstance(inpt, datapoints.Datapoint):
kernel = _get_kernel(center_crop, type(inpt))
return kernel(inpt, output_size)
elif isinstance(inpt, PIL.Image.Image):
return center_crop_image_pil(inpt, output_size)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
if torch.jit.is_scripting():
return center_crop_image_tensor(inpt, output_size=output_size)
_log_api_usage_once(center_crop)
kernel = _get_kernel(center_crop, type(inpt))
return kernel(inpt, output_size=output_size)
def _center_crop_parse_output_size(output_size: List[int]) -> List[int]:
......@@ -2034,6 +1966,7 @@ def _center_crop_compute_crop_anchor(
return crop_top, crop_left
@_register_kernel_internal(center_crop, torch.Tensor)
@_register_kernel_internal(center_crop, datapoints.Image)
def center_crop_image_tensor(image: torch.Tensor, output_size: List[int]) -> torch.Tensor:
crop_height, crop_width = _center_crop_parse_output_size(output_size)
......@@ -2054,7 +1987,7 @@ def center_crop_image_tensor(image: torch.Tensor, output_size: List[int]) -> tor
return image[..., crop_top : (crop_top + crop_height), crop_left : (crop_left + crop_width)]
@torch.jit.unused
@_register_kernel_internal(center_crop, PIL.Image.Image)
def center_crop_image_pil(image: PIL.Image.Image, output_size: List[int]) -> PIL.Image.Image:
crop_height, crop_width = _center_crop_parse_output_size(output_size)
image_height, image_width = get_size_image_pil(image)
......@@ -2125,25 +2058,34 @@ def resized_crop(
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
antialias: Optional[Union[str, bool]] = "warn",
) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(resized_crop)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
if torch.jit.is_scripting():
return resized_crop_image_tensor(
inpt, top, left, height, width, antialias=antialias, size=size, interpolation=interpolation
)
elif isinstance(inpt, datapoints.Datapoint):
kernel = _get_kernel(resized_crop, type(inpt))
return kernel(inpt, top, left, height, width, antialias=antialias, size=size, interpolation=interpolation)
elif isinstance(inpt, PIL.Image.Image):
return resized_crop_image_pil(inpt, top, left, height, width, size=size, interpolation=interpolation)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
inpt,
top=top,
left=left,
height=height,
width=width,
size=size,
interpolation=interpolation,
antialias=antialias,
)
_log_api_usage_once(resized_crop)
kernel = _get_kernel(resized_crop, type(inpt))
return kernel(
inpt,
top=top,
left=left,
height=height,
width=width,
size=size,
interpolation=interpolation,
antialias=antialias,
)
@_register_kernel_internal(resized_crop, torch.Tensor)
@_register_kernel_internal(resized_crop, datapoints.Image)
def resized_crop_image_tensor(
image: torch.Tensor,
......@@ -2159,7 +2101,6 @@ def resized_crop_image_tensor(
return resize_image_tensor(image, size, interpolation=interpolation, antialias=antialias)
@torch.jit.unused
def resized_crop_image_pil(
image: PIL.Image.Image,
top: int,
......@@ -2173,6 +2114,30 @@ def resized_crop_image_pil(
return resize_image_pil(image, size, interpolation=interpolation)
@_register_kernel_internal(resized_crop, PIL.Image.Image)
def resized_crop_image_pil_dispatch(
image: PIL.Image.Image,
top: int,
left: int,
height: int,
width: int,
size: List[int],
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
antialias: Optional[Union[str, bool]] = "warn",
) -> PIL.Image.Image:
if antialias is False:
warnings.warn("Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.")
return resized_crop_image_pil(
image,
top=top,
left=left,
height=height,
width=width,
size=size,
interpolation=interpolation,
)
def resized_crop_bounding_boxes(
bounding_boxes: torch.Tensor,
format: datapoints.BoundingBoxFormat,
......@@ -2244,21 +2209,13 @@ def five_crop(
datapoints._InputTypeJIT,
datapoints._InputTypeJIT,
]:
if not torch.jit.is_scripting():
_log_api_usage_once(five_crop)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return five_crop_image_tensor(inpt, size)
elif isinstance(inpt, datapoints.Datapoint):
kernel = _get_kernel(five_crop, type(inpt))
return kernel(inpt, size)
elif isinstance(inpt, PIL.Image.Image):
return five_crop_image_pil(inpt, size)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
if torch.jit.is_scripting():
return five_crop_image_tensor(inpt, size=size)
_log_api_usage_once(five_crop)
kernel = _get_kernel(five_crop, type(inpt))
return kernel(inpt, size=size)
def _parse_five_crop_size(size: List[int]) -> List[int]:
......@@ -2275,6 +2232,7 @@ def _parse_five_crop_size(size: List[int]) -> List[int]:
return size
@_register_five_ten_crop_kernel(five_crop, torch.Tensor)
@_register_five_ten_crop_kernel(five_crop, datapoints.Image)
def five_crop_image_tensor(
image: torch.Tensor, size: List[int]
......@@ -2294,7 +2252,7 @@ def five_crop_image_tensor(
return tl, tr, bl, br, center
@torch.jit.unused
@_register_five_ten_crop_kernel(five_crop, PIL.Image.Image)
def five_crop_image_pil(
image: PIL.Image.Image, size: List[int]
) -> Tuple[PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image]:
......@@ -2335,23 +2293,16 @@ def ten_crop(
datapoints._InputTypeJIT,
datapoints._InputTypeJIT,
]:
if not torch.jit.is_scripting():
_log_api_usage_once(ten_crop)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return ten_crop_image_tensor(inpt, size, vertical_flip=vertical_flip)
elif isinstance(inpt, datapoints.Datapoint):
kernel = _get_kernel(ten_crop, type(inpt))
return kernel(inpt, size, vertical_flip=vertical_flip)
elif isinstance(inpt, PIL.Image.Image):
return ten_crop_image_pil(inpt, size, vertical_flip=vertical_flip)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
if torch.jit.is_scripting():
return ten_crop_image_tensor(inpt, size=size, vertical_flip=vertical_flip)
_log_api_usage_once(ten_crop)
kernel = _get_kernel(ten_crop, type(inpt))
return kernel(inpt, size=size, vertical_flip=vertical_flip)
@_register_five_ten_crop_kernel(ten_crop, torch.Tensor)
@_register_five_ten_crop_kernel(ten_crop, datapoints.Image)
def ten_crop_image_tensor(
image: torch.Tensor, size: List[int], vertical_flip: bool = False
......@@ -2379,7 +2330,7 @@ def ten_crop_image_tensor(
return non_flipped + flipped
@torch.jit.unused
@_register_five_ten_crop_kernel(ten_crop, PIL.Image.Image)
def ten_crop_image_pil(
image: PIL.Image.Image, size: List[int], vertical_flip: bool = False
) -> Tuple[
......
......@@ -13,23 +13,16 @@ from ._utils import _get_kernel, _register_kernel_internal, _register_unsupporte
@_register_unsupported_type(datapoints.BoundingBoxes, datapoints.Mask)
def get_dimensions(inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]) -> List[int]:
if not torch.jit.is_scripting():
_log_api_usage_once(get_dimensions)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
if torch.jit.is_scripting():
return get_dimensions_image_tensor(inpt)
elif isinstance(inpt, datapoints.Datapoint):
kernel = _get_kernel(get_dimensions, type(inpt))
return kernel(inpt)
elif isinstance(inpt, PIL.Image.Image):
return get_dimensions_image_pil(inpt)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
_log_api_usage_once(get_dimensions)
kernel = _get_kernel(get_dimensions, type(inpt))
return kernel(inpt)
@_register_kernel_internal(get_dimensions, torch.Tensor)
@_register_kernel_internal(get_dimensions, datapoints.Image, datapoint_wrapper=False)
def get_dimensions_image_tensor(image: torch.Tensor) -> List[int]:
chw = list(image.shape[-3:])
......@@ -43,7 +36,7 @@ def get_dimensions_image_tensor(image: torch.Tensor) -> List[int]:
raise TypeError(f"Input tensor should have at least two dimensions, but got {ndims}")
get_dimensions_image_pil = _FP.get_dimensions
get_dimensions_image_pil = _register_kernel_internal(get_dimensions, PIL.Image.Image)(_FP.get_dimensions)
@_register_kernel_internal(get_dimensions, datapoints.Video, datapoint_wrapper=False)
......@@ -53,23 +46,16 @@ def get_dimensions_video(video: torch.Tensor) -> List[int]:
@_register_unsupported_type(datapoints.BoundingBoxes, datapoints.Mask)
def get_num_channels(inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]) -> int:
if not torch.jit.is_scripting():
_log_api_usage_once(get_num_channels)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
if torch.jit.is_scripting():
return get_num_channels_image_tensor(inpt)
elif isinstance(inpt, datapoints.Datapoint):
kernel = _get_kernel(get_num_channels, type(inpt))
return kernel(inpt)
elif isinstance(inpt, PIL.Image.Image):
return get_num_channels_image_pil(inpt)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
_log_api_usage_once(get_num_channels)
kernel = _get_kernel(get_num_channels, type(inpt))
return kernel(inpt)
@_register_kernel_internal(get_num_channels, torch.Tensor)
@_register_kernel_internal(get_num_channels, datapoints.Image, datapoint_wrapper=False)
def get_num_channels_image_tensor(image: torch.Tensor) -> int:
chw = image.shape[-3:]
......@@ -82,7 +68,7 @@ def get_num_channels_image_tensor(image: torch.Tensor) -> int:
raise TypeError(f"Input tensor should have at least two dimensions, but got {ndims}")
get_num_channels_image_pil = _FP.get_image_num_channels
get_num_channels_image_pil = _register_kernel_internal(get_num_channels, PIL.Image.Image)(_FP.get_image_num_channels)
@_register_kernel_internal(get_num_channels, datapoints.Video, datapoint_wrapper=False)
......@@ -96,23 +82,16 @@ get_image_num_channels = get_num_channels
def get_size(inpt: datapoints._InputTypeJIT) -> List[int]:
if not torch.jit.is_scripting():
_log_api_usage_once(get_size)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
if torch.jit.is_scripting():
return get_size_image_tensor(inpt)
elif isinstance(inpt, datapoints.Datapoint):
kernel = _get_kernel(get_size, type(inpt))
return kernel(inpt)
elif isinstance(inpt, PIL.Image.Image):
return get_size_image_pil(inpt)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
_log_api_usage_once(get_size)
kernel = _get_kernel(get_size, type(inpt))
return kernel(inpt)
@_register_kernel_internal(get_size, torch.Tensor)
@_register_kernel_internal(get_size, datapoints.Image, datapoint_wrapper=False)
def get_size_image_tensor(image: torch.Tensor) -> List[int]:
hw = list(image.shape[-2:])
......@@ -123,7 +102,7 @@ def get_size_image_tensor(image: torch.Tensor) -> List[int]:
raise TypeError(f"Input tensor should have at least two dimensions, but got {ndims}")
@torch.jit.unused
@_register_kernel_internal(get_size, PIL.Image.Image)
def get_size_image_pil(image: PIL.Image.Image) -> List[int]:
width, height = _FP.get_image_size(image)
return [height, width]
......@@ -146,21 +125,16 @@ def get_size_bounding_boxes(bounding_box: datapoints.BoundingBoxes) -> List[int]
@_register_unsupported_type(PIL.Image.Image, datapoints.Image, datapoints.BoundingBoxes, datapoints.Mask)
def get_num_frames(inpt: datapoints._VideoTypeJIT) -> int:
if not torch.jit.is_scripting():
_log_api_usage_once(get_num_frames)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
if torch.jit.is_scripting():
return get_num_frames_video(inpt)
elif isinstance(inpt, datapoints.Datapoint):
kernel = _get_kernel(get_num_frames, type(inpt))
return kernel(inpt)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
_log_api_usage_once(get_num_frames)
kernel = _get_kernel(get_num_frames, type(inpt))
return kernel(inpt)
@_register_kernel_internal(get_num_frames, torch.Tensor)
@_register_kernel_internal(get_num_frames, datapoints.Video, datapoint_wrapper=False)
def get_num_frames_video(video: torch.Tensor) -> int:
return video.shape[-4]
......
......@@ -11,13 +11,7 @@ from torchvision.transforms.functional import pil_to_tensor, to_pil_image
from torchvision.utils import _log_api_usage_once
from ._utils import (
_get_kernel,
_register_explicit_noop,
_register_kernel_internal,
_register_unsupported_type,
is_simple_tensor,
)
from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_internal, _register_unsupported_type
@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask)
......@@ -28,19 +22,16 @@ def normalize(
std: List[float],
inplace: bool = False,
) -> torch.Tensor:
if not torch.jit.is_scripting():
_log_api_usage_once(normalize)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
if torch.jit.is_scripting():
return normalize_image_tensor(inpt, mean=mean, std=std, inplace=inplace)
elif isinstance(inpt, datapoints.Datapoint):
kernel = _get_kernel(normalize, type(inpt))
return kernel(inpt, mean=mean, std=std, inplace=inplace)
else:
raise TypeError(
f"Input can either be a plain tensor or any TorchVision datapoint, but got {type(inpt)} instead."
)
_log_api_usage_once(normalize)
kernel = _get_kernel(normalize, type(inpt))
return kernel(inpt, mean=mean, std=std, inplace=inplace)
@_register_kernel_internal(normalize, torch.Tensor)
@_register_kernel_internal(normalize, datapoints.Image)
def normalize_image_tensor(
image: torch.Tensor, mean: List[float], std: List[float], inplace: bool = False
......@@ -86,21 +77,13 @@ def normalize_video(video: torch.Tensor, mean: List[float], std: List[float], in
def gaussian_blur(
inpt: datapoints._InputTypeJIT, kernel_size: List[int], sigma: Optional[List[float]] = None
) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(gaussian_blur)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
if torch.jit.is_scripting():
return gaussian_blur_image_tensor(inpt, kernel_size=kernel_size, sigma=sigma)
elif isinstance(inpt, datapoints.Datapoint):
kernel = _get_kernel(gaussian_blur, type(inpt))
return kernel(inpt, kernel_size=kernel_size, sigma=sigma)
elif isinstance(inpt, PIL.Image.Image):
return gaussian_blur_image_pil(inpt, kernel_size=kernel_size, sigma=sigma)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
_log_api_usage_once(gaussian_blur)
kernel = _get_kernel(gaussian_blur, type(inpt))
return kernel(inpt, kernel_size=kernel_size, sigma=sigma)
def _get_gaussian_kernel1d(kernel_size: int, sigma: float, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
......@@ -119,6 +102,7 @@ def _get_gaussian_kernel2d(
return kernel2d
@_register_kernel_internal(gaussian_blur, torch.Tensor)
@_register_kernel_internal(gaussian_blur, datapoints.Image)
def gaussian_blur_image_tensor(
image: torch.Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None
......@@ -184,7 +168,7 @@ def gaussian_blur_image_tensor(
return output
@torch.jit.unused
@_register_kernel_internal(gaussian_blur, PIL.Image.Image)
def gaussian_blur_image_pil(
image: PIL.Image.Image, kernel_size: List[int], sigma: Optional[List[float]] = None
) -> PIL.Image.Image:
......@@ -200,21 +184,17 @@ def gaussian_blur_video(
return gaussian_blur_image_tensor(video, kernel_size, sigma)
@_register_unsupported_type(PIL.Image.Image)
def to_dtype(
inpt: datapoints._InputTypeJIT, dtype: torch.dtype = torch.float, scale: bool = False
) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(to_dtype)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return to_dtype_image_tensor(inpt, dtype, scale=scale)
elif isinstance(inpt, datapoints.Datapoint):
kernel = _get_kernel(to_dtype, type(inpt))
return kernel(inpt, dtype, scale=scale)
else:
raise TypeError(
f"Input can either be a plain tensor or any TorchVision datapoint, but got {type(inpt)} instead."
)
if torch.jit.is_scripting():
return to_dtype_image_tensor(inpt, dtype=dtype, scale=scale)
_log_api_usage_once(to_dtype)
kernel = _get_kernel(to_dtype, type(inpt))
return kernel(inpt, dtype=dtype, scale=scale)
def _num_value_bits(dtype: torch.dtype) -> int:
......@@ -232,6 +212,7 @@ def _num_value_bits(dtype: torch.dtype) -> int:
raise TypeError(f"Number of value bits is only defined for integer dtypes, but got {dtype}.")
@_register_kernel_internal(to_dtype, torch.Tensor)
@_register_kernel_internal(to_dtype, datapoints.Image)
def to_dtype_image_tensor(image: torch.Tensor, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor:
......
......@@ -5,27 +5,23 @@ from torchvision import datapoints
from torchvision.utils import _log_api_usage_once
from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_internal, is_simple_tensor
from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_internal
@_register_explicit_noop(
PIL.Image.Image, datapoints.Image, datapoints.BoundingBoxes, datapoints.Mask, warn_passthrough=True
)
def uniform_temporal_subsample(inpt: datapoints._VideoTypeJIT, num_samples: int) -> datapoints._VideoTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(uniform_temporal_subsample)
if torch.jit.is_scripting():
return uniform_temporal_subsample_video(inpt, num_samples=num_samples)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return uniform_temporal_subsample_video(inpt, num_samples)
elif isinstance(inpt, datapoints.Datapoint):
kernel = _get_kernel(uniform_temporal_subsample, type(inpt))
return kernel(inpt, num_samples)
else:
raise TypeError(
f"Input can either be a plain tensor or any TorchVision datapoint, but got {type(inpt)} instead."
)
_log_api_usage_once(uniform_temporal_subsample)
kernel = _get_kernel(uniform_temporal_subsample, type(inpt))
return kernel(inpt, num_samples=num_samples)
@_register_kernel_internal(uniform_temporal_subsample, torch.Tensor)
@_register_kernel_internal(uniform_temporal_subsample, datapoints.Video)
def uniform_temporal_subsample_video(video: torch.Tensor, num_samples: int) -> torch.Tensor:
# Reference: https://github.com/facebookresearch/pytorchvideo/blob/a0a131e/pytorchvideo/transforms/functional.py#L19
......
......@@ -23,15 +23,17 @@ def _kernel_datapoint_wrapper(kernel):
return wrapper
def _register_kernel_internal(dispatcher, datapoint_cls, *, datapoint_wrapper=True):
def _register_kernel_internal(dispatcher, input_type, *, datapoint_wrapper=True):
registry = _KERNEL_REGISTRY.setdefault(dispatcher, {})
if datapoint_cls in registry:
raise TypeError(
f"Dispatcher '{dispatcher.__name__}' already has a kernel registered for type '{datapoint_cls.__name__}'."
)
if input_type in registry:
raise ValueError(f"Dispatcher {dispatcher} already has a kernel registered for type {input_type}.")
def decorator(kernel):
registry[datapoint_cls] = _kernel_datapoint_wrapper(kernel) if datapoint_wrapper else kernel
registry[input_type] = (
_kernel_datapoint_wrapper(kernel)
if issubclass(input_type, datapoints.Datapoint) and datapoint_wrapper
else kernel
)
return kernel
return decorator
......@@ -43,7 +45,9 @@ def _name_to_dispatcher(name):
try:
return getattr(torchvision.transforms.v2.functional, name)
except AttributeError:
raise ValueError(f"Could not find dispatcher with name '{name}'.") from None
raise ValueError(
f"Could not find dispatcher with name '{name}' in torchvision.transforms.v2.functional."
) from None
def register_kernel(dispatcher, datapoint_cls):
......@@ -54,22 +58,57 @@ def register_kernel(dispatcher, datapoint_cls):
"""
if isinstance(dispatcher, str):
dispatcher = _name_to_dispatcher(name=dispatcher)
elif not (
callable(dispatcher)
and getattr(dispatcher, "__module__", "").startswith("torchvision.transforms.v2.functional")
):
raise ValueError(
f"Kernels can only be registered on dispatchers from the torchvision.transforms.v2.functional namespace, "
f"but got {dispatcher}."
)
if not (
isinstance(datapoint_cls, type)
and issubclass(datapoint_cls, datapoints.Datapoint)
and datapoint_cls is not datapoints.Datapoint
):
raise ValueError(
f"Kernels can only be registered for subclasses of torchvision.datapoints.Datapoint, "
f"but got {datapoint_cls}."
)
return _register_kernel_internal(dispatcher, datapoint_cls, datapoint_wrapper=False)
def _get_kernel(dispatcher, datapoint_cls):
def _get_kernel(dispatcher, input_type):
registry = _KERNEL_REGISTRY.get(dispatcher)
if not registry:
raise ValueError(f"No kernel registered for dispatcher '{dispatcher.__name__}'.")
if datapoint_cls in registry:
return registry[datapoint_cls]
for registered_cls, kernel in registry.items():
if issubclass(datapoint_cls, registered_cls):
return kernel
return _noop
raise ValueError(f"No kernel registered for dispatcher {dispatcher.__name__}.")
# In case we have an exact type match, we take a shortcut.
if input_type in registry:
return registry[input_type]
# In case of datapoints, we check if we have a kernel for a superclass registered
if issubclass(input_type, datapoints.Datapoint):
# Since we have already checked for an exact match above, we can start the traversal at the superclass.
for cls in input_type.__mro__[1:]:
if cls is datapoints.Datapoint:
# We don't want user-defined datapoints to dispatch to the pure Tensor kernels, so we explicit stop the
# MRO traversal before hitting torch.Tensor. We can even stop at datapoints.Datapoint, since we don't
# allow kernels to be registered for datapoints.Datapoint anyway.
break
elif cls in registry:
return registry[cls]
# Note that in the future we are not going to return a noop here, but rather raise the error below
return _noop
raise TypeError(
f"Dispatcher {dispatcher} supports inputs of type torch.Tensor, PIL.Image.Image, "
f"and subclasses of torchvision.datapoints.Datapoint, "
f"but got {input_type} instead."
)
# Everything below this block is stuff that we need right now, since it looks like we need to release in an intermediate
......@@ -101,7 +140,9 @@ def _register_explicit_noop(*datapoints_classes, warn_passthrough=False):
f"F.{dispatcher.__name__} is currently passing through inputs of type datapoints.{cls.__name__}. "
f"This will likely change in the future."
)
register_kernel(dispatcher, cls)(functools.partial(_noop, __msg__=msg if warn_passthrough else None))
_register_kernel_internal(dispatcher, cls, datapoint_wrapper=False)(
functools.partial(_noop, __msg__=msg if warn_passthrough else None)
)
return dispatcher
return decorator
......@@ -115,13 +156,15 @@ def _noop(inpt, *args, __msg__=None, **kwargs):
# TODO: we only need this, since our default behavior in case no kernel is found is passthrough. When we change that
# to error later, this decorator can be removed, since the error will be raised by _get_kernel
def _register_unsupported_type(*datapoints_classes):
def _register_unsupported_type(*input_types):
def kernel(inpt, *args, __dispatcher_name__, **kwargs):
raise TypeError(f"F.{__dispatcher_name__} does not support inputs of type {type(inpt)}.")
def decorator(dispatcher):
for cls in datapoints_classes:
register_kernel(dispatcher, cls)(functools.partial(kernel, __dispatcher_name__=dispatcher.__name__))
for input_type in input_types:
_register_kernel_internal(dispatcher, input_type, datapoint_wrapper=False)(
functools.partial(kernel, __dispatcher_name__=dispatcher.__name__)
)
return dispatcher
return decorator
......@@ -129,13 +172,10 @@ def _register_unsupported_type(*datapoints_classes):
# This basically replicates _register_kernel_internal, but with a specialized wrapper for five_crop / ten_crop
# We could get rid of this by letting _register_kernel_internal take arbitrary dispatchers rather than wrap_kernel: bool
# TODO: decide if we want that
def _register_five_ten_crop_kernel(dispatcher, datapoint_cls):
def _register_five_ten_crop_kernel(dispatcher, input_type):
registry = _KERNEL_REGISTRY.setdefault(dispatcher, {})
if datapoint_cls in registry:
raise TypeError(
f"Dispatcher '{dispatcher.__name__}' already has a kernel registered for type '{datapoint_cls.__name__}'."
)
if input_type in registry:
raise TypeError(f"Dispatcher '{dispatcher}' already has a kernel registered for type '{input_type}'.")
def wrap(kernel):
@functools.wraps(kernel)
......@@ -147,7 +187,7 @@ def _register_five_ten_crop_kernel(dispatcher, datapoint_cls):
return wrapper
def decorator(kernel):
registry[datapoint_cls] = wrap(kernel)
registry[input_type] = wrap(kernel) if issubclass(input_type, datapoints.Datapoint) else kernel
return kernel
return decorator
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment