Unverified Commit 2030d208 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

register tensor and PIL kernel the same way as datapoints (#7797)


Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
parent 84db2ac4
...@@ -2,7 +2,6 @@ import inspect ...@@ -2,7 +2,6 @@ import inspect
import math import math
import os import os
import re import re
from unittest import mock
import numpy as np import numpy as np
import PIL.Image import PIL.Image
...@@ -25,7 +24,6 @@ from torchvision.transforms.functional import _get_perspective_coeffs ...@@ -25,7 +24,6 @@ from torchvision.transforms.functional import _get_perspective_coeffs
from torchvision.transforms.v2 import functional as F from torchvision.transforms.v2 import functional as F
from torchvision.transforms.v2.functional._geometry import _center_crop_compute_padding from torchvision.transforms.v2.functional._geometry import _center_crop_compute_padding
from torchvision.transforms.v2.functional._meta import clamp_bounding_boxes, convert_format_bounding_boxes from torchvision.transforms.v2.functional._meta import clamp_bounding_boxes, convert_format_bounding_boxes
from torchvision.transforms.v2.functional._utils import _KERNEL_REGISTRY
from torchvision.transforms.v2.utils import is_simple_tensor from torchvision.transforms.v2.utils import is_simple_tensor
from transforms_v2_dispatcher_infos import DISPATCHER_INFOS from transforms_v2_dispatcher_infos import DISPATCHER_INFOS
from transforms_v2_kernel_infos import KERNEL_INFOS from transforms_v2_kernel_infos import KERNEL_INFOS
...@@ -359,18 +357,6 @@ class TestDispatchers: ...@@ -359,18 +357,6 @@ class TestDispatchers:
def test_scriptable(self, dispatcher): def test_scriptable(self, dispatcher):
script(dispatcher) script(dispatcher)
@image_sample_inputs
def test_dispatch_simple_tensor(self, info, args_kwargs, spy_on):
(image_datapoint, *other_args), kwargs = args_kwargs.load()
image_simple_tensor = torch.Tensor(image_datapoint)
kernel_info = info.kernel_infos[datapoints.Image]
spy = spy_on(kernel_info.kernel, module=info.dispatcher.__module__, name=kernel_info.id)
info.dispatcher(image_simple_tensor, *other_args, **kwargs)
spy.assert_called_once()
@image_sample_inputs @image_sample_inputs
def test_simple_tensor_output_type(self, info, args_kwargs): def test_simple_tensor_output_type(self, info, args_kwargs):
(image_datapoint, *other_args), kwargs = args_kwargs.load() (image_datapoint, *other_args), kwargs = args_kwargs.load()
...@@ -381,25 +367,6 @@ class TestDispatchers: ...@@ -381,25 +367,6 @@ class TestDispatchers:
# We cannot use `isinstance` here since all datapoints are instances of `torch.Tensor` as well # We cannot use `isinstance` here since all datapoints are instances of `torch.Tensor` as well
assert type(output) is torch.Tensor assert type(output) is torch.Tensor
@make_info_args_kwargs_parametrization(
[info for info in DISPATCHER_INFOS if info.pil_kernel_info is not None],
args_kwargs_fn=lambda info: info.sample_inputs(datapoints.Image),
)
def test_dispatch_pil(self, info, args_kwargs, spy_on):
(image_datapoint, *other_args), kwargs = args_kwargs.load()
if image_datapoint.ndim > 3:
pytest.skip("Input is batched")
image_pil = F.to_image_pil(image_datapoint)
pil_kernel_info = info.pil_kernel_info
spy = spy_on(pil_kernel_info.kernel, module=info.dispatcher.__module__, name=pil_kernel_info.id)
info.dispatcher(image_pil, *other_args, **kwargs)
spy.assert_called_once()
@make_info_args_kwargs_parametrization( @make_info_args_kwargs_parametrization(
[info for info in DISPATCHER_INFOS if info.pil_kernel_info is not None], [info for info in DISPATCHER_INFOS if info.pil_kernel_info is not None],
args_kwargs_fn=lambda info: info.sample_inputs(datapoints.Image), args_kwargs_fn=lambda info: info.sample_inputs(datapoints.Image),
...@@ -416,28 +383,6 @@ class TestDispatchers: ...@@ -416,28 +383,6 @@ class TestDispatchers:
assert isinstance(output, PIL.Image.Image) assert isinstance(output, PIL.Image.Image)
@make_info_args_kwargs_parametrization(
DISPATCHER_INFOS,
args_kwargs_fn=lambda info: info.sample_inputs(),
)
def test_dispatch_datapoint(self, info, args_kwargs, spy_on):
(datapoint, *other_args), kwargs = args_kwargs.load()
input_type = type(datapoint)
wrapped_kernel = _KERNEL_REGISTRY[info.dispatcher][input_type]
# In case the wrapper was decorated with @functools.wraps, we can make the check more strict and test if the
# proper kernel was wrapped
if hasattr(wrapped_kernel, "__wrapped__"):
assert wrapped_kernel.__wrapped__ is info.kernels[input_type]
spy = mock.MagicMock(wraps=wrapped_kernel, name=wrapped_kernel.__name__)
with mock.patch.dict(_KERNEL_REGISTRY[info.dispatcher], values={input_type: spy}):
info.dispatcher(datapoint, *other_args, **kwargs)
spy.assert_called_once()
@make_info_args_kwargs_parametrization( @make_info_args_kwargs_parametrization(
DISPATCHER_INFOS, DISPATCHER_INFOS,
args_kwargs_fn=lambda info: info.sample_inputs(), args_kwargs_fn=lambda info: info.sample_inputs(),
...@@ -449,6 +394,9 @@ class TestDispatchers: ...@@ -449,6 +394,9 @@ class TestDispatchers:
assert isinstance(output, type(datapoint)) assert isinstance(output, type(datapoint))
if isinstance(datapoint, datapoints.BoundingBoxes) and info.dispatcher is not F.convert_format_bounding_boxes:
assert output.format == datapoint.format
@pytest.mark.parametrize( @pytest.mark.parametrize(
("dispatcher_info", "datapoint_type", "kernel_info"), ("dispatcher_info", "datapoint_type", "kernel_info"),
[ [
......
...@@ -39,7 +39,7 @@ from torchvision import datapoints ...@@ -39,7 +39,7 @@ from torchvision import datapoints
from torchvision.transforms._functional_tensor import _max_value as get_max_value from torchvision.transforms._functional_tensor import _max_value as get_max_value
from torchvision.transforms.functional import pil_modes_mapping from torchvision.transforms.functional import pil_modes_mapping
from torchvision.transforms.v2 import functional as F from torchvision.transforms.v2 import functional as F
from torchvision.transforms.v2.functional._utils import _KERNEL_REGISTRY from torchvision.transforms.v2.functional._utils import _get_kernel, _KERNEL_REGISTRY, _noop, _register_kernel_internal
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
...@@ -173,59 +173,32 @@ def _check_dispatcher_scripted_smoke(dispatcher, input, *args, **kwargs): ...@@ -173,59 +173,32 @@ def _check_dispatcher_scripted_smoke(dispatcher, input, *args, **kwargs):
dispatcher_scripted(input.as_subclass(torch.Tensor), *args, **kwargs) dispatcher_scripted(input.as_subclass(torch.Tensor), *args, **kwargs)
def _check_dispatcher_dispatch(dispatcher, kernel, input, *args, **kwargs):
"""Checks if the dispatcher correctly dispatches the input to the corresponding kernel and that the input type is
preserved in doing so. For bounding boxes also checks that the format is preserved.
"""
input_type = type(input)
if isinstance(input, datapoints.Datapoint):
wrapped_kernel = _KERNEL_REGISTRY[dispatcher][input_type]
# In case the wrapper was decorated with @functools.wraps, we can make the check more strict and test if the
# proper kernel was wrapped
if hasattr(wrapped_kernel, "__wrapped__"):
assert wrapped_kernel.__wrapped__ is kernel
spy = mock.MagicMock(wraps=wrapped_kernel, name=wrapped_kernel.__name__)
with mock.patch.dict(_KERNEL_REGISTRY[dispatcher], values={input_type: spy}):
output = dispatcher(input, *args, **kwargs)
spy.assert_called_once()
else:
with mock.patch(f"{dispatcher.__module__}.{kernel.__name__}", wraps=kernel) as spy:
output = dispatcher(input, *args, **kwargs)
spy.assert_called_once()
assert isinstance(output, input_type)
if isinstance(input, datapoints.BoundingBoxes):
assert output.format == input.format
def check_dispatcher( def check_dispatcher(
dispatcher, dispatcher,
# TODO: remove this parameter
kernel, kernel,
input, input,
*args, *args,
check_scripted_smoke=True, check_scripted_smoke=True,
check_dispatch=True,
**kwargs, **kwargs,
): ):
unknown_input = object() unknown_input = object()
with mock.patch("torch._C._log_api_usage_once", wraps=torch._C._log_api_usage_once) as spy:
with pytest.raises(TypeError, match=re.escape(str(type(unknown_input)))): with pytest.raises(TypeError, match=re.escape(str(type(unknown_input)))):
dispatcher(unknown_input, *args, **kwargs) dispatcher(unknown_input, *args, **kwargs)
with mock.patch("torch._C._log_api_usage_once", wraps=torch._C._log_api_usage_once) as spy:
output = dispatcher(input, *args, **kwargs)
spy.assert_any_call(f"{dispatcher.__module__}.{dispatcher.__name__}") spy.assert_any_call(f"{dispatcher.__module__}.{dispatcher.__name__}")
assert isinstance(output, type(input))
if isinstance(input, datapoints.BoundingBoxes):
assert output.format == input.format
if check_scripted_smoke: if check_scripted_smoke:
_check_dispatcher_scripted_smoke(dispatcher, input, *args, **kwargs) _check_dispatcher_scripted_smoke(dispatcher, input, *args, **kwargs)
if check_dispatch:
_check_dispatcher_dispatch(dispatcher, kernel, input, *args, **kwargs)
def check_dispatcher_kernel_signature_match(dispatcher, *, kernel, input_type): def check_dispatcher_kernel_signature_match(dispatcher, *, kernel, input_type):
"""Checks if the signature of the dispatcher matches the kernel signature.""" """Checks if the signature of the dispatcher matches the kernel signature."""
...@@ -412,18 +385,20 @@ def reference_affine_bounding_boxes_helper(bounding_boxes, *, format, canvas_siz ...@@ -412,18 +385,20 @@ def reference_affine_bounding_boxes_helper(bounding_boxes, *, format, canvas_siz
@pytest.mark.parametrize( @pytest.mark.parametrize(
("dispatcher", "registered_datapoint_clss"), ("dispatcher", "registered_input_types"),
[(dispatcher, set(registry.keys())) for dispatcher, registry in _KERNEL_REGISTRY.items()], [(dispatcher, set(registry.keys())) for dispatcher, registry in _KERNEL_REGISTRY.items()],
) )
def test_exhaustive_kernel_registration(dispatcher, registered_datapoint_clss): def test_exhaustive_kernel_registration(dispatcher, registered_input_types):
missing = { missing = {
torch.Tensor,
PIL.Image.Image,
datapoints.Image, datapoints.Image,
datapoints.BoundingBoxes, datapoints.BoundingBoxes,
datapoints.Mask, datapoints.Mask,
datapoints.Video, datapoints.Video,
} - registered_datapoint_clss } - registered_input_types
if missing: if missing:
names = sorted(f"datapoints.{cls.__name__}" for cls in missing) names = sorted(str(t) for t in missing)
raise AssertionError( raise AssertionError(
"\n".join( "\n".join(
[ [
...@@ -1753,11 +1728,6 @@ class TestToDtype: ...@@ -1753,11 +1728,6 @@ class TestToDtype:
F.to_dtype, F.to_dtype,
kernel, kernel,
make_input(dtype=input_dtype, device=device), make_input(dtype=input_dtype, device=device),
# TODO: we could leave check_dispatch to True but it currently fails
# in _check_dispatcher_dispatch because there is no to_dtype() method on the datapoints.
# We should be able to put this back if we change the dispatch
# mechanism e.g. via https://github.com/pytorch/vision/pull/7733
check_dispatch=False,
dtype=output_dtype, dtype=output_dtype,
scale=scale, scale=scale,
) )
...@@ -2208,9 +2178,105 @@ class TestRegisterKernel: ...@@ -2208,9 +2178,105 @@ class TestRegisterKernel:
t(torch.rand(3, 10, 10)).shape == (3, 224, 224) t(torch.rand(3, 10, 10)).shape == (3, 224, 224)
t(datapoints.Image(torch.rand(3, 10, 10))).shape == (3, 224, 224) t(datapoints.Image(torch.rand(3, 10, 10))).shape == (3, 224, 224)
def test_bad_disaptcher_name(self): def test_errors(self):
class CustomDatapoint(datapoints.Datapoint): with pytest.raises(ValueError, match="Could not find dispatcher with name"):
F.register_kernel("bad_name", datapoints.Image)
with pytest.raises(ValueError, match="Kernels can only be registered on dispatchers"):
F.register_kernel(datapoints.Image, F.resize)
with pytest.raises(ValueError, match="Kernels can only be registered for subclasses"):
F.register_kernel(F.resize, object)
with pytest.raises(ValueError, match="already has a kernel registered for type"):
F.register_kernel(F.resize, datapoints.Image)(F.resize_image_tensor)
class TestGetKernel:
# We are using F.resize as dispatcher and the kernels below as proxy. Any other dispatcher / kernels combination
# would also be fine
KERNELS = {
torch.Tensor: F.resize_image_tensor,
PIL.Image.Image: F.resize_image_pil,
datapoints.Image: F.resize_image_tensor,
datapoints.BoundingBoxes: F.resize_bounding_boxes,
datapoints.Mask: F.resize_mask,
datapoints.Video: F.resize_video,
}
def test_unsupported_types(self):
class MyTensor(torch.Tensor):
pass pass
with pytest.raises(ValueError, match="Could not find dispatcher with name"): class MyPILImage(PIL.Image.Image):
F.register_kernel("bad_name", CustomDatapoint) pass
for input_type in [str, int, object, MyTensor, MyPILImage]:
with pytest.raises(
TypeError,
match=(
"supports inputs of type torch.Tensor, PIL.Image.Image, "
"and subclasses of torchvision.datapoints.Datapoint"
),
):
_get_kernel(F.resize, input_type)
def test_exact_match(self):
# We cannot use F.resize together with self.KERNELS mapping here directly here, since this is only the
# ideal wrapping. Practically, we have an intermediate wrapper layer. Thus, we create a new resize dispatcher
# here, register the kernels without wrapper, and check the exact matching afterwards.
def resize_with_pure_kernels():
pass
for input_type, kernel in self.KERNELS.items():
_register_kernel_internal(resize_with_pure_kernels, input_type, datapoint_wrapper=False)(kernel)
assert _get_kernel(resize_with_pure_kernels, input_type) is kernel
def test_builtin_datapoint_subclass(self):
# We cannot use F.resize together with self.KERNELS mapping here directly here, since this is only the
# ideal wrapping. Practically, we have an intermediate wrapper layer. Thus, we create a new resize dispatcher
# here, register the kernels without wrapper, and check if subclasses of our builtin datapoints get dispatched
# to the kernel of the corresponding superclass
def resize_with_pure_kernels():
pass
class MyImage(datapoints.Image):
pass
class MyBoundingBoxes(datapoints.BoundingBoxes):
pass
class MyMask(datapoints.Mask):
pass
class MyVideo(datapoints.Video):
pass
for custom_datapoint_subclass in [
MyImage,
MyBoundingBoxes,
MyMask,
MyVideo,
]:
builtin_datapoint_class = custom_datapoint_subclass.__mro__[1]
builtin_datapoint_kernel = self.KERNELS[builtin_datapoint_class]
_register_kernel_internal(resize_with_pure_kernels, builtin_datapoint_class, datapoint_wrapper=False)(
builtin_datapoint_kernel
)
assert _get_kernel(resize_with_pure_kernels, custom_datapoint_subclass) is builtin_datapoint_kernel
def test_datapoint_subclass(self):
class MyDatapoint(datapoints.Datapoint):
pass
# Note that this will be an error in the future
assert _get_kernel(F.resize, MyDatapoint) is _noop
def resize_my_datapoint():
pass
_register_kernel_internal(F.resize, MyDatapoint, datapoint_wrapper=False)(resize_my_datapoint)
assert _get_kernel(F.resize, MyDatapoint) is resize_my_datapoint
...@@ -7,7 +7,7 @@ from torchvision import datapoints ...@@ -7,7 +7,7 @@ from torchvision import datapoints
from torchvision.transforms.functional import pil_to_tensor, to_pil_image from torchvision.transforms.functional import pil_to_tensor, to_pil_image
from torchvision.utils import _log_api_usage_once from torchvision.utils import _log_api_usage_once
from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_internal, is_simple_tensor from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_internal
@_register_explicit_noop(datapoints.Mask, datapoints.BoundingBoxes, warn_passthrough=True) @_register_explicit_noop(datapoints.Mask, datapoints.BoundingBoxes, warn_passthrough=True)
...@@ -20,23 +20,16 @@ def erase( ...@@ -20,23 +20,16 @@ def erase(
v: torch.Tensor, v: torch.Tensor,
inplace: bool = False, inplace: bool = False,
) -> Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]: ) -> Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]:
if not torch.jit.is_scripting(): if torch.jit.is_scripting():
return erase_image_tensor(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
_log_api_usage_once(erase) _log_api_usage_once(erase)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return erase_image_tensor(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
elif isinstance(inpt, datapoints.Datapoint):
kernel = _get_kernel(erase, type(inpt)) kernel = _get_kernel(erase, type(inpt))
return kernel(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace) return kernel(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
elif isinstance(inpt, PIL.Image.Image):
return erase_image_pil(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
@_register_kernel_internal(erase, torch.Tensor)
@_register_kernel_internal(erase, datapoints.Image) @_register_kernel_internal(erase, datapoints.Image)
def erase_image_tensor( def erase_image_tensor(
image: torch.Tensor, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False image: torch.Tensor, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False
...@@ -48,7 +41,7 @@ def erase_image_tensor( ...@@ -48,7 +41,7 @@ def erase_image_tensor(
return image return image
@torch.jit.unused @_register_kernel_internal(erase, PIL.Image.Image)
def erase_image_pil( def erase_image_pil(
image: PIL.Image.Image, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False image: PIL.Image.Image, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False
) -> PIL.Image.Image: ) -> PIL.Image.Image:
......
...@@ -10,29 +10,20 @@ from torchvision.transforms._functional_tensor import _max_value ...@@ -10,29 +10,20 @@ from torchvision.transforms._functional_tensor import _max_value
from torchvision.utils import _log_api_usage_once from torchvision.utils import _log_api_usage_once
from ._misc import _num_value_bits, to_dtype_image_tensor from ._misc import _num_value_bits, to_dtype_image_tensor
from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_internal, is_simple_tensor from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_internal
@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask, datapoints.Video) @_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask, datapoints.Video)
def rgb_to_grayscale( def rgb_to_grayscale(
inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT], num_output_channels: int = 1 inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT], num_output_channels: int = 1
) -> Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]: ) -> Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]:
if not torch.jit.is_scripting(): if torch.jit.is_scripting():
_log_api_usage_once(rgb_to_grayscale)
if num_output_channels not in (1, 3):
raise ValueError(f"num_output_channels must be 1 or 3, got {num_output_channels}.")
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return rgb_to_grayscale_image_tensor(inpt, num_output_channels=num_output_channels) return rgb_to_grayscale_image_tensor(inpt, num_output_channels=num_output_channels)
elif isinstance(inpt, datapoints.Datapoint):
_log_api_usage_once(rgb_to_grayscale)
kernel = _get_kernel(rgb_to_grayscale, type(inpt)) kernel = _get_kernel(rgb_to_grayscale, type(inpt))
return kernel(inpt, num_output_channels=num_output_channels) return kernel(inpt, num_output_channels=num_output_channels)
elif isinstance(inpt, PIL.Image.Image):
return rgb_to_grayscale_image_pil(inpt, num_output_channels=num_output_channels)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
# `to_grayscale` actually predates `rgb_to_grayscale` in v1, but only handles PIL images. Since `rgb_to_grayscale` is a # `to_grayscale` actually predates `rgb_to_grayscale` in v1, but only handles PIL images. Since `rgb_to_grayscale` is a
...@@ -56,12 +47,19 @@ def _rgb_to_grayscale_image_tensor( ...@@ -56,12 +47,19 @@ def _rgb_to_grayscale_image_tensor(
return l_img return l_img
@_register_kernel_internal(rgb_to_grayscale, torch.Tensor)
@_register_kernel_internal(rgb_to_grayscale, datapoints.Image) @_register_kernel_internal(rgb_to_grayscale, datapoints.Image)
def rgb_to_grayscale_image_tensor(image: torch.Tensor, num_output_channels: int = 1) -> torch.Tensor: def rgb_to_grayscale_image_tensor(image: torch.Tensor, num_output_channels: int = 1) -> torch.Tensor:
if num_output_channels not in (1, 3):
raise ValueError(f"num_output_channels must be 1 or 3, got {num_output_channels}.")
return _rgb_to_grayscale_image_tensor(image, num_output_channels=num_output_channels, preserve_dtype=True) return _rgb_to_grayscale_image_tensor(image, num_output_channels=num_output_channels, preserve_dtype=True)
rgb_to_grayscale_image_pil = _FP.to_grayscale @_register_kernel_internal(rgb_to_grayscale, PIL.Image.Image)
def rgb_to_grayscale_image_pil(image: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Image.Image:
if num_output_channels not in (1, 3):
raise ValueError(f"num_output_channels must be 1 or 3, got {num_output_channels}.")
return _FP.to_grayscale(image, num_output_channels=num_output_channels)
def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Tensor: def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Tensor:
...@@ -74,23 +72,16 @@ def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Te ...@@ -74,23 +72,16 @@ def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Te
@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) @_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask)
def adjust_brightness(inpt: datapoints._InputTypeJIT, brightness_factor: float) -> datapoints._InputTypeJIT: def adjust_brightness(inpt: datapoints._InputTypeJIT, brightness_factor: float) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting(): if torch.jit.is_scripting():
return adjust_brightness_image_tensor(inpt, brightness_factor=brightness_factor)
_log_api_usage_once(adjust_brightness) _log_api_usage_once(adjust_brightness)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return adjust_brightness_image_tensor(inpt, brightness_factor=brightness_factor)
elif isinstance(inpt, datapoints.Datapoint):
kernel = _get_kernel(adjust_brightness, type(inpt)) kernel = _get_kernel(adjust_brightness, type(inpt))
return kernel(inpt, brightness_factor=brightness_factor) return kernel(inpt, brightness_factor=brightness_factor)
elif isinstance(inpt, PIL.Image.Image):
return adjust_brightness_image_pil(inpt, brightness_factor=brightness_factor)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
@_register_kernel_internal(adjust_brightness, torch.Tensor)
@_register_kernel_internal(adjust_brightness, datapoints.Image) @_register_kernel_internal(adjust_brightness, datapoints.Image)
def adjust_brightness_image_tensor(image: torch.Tensor, brightness_factor: float) -> torch.Tensor: def adjust_brightness_image_tensor(image: torch.Tensor, brightness_factor: float) -> torch.Tensor:
if brightness_factor < 0: if brightness_factor < 0:
...@@ -106,6 +97,7 @@ def adjust_brightness_image_tensor(image: torch.Tensor, brightness_factor: float ...@@ -106,6 +97,7 @@ def adjust_brightness_image_tensor(image: torch.Tensor, brightness_factor: float
return output if fp else output.to(image.dtype) return output if fp else output.to(image.dtype)
@_register_kernel_internal(adjust_brightness, PIL.Image.Image)
def adjust_brightness_image_pil(image: PIL.Image.Image, brightness_factor: float) -> PIL.Image.Image: def adjust_brightness_image_pil(image: PIL.Image.Image, brightness_factor: float) -> PIL.Image.Image:
return _FP.adjust_brightness(image, brightness_factor=brightness_factor) return _FP.adjust_brightness(image, brightness_factor=brightness_factor)
...@@ -117,23 +109,16 @@ def adjust_brightness_video(video: torch.Tensor, brightness_factor: float) -> to ...@@ -117,23 +109,16 @@ def adjust_brightness_video(video: torch.Tensor, brightness_factor: float) -> to
@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) @_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask)
def adjust_saturation(inpt: datapoints._InputTypeJIT, saturation_factor: float) -> datapoints._InputTypeJIT: def adjust_saturation(inpt: datapoints._InputTypeJIT, saturation_factor: float) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting(): if torch.jit.is_scripting():
return adjust_saturation_image_tensor(inpt, saturation_factor=saturation_factor)
_log_api_usage_once(adjust_saturation) _log_api_usage_once(adjust_saturation)
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, datapoints.Datapoint)):
return adjust_saturation_image_tensor(inpt, saturation_factor=saturation_factor)
elif isinstance(inpt, datapoints.Datapoint):
kernel = _get_kernel(adjust_saturation, type(inpt)) kernel = _get_kernel(adjust_saturation, type(inpt))
return kernel(inpt, saturation_factor=saturation_factor) return kernel(inpt, saturation_factor=saturation_factor)
elif isinstance(inpt, PIL.Image.Image):
return adjust_saturation_image_pil(inpt, saturation_factor=saturation_factor)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
@_register_kernel_internal(adjust_saturation, torch.Tensor)
@_register_kernel_internal(adjust_saturation, datapoints.Image) @_register_kernel_internal(adjust_saturation, datapoints.Image)
def adjust_saturation_image_tensor(image: torch.Tensor, saturation_factor: float) -> torch.Tensor: def adjust_saturation_image_tensor(image: torch.Tensor, saturation_factor: float) -> torch.Tensor:
if saturation_factor < 0: if saturation_factor < 0:
...@@ -153,7 +138,7 @@ def adjust_saturation_image_tensor(image: torch.Tensor, saturation_factor: float ...@@ -153,7 +138,7 @@ def adjust_saturation_image_tensor(image: torch.Tensor, saturation_factor: float
return _blend(image, grayscale_image, saturation_factor) return _blend(image, grayscale_image, saturation_factor)
adjust_saturation_image_pil = _FP.adjust_saturation adjust_saturation_image_pil = _register_kernel_internal(adjust_saturation, PIL.Image.Image)(_FP.adjust_saturation)
@_register_kernel_internal(adjust_saturation, datapoints.Video) @_register_kernel_internal(adjust_saturation, datapoints.Video)
...@@ -163,23 +148,16 @@ def adjust_saturation_video(video: torch.Tensor, saturation_factor: float) -> to ...@@ -163,23 +148,16 @@ def adjust_saturation_video(video: torch.Tensor, saturation_factor: float) -> to
@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) @_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask)
def adjust_contrast(inpt: datapoints._InputTypeJIT, contrast_factor: float) -> datapoints._InputTypeJIT: def adjust_contrast(inpt: datapoints._InputTypeJIT, contrast_factor: float) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting(): if torch.jit.is_scripting():
return adjust_contrast_image_tensor(inpt, contrast_factor=contrast_factor)
_log_api_usage_once(adjust_contrast) _log_api_usage_once(adjust_contrast)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return adjust_contrast_image_tensor(inpt, contrast_factor=contrast_factor)
elif isinstance(inpt, datapoints.Datapoint):
kernel = _get_kernel(adjust_contrast, type(inpt)) kernel = _get_kernel(adjust_contrast, type(inpt))
return kernel(inpt, contrast_factor=contrast_factor) return kernel(inpt, contrast_factor=contrast_factor)
elif isinstance(inpt, PIL.Image.Image):
return adjust_contrast_image_pil(inpt, contrast_factor=contrast_factor)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
@_register_kernel_internal(adjust_contrast, torch.Tensor)
@_register_kernel_internal(adjust_contrast, datapoints.Image) @_register_kernel_internal(adjust_contrast, datapoints.Image)
def adjust_contrast_image_tensor(image: torch.Tensor, contrast_factor: float) -> torch.Tensor: def adjust_contrast_image_tensor(image: torch.Tensor, contrast_factor: float) -> torch.Tensor:
if contrast_factor < 0: if contrast_factor < 0:
...@@ -199,7 +177,7 @@ def adjust_contrast_image_tensor(image: torch.Tensor, contrast_factor: float) -> ...@@ -199,7 +177,7 @@ def adjust_contrast_image_tensor(image: torch.Tensor, contrast_factor: float) ->
return _blend(image, mean, contrast_factor) return _blend(image, mean, contrast_factor)
adjust_contrast_image_pil = _FP.adjust_contrast adjust_contrast_image_pil = _register_kernel_internal(adjust_contrast, PIL.Image.Image)(_FP.adjust_contrast)
@_register_kernel_internal(adjust_contrast, datapoints.Video) @_register_kernel_internal(adjust_contrast, datapoints.Video)
...@@ -209,23 +187,16 @@ def adjust_contrast_video(video: torch.Tensor, contrast_factor: float) -> torch. ...@@ -209,23 +187,16 @@ def adjust_contrast_video(video: torch.Tensor, contrast_factor: float) -> torch.
@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) @_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask)
def adjust_sharpness(inpt: datapoints._InputTypeJIT, sharpness_factor: float) -> datapoints._InputTypeJIT: def adjust_sharpness(inpt: datapoints._InputTypeJIT, sharpness_factor: float) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting(): if torch.jit.is_scripting():
return adjust_sharpness_image_tensor(inpt, sharpness_factor=sharpness_factor)
_log_api_usage_once(adjust_sharpness) _log_api_usage_once(adjust_sharpness)
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, datapoints.Datapoint)):
return adjust_sharpness_image_tensor(inpt, sharpness_factor=sharpness_factor)
elif isinstance(inpt, datapoints.Datapoint):
kernel = _get_kernel(adjust_sharpness, type(inpt)) kernel = _get_kernel(adjust_sharpness, type(inpt))
return kernel(inpt, sharpness_factor=sharpness_factor) return kernel(inpt, sharpness_factor=sharpness_factor)
elif isinstance(inpt, PIL.Image.Image):
return adjust_sharpness_image_pil(inpt, sharpness_factor=sharpness_factor)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
@_register_kernel_internal(adjust_sharpness, torch.Tensor)
@_register_kernel_internal(adjust_sharpness, datapoints.Image) @_register_kernel_internal(adjust_sharpness, datapoints.Image)
def adjust_sharpness_image_tensor(image: torch.Tensor, sharpness_factor: float) -> torch.Tensor: def adjust_sharpness_image_tensor(image: torch.Tensor, sharpness_factor: float) -> torch.Tensor:
num_channels, height, width = image.shape[-3:] num_channels, height, width = image.shape[-3:]
...@@ -279,7 +250,7 @@ def adjust_sharpness_image_tensor(image: torch.Tensor, sharpness_factor: float) ...@@ -279,7 +250,7 @@ def adjust_sharpness_image_tensor(image: torch.Tensor, sharpness_factor: float)
return output return output
adjust_sharpness_image_pil = _FP.adjust_sharpness adjust_sharpness_image_pil = _register_kernel_internal(adjust_sharpness, PIL.Image.Image)(_FP.adjust_sharpness)
@_register_kernel_internal(adjust_sharpness, datapoints.Video) @_register_kernel_internal(adjust_sharpness, datapoints.Video)
...@@ -289,21 +260,13 @@ def adjust_sharpness_video(video: torch.Tensor, sharpness_factor: float) -> torc ...@@ -289,21 +260,13 @@ def adjust_sharpness_video(video: torch.Tensor, sharpness_factor: float) -> torc
@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) @_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask)
def adjust_hue(inpt: datapoints._InputTypeJIT, hue_factor: float) -> datapoints._InputTypeJIT: def adjust_hue(inpt: datapoints._InputTypeJIT, hue_factor: float) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting(): if torch.jit.is_scripting():
return adjust_hue_image_tensor(inpt, hue_factor=hue_factor)
_log_api_usage_once(adjust_hue) _log_api_usage_once(adjust_hue)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return adjust_hue_image_tensor(inpt, hue_factor=hue_factor)
elif isinstance(inpt, datapoints.Datapoint):
kernel = _get_kernel(adjust_hue, type(inpt)) kernel = _get_kernel(adjust_hue, type(inpt))
return kernel(inpt, hue_factor=hue_factor) return kernel(inpt, hue_factor=hue_factor)
elif isinstance(inpt, PIL.Image.Image):
return adjust_hue_image_pil(inpt, hue_factor=hue_factor)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
def _rgb_to_hsv(image: torch.Tensor) -> torch.Tensor: def _rgb_to_hsv(image: torch.Tensor) -> torch.Tensor:
...@@ -370,6 +333,7 @@ def _hsv_to_rgb(img: torch.Tensor) -> torch.Tensor: ...@@ -370,6 +333,7 @@ def _hsv_to_rgb(img: torch.Tensor) -> torch.Tensor:
return (a4.mul_(mask.unsqueeze(dim=-4))).sum(dim=-3) return (a4.mul_(mask.unsqueeze(dim=-4))).sum(dim=-3)
@_register_kernel_internal(adjust_hue, torch.Tensor)
@_register_kernel_internal(adjust_hue, datapoints.Image) @_register_kernel_internal(adjust_hue, datapoints.Image)
def adjust_hue_image_tensor(image: torch.Tensor, hue_factor: float) -> torch.Tensor: def adjust_hue_image_tensor(image: torch.Tensor, hue_factor: float) -> torch.Tensor:
if not (-0.5 <= hue_factor <= 0.5): if not (-0.5 <= hue_factor <= 0.5):
...@@ -398,7 +362,7 @@ def adjust_hue_image_tensor(image: torch.Tensor, hue_factor: float) -> torch.Ten ...@@ -398,7 +362,7 @@ def adjust_hue_image_tensor(image: torch.Tensor, hue_factor: float) -> torch.Ten
return to_dtype_image_tensor(image_hue_adj, orig_dtype, scale=True) return to_dtype_image_tensor(image_hue_adj, orig_dtype, scale=True)
adjust_hue_image_pil = _FP.adjust_hue adjust_hue_image_pil = _register_kernel_internal(adjust_hue, PIL.Image.Image)(_FP.adjust_hue)
@_register_kernel_internal(adjust_hue, datapoints.Video) @_register_kernel_internal(adjust_hue, datapoints.Video)
...@@ -408,23 +372,16 @@ def adjust_hue_video(video: torch.Tensor, hue_factor: float) -> torch.Tensor: ...@@ -408,23 +372,16 @@ def adjust_hue_video(video: torch.Tensor, hue_factor: float) -> torch.Tensor:
@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) @_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask)
def adjust_gamma(inpt: datapoints._InputTypeJIT, gamma: float, gain: float = 1) -> datapoints._InputTypeJIT: def adjust_gamma(inpt: datapoints._InputTypeJIT, gamma: float, gain: float = 1) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting(): if torch.jit.is_scripting():
return adjust_gamma_image_tensor(inpt, gamma=gamma, gain=gain)
_log_api_usage_once(adjust_gamma) _log_api_usage_once(adjust_gamma)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return adjust_gamma_image_tensor(inpt, gamma=gamma, gain=gain)
elif isinstance(inpt, datapoints.Datapoint):
kernel = _get_kernel(adjust_gamma, type(inpt)) kernel = _get_kernel(adjust_gamma, type(inpt))
return kernel(inpt, gamma=gamma, gain=gain) return kernel(inpt, gamma=gamma, gain=gain)
elif isinstance(inpt, PIL.Image.Image):
return adjust_gamma_image_pil(inpt, gamma=gamma, gain=gain)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
@_register_kernel_internal(adjust_gamma, torch.Tensor)
@_register_kernel_internal(adjust_gamma, datapoints.Image) @_register_kernel_internal(adjust_gamma, datapoints.Image)
def adjust_gamma_image_tensor(image: torch.Tensor, gamma: float, gain: float = 1.0) -> torch.Tensor: def adjust_gamma_image_tensor(image: torch.Tensor, gamma: float, gain: float = 1.0) -> torch.Tensor:
if gamma < 0: if gamma < 0:
...@@ -445,7 +402,7 @@ def adjust_gamma_image_tensor(image: torch.Tensor, gamma: float, gain: float = 1 ...@@ -445,7 +402,7 @@ def adjust_gamma_image_tensor(image: torch.Tensor, gamma: float, gain: float = 1
return to_dtype_image_tensor(output, image.dtype, scale=True) return to_dtype_image_tensor(output, image.dtype, scale=True)
adjust_gamma_image_pil = _FP.adjust_gamma adjust_gamma_image_pil = _register_kernel_internal(adjust_gamma, PIL.Image.Image)(_FP.adjust_gamma)
@_register_kernel_internal(adjust_gamma, datapoints.Video) @_register_kernel_internal(adjust_gamma, datapoints.Video)
...@@ -455,23 +412,16 @@ def adjust_gamma_video(video: torch.Tensor, gamma: float, gain: float = 1) -> to ...@@ -455,23 +412,16 @@ def adjust_gamma_video(video: torch.Tensor, gamma: float, gain: float = 1) -> to
@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) @_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask)
def posterize(inpt: datapoints._InputTypeJIT, bits: int) -> datapoints._InputTypeJIT: def posterize(inpt: datapoints._InputTypeJIT, bits: int) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting(): if torch.jit.is_scripting():
return posterize_image_tensor(inpt, bits=bits)
_log_api_usage_once(posterize) _log_api_usage_once(posterize)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return posterize_image_tensor(inpt, bits=bits)
elif isinstance(inpt, datapoints.Datapoint):
kernel = _get_kernel(posterize, type(inpt)) kernel = _get_kernel(posterize, type(inpt))
return kernel(inpt, bits=bits) return kernel(inpt, bits=bits)
elif isinstance(inpt, PIL.Image.Image):
return posterize_image_pil(inpt, bits=bits)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
@_register_kernel_internal(posterize, torch.Tensor)
@_register_kernel_internal(posterize, datapoints.Image) @_register_kernel_internal(posterize, datapoints.Image)
def posterize_image_tensor(image: torch.Tensor, bits: int) -> torch.Tensor: def posterize_image_tensor(image: torch.Tensor, bits: int) -> torch.Tensor:
if image.is_floating_point(): if image.is_floating_point():
...@@ -486,7 +436,7 @@ def posterize_image_tensor(image: torch.Tensor, bits: int) -> torch.Tensor: ...@@ -486,7 +436,7 @@ def posterize_image_tensor(image: torch.Tensor, bits: int) -> torch.Tensor:
return image & mask return image & mask
posterize_image_pil = _FP.posterize posterize_image_pil = _register_kernel_internal(posterize, PIL.Image.Image)(_FP.posterize)
@_register_kernel_internal(posterize, datapoints.Video) @_register_kernel_internal(posterize, datapoints.Video)
...@@ -496,23 +446,16 @@ def posterize_video(video: torch.Tensor, bits: int) -> torch.Tensor: ...@@ -496,23 +446,16 @@ def posterize_video(video: torch.Tensor, bits: int) -> torch.Tensor:
@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) @_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask)
def solarize(inpt: datapoints._InputTypeJIT, threshold: float) -> datapoints._InputTypeJIT: def solarize(inpt: datapoints._InputTypeJIT, threshold: float) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting(): if torch.jit.is_scripting():
return solarize_image_tensor(inpt, threshold=threshold)
_log_api_usage_once(solarize) _log_api_usage_once(solarize)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return solarize_image_tensor(inpt, threshold=threshold)
elif isinstance(inpt, datapoints.Datapoint):
kernel = _get_kernel(solarize, type(inpt)) kernel = _get_kernel(solarize, type(inpt))
return kernel(inpt, threshold=threshold) return kernel(inpt, threshold=threshold)
elif isinstance(inpt, PIL.Image.Image):
return solarize_image_pil(inpt, threshold=threshold)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
@_register_kernel_internal(solarize, torch.Tensor)
@_register_kernel_internal(solarize, datapoints.Image) @_register_kernel_internal(solarize, datapoints.Image)
def solarize_image_tensor(image: torch.Tensor, threshold: float) -> torch.Tensor: def solarize_image_tensor(image: torch.Tensor, threshold: float) -> torch.Tensor:
if threshold > _max_value(image.dtype): if threshold > _max_value(image.dtype):
...@@ -521,7 +464,7 @@ def solarize_image_tensor(image: torch.Tensor, threshold: float) -> torch.Tensor ...@@ -521,7 +464,7 @@ def solarize_image_tensor(image: torch.Tensor, threshold: float) -> torch.Tensor
return torch.where(image >= threshold, invert_image_tensor(image), image) return torch.where(image >= threshold, invert_image_tensor(image), image)
solarize_image_pil = _FP.solarize solarize_image_pil = _register_kernel_internal(solarize, PIL.Image.Image)(_FP.solarize)
@_register_kernel_internal(solarize, datapoints.Video) @_register_kernel_internal(solarize, datapoints.Video)
...@@ -531,25 +474,16 @@ def solarize_video(video: torch.Tensor, threshold: float) -> torch.Tensor: ...@@ -531,25 +474,16 @@ def solarize_video(video: torch.Tensor, threshold: float) -> torch.Tensor:
@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) @_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask)
def autocontrast(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT: def autocontrast(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting(): if torch.jit.is_scripting():
return autocontrast_image_tensor(inpt)
_log_api_usage_once(autocontrast) _log_api_usage_once(autocontrast)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return autocontrast_image_tensor(inpt)
elif isinstance(inpt, datapoints.Datapoint):
kernel = _get_kernel(autocontrast, type(inpt)) kernel = _get_kernel(autocontrast, type(inpt))
return kernel( return kernel(inpt)
inpt,
)
elif isinstance(inpt, PIL.Image.Image):
return autocontrast_image_pil(inpt)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
@_register_kernel_internal(autocontrast, torch.Tensor)
@_register_kernel_internal(autocontrast, datapoints.Image) @_register_kernel_internal(autocontrast, datapoints.Image)
def autocontrast_image_tensor(image: torch.Tensor) -> torch.Tensor: def autocontrast_image_tensor(image: torch.Tensor) -> torch.Tensor:
c = image.shape[-3] c = image.shape[-3]
...@@ -580,7 +514,7 @@ def autocontrast_image_tensor(image: torch.Tensor) -> torch.Tensor: ...@@ -580,7 +514,7 @@ def autocontrast_image_tensor(image: torch.Tensor) -> torch.Tensor:
return diff.div_(inv_scale).clamp_(0, bound).to(image.dtype) return diff.div_(inv_scale).clamp_(0, bound).to(image.dtype)
autocontrast_image_pil = _FP.autocontrast autocontrast_image_pil = _register_kernel_internal(autocontrast, PIL.Image.Image)(_FP.autocontrast)
@_register_kernel_internal(autocontrast, datapoints.Video) @_register_kernel_internal(autocontrast, datapoints.Video)
...@@ -590,25 +524,16 @@ def autocontrast_video(video: torch.Tensor) -> torch.Tensor: ...@@ -590,25 +524,16 @@ def autocontrast_video(video: torch.Tensor) -> torch.Tensor:
@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) @_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask)
def equalize(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT: def equalize(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting(): if torch.jit.is_scripting():
return equalize_image_tensor(inpt)
_log_api_usage_once(equalize) _log_api_usage_once(equalize)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return equalize_image_tensor(inpt)
elif isinstance(inpt, datapoints.Datapoint):
kernel = _get_kernel(equalize, type(inpt)) kernel = _get_kernel(equalize, type(inpt))
return kernel( return kernel(inpt)
inpt,
)
elif isinstance(inpt, PIL.Image.Image):
return equalize_image_pil(inpt)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
@_register_kernel_internal(equalize, torch.Tensor)
@_register_kernel_internal(equalize, datapoints.Image) @_register_kernel_internal(equalize, datapoints.Image)
def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor: def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor:
if image.numel() == 0: if image.numel() == 0:
...@@ -679,7 +604,7 @@ def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor: ...@@ -679,7 +604,7 @@ def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor:
return to_dtype_image_tensor(output, output_dtype, scale=True) return to_dtype_image_tensor(output, output_dtype, scale=True)
equalize_image_pil = _FP.equalize equalize_image_pil = _register_kernel_internal(equalize, PIL.Image.Image)(_FP.equalize)
@_register_kernel_internal(equalize, datapoints.Video) @_register_kernel_internal(equalize, datapoints.Video)
...@@ -689,25 +614,16 @@ def equalize_video(video: torch.Tensor) -> torch.Tensor: ...@@ -689,25 +614,16 @@ def equalize_video(video: torch.Tensor) -> torch.Tensor:
@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) @_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask)
def invert(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT: def invert(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting(): if torch.jit.is_scripting():
return invert_image_tensor(inpt)
_log_api_usage_once(invert) _log_api_usage_once(invert)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return invert_image_tensor(inpt)
elif isinstance(inpt, datapoints.Datapoint):
kernel = _get_kernel(invert, type(inpt)) kernel = _get_kernel(invert, type(inpt))
return kernel( return kernel(inpt)
inpt,
)
elif isinstance(inpt, PIL.Image.Image):
return invert_image_pil(inpt)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
@_register_kernel_internal(invert, torch.Tensor)
@_register_kernel_internal(invert, datapoints.Image) @_register_kernel_internal(invert, datapoints.Image)
def invert_image_tensor(image: torch.Tensor) -> torch.Tensor: def invert_image_tensor(image: torch.Tensor) -> torch.Tensor:
if image.is_floating_point(): if image.is_floating_point():
...@@ -719,7 +635,7 @@ def invert_image_tensor(image: torch.Tensor) -> torch.Tensor: ...@@ -719,7 +635,7 @@ def invert_image_tensor(image: torch.Tensor) -> torch.Tensor:
return image.bitwise_xor((1 << _num_value_bits(image.dtype)) - 1) return image.bitwise_xor((1 << _num_value_bits(image.dtype)) - 1)
invert_image_pil = _FP.invert invert_image_pil = _register_kernel_internal(invert, PIL.Image.Image)(_FP.invert)
@_register_kernel_internal(invert, datapoints.Video) @_register_kernel_internal(invert, datapoints.Video)
......
...@@ -25,13 +25,7 @@ from torchvision.utils import _log_api_usage_once ...@@ -25,13 +25,7 @@ from torchvision.utils import _log_api_usage_once
from ._meta import clamp_bounding_boxes, convert_format_bounding_boxes, get_size_image_pil from ._meta import clamp_bounding_boxes, convert_format_bounding_boxes, get_size_image_pil
from ._utils import ( from ._utils import _get_kernel, _register_explicit_noop, _register_five_ten_crop_kernel, _register_kernel_internal
_get_kernel,
_register_explicit_noop,
_register_five_ten_crop_kernel,
_register_kernel_internal,
is_simple_tensor,
)
def _check_interpolation(interpolation: Union[InterpolationMode, int]) -> InterpolationMode: def _check_interpolation(interpolation: Union[InterpolationMode, int]) -> InterpolationMode:
...@@ -46,30 +40,22 @@ def _check_interpolation(interpolation: Union[InterpolationMode, int]) -> Interp ...@@ -46,30 +40,22 @@ def _check_interpolation(interpolation: Union[InterpolationMode, int]) -> Interp
def horizontal_flip(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT: def horizontal_flip(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting(): if torch.jit.is_scripting():
return horizontal_flip_image_tensor(inpt)
_log_api_usage_once(horizontal_flip) _log_api_usage_once(horizontal_flip)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return horizontal_flip_image_tensor(inpt)
elif isinstance(inpt, datapoints.Datapoint):
kernel = _get_kernel(horizontal_flip, type(inpt)) kernel = _get_kernel(horizontal_flip, type(inpt))
return kernel( return kernel(inpt)
inpt,
)
elif isinstance(inpt, PIL.Image.Image):
return horizontal_flip_image_pil(inpt)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
@_register_kernel_internal(horizontal_flip, torch.Tensor)
@_register_kernel_internal(horizontal_flip, datapoints.Image) @_register_kernel_internal(horizontal_flip, datapoints.Image)
def horizontal_flip_image_tensor(image: torch.Tensor) -> torch.Tensor: def horizontal_flip_image_tensor(image: torch.Tensor) -> torch.Tensor:
return image.flip(-1) return image.flip(-1)
@_register_kernel_internal(horizontal_flip, PIL.Image.Image)
def horizontal_flip_image_pil(image: PIL.Image.Image) -> PIL.Image.Image: def horizontal_flip_image_pil(image: PIL.Image.Image) -> PIL.Image.Image:
return _FP.hflip(image) return _FP.hflip(image)
...@@ -110,30 +96,22 @@ def horizontal_flip_video(video: torch.Tensor) -> torch.Tensor: ...@@ -110,30 +96,22 @@ def horizontal_flip_video(video: torch.Tensor) -> torch.Tensor:
def vertical_flip(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT: def vertical_flip(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting(): if torch.jit.is_scripting():
return vertical_flip_image_tensor(inpt)
_log_api_usage_once(vertical_flip) _log_api_usage_once(vertical_flip)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return vertical_flip_image_tensor(inpt)
elif isinstance(inpt, datapoints.Datapoint):
kernel = _get_kernel(vertical_flip, type(inpt)) kernel = _get_kernel(vertical_flip, type(inpt))
return kernel( return kernel(inpt)
inpt,
)
elif isinstance(inpt, PIL.Image.Image):
return vertical_flip_image_pil(inpt)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
@_register_kernel_internal(vertical_flip, torch.Tensor)
@_register_kernel_internal(vertical_flip, datapoints.Image) @_register_kernel_internal(vertical_flip, datapoints.Image)
def vertical_flip_image_tensor(image: torch.Tensor) -> torch.Tensor: def vertical_flip_image_tensor(image: torch.Tensor) -> torch.Tensor:
return image.flip(-2) return image.flip(-2)
@_register_kernel_internal(vertical_flip, PIL.Image.Image)
def vertical_flip_image_pil(image: PIL.Image) -> PIL.Image: def vertical_flip_image_pil(image: PIL.Image) -> PIL.Image:
return _FP.vflip(image) return _FP.vflip(image)
...@@ -199,24 +177,16 @@ def resize( ...@@ -199,24 +177,16 @@ def resize(
max_size: Optional[int] = None, max_size: Optional[int] = None,
antialias: Optional[Union[str, bool]] = "warn", antialias: Optional[Union[str, bool]] = "warn",
) -> datapoints._InputTypeJIT: ) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting(): if torch.jit.is_scripting():
return resize_image_tensor(inpt, size=size, interpolation=interpolation, max_size=max_size, antialias=antialias)
_log_api_usage_once(resize) _log_api_usage_once(resize)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return resize_image_tensor(inpt, size, interpolation=interpolation, max_size=max_size, antialias=antialias)
elif isinstance(inpt, datapoints.Datapoint):
kernel = _get_kernel(resize, type(inpt)) kernel = _get_kernel(resize, type(inpt))
return kernel(inpt, size, interpolation=interpolation, max_size=max_size, antialias=antialias) return kernel(inpt, size=size, interpolation=interpolation, max_size=max_size, antialias=antialias)
elif isinstance(inpt, PIL.Image.Image):
if antialias is False:
warnings.warn("Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.")
return resize_image_pil(inpt, size, interpolation=interpolation, max_size=max_size)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
@_register_kernel_internal(resize, torch.Tensor)
@_register_kernel_internal(resize, datapoints.Image) @_register_kernel_internal(resize, datapoints.Image)
def resize_image_tensor( def resize_image_tensor(
image: torch.Tensor, image: torch.Tensor,
...@@ -297,7 +267,6 @@ def resize_image_tensor( ...@@ -297,7 +267,6 @@ def resize_image_tensor(
return image.reshape(shape[:-3] + (num_channels, new_height, new_width)) return image.reshape(shape[:-3] + (num_channels, new_height, new_width))
@torch.jit.unused
def resize_image_pil( def resize_image_pil(
image: PIL.Image.Image, image: PIL.Image.Image,
size: Union[Sequence[int], int], size: Union[Sequence[int], int],
...@@ -319,6 +288,19 @@ def resize_image_pil( ...@@ -319,6 +288,19 @@ def resize_image_pil(
return image.resize((new_width, new_height), resample=pil_modes_mapping[interpolation]) return image.resize((new_width, new_height), resample=pil_modes_mapping[interpolation])
@_register_kernel_internal(resize, PIL.Image.Image)
def _resize_image_pil_dispatch(
image: PIL.Image.Image,
size: Union[Sequence[int], int],
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
max_size: Optional[int] = None,
antialias: Optional[Union[str, bool]] = "warn",
) -> PIL.Image.Image:
if antialias is False:
warnings.warn("Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.")
return resize_image_pil(image, size=size, interpolation=interpolation, max_size=max_size)
def resize_mask(mask: torch.Tensor, size: List[int], max_size: Optional[int] = None) -> torch.Tensor: def resize_mask(mask: torch.Tensor, size: List[int], max_size: Optional[int] = None) -> torch.Tensor:
if mask.ndim < 3: if mask.ndim < 3:
mask = mask.unsqueeze(0) mask = mask.unsqueeze(0)
...@@ -391,14 +373,10 @@ def affine( ...@@ -391,14 +373,10 @@ def affine(
fill: datapoints._FillTypeJIT = None, fill: datapoints._FillTypeJIT = None,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> datapoints._InputTypeJIT: ) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting(): if torch.jit.is_scripting():
_log_api_usage_once(affine)
# TODO: consider deprecating integers from angle and shear on the future
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return affine_image_tensor( return affine_image_tensor(
inpt, inpt,
angle, angle=angle,
translate=translate, translate=translate,
scale=scale, scale=scale,
shear=shear, shear=shear,
...@@ -406,22 +384,13 @@ def affine( ...@@ -406,22 +384,13 @@ def affine(
fill=fill, fill=fill,
center=center, center=center,
) )
elif isinstance(inpt, datapoints.Datapoint):
_log_api_usage_once(affine)
kernel = _get_kernel(affine, type(inpt)) kernel = _get_kernel(affine, type(inpt))
return kernel( return kernel(
inpt, inpt,
angle, angle=angle,
translate=translate,
scale=scale,
shear=shear,
interpolation=interpolation,
fill=fill,
center=center,
)
elif isinstance(inpt, PIL.Image.Image):
return affine_image_pil(
inpt,
angle,
translate=translate, translate=translate,
scale=scale, scale=scale,
shear=shear, shear=shear,
...@@ -429,11 +398,6 @@ def affine( ...@@ -429,11 +398,6 @@ def affine(
fill=fill, fill=fill,
center=center, center=center,
) )
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
def _affine_parse_args( def _affine_parse_args(
...@@ -684,6 +648,7 @@ def _affine_grid( ...@@ -684,6 +648,7 @@ def _affine_grid(
return output_grid.view(1, oh, ow, 2) return output_grid.view(1, oh, ow, 2)
@_register_kernel_internal(affine, torch.Tensor)
@_register_kernel_internal(affine, datapoints.Image) @_register_kernel_internal(affine, datapoints.Image)
def affine_image_tensor( def affine_image_tensor(
image: torch.Tensor, image: torch.Tensor,
...@@ -736,7 +701,7 @@ def affine_image_tensor( ...@@ -736,7 +701,7 @@ def affine_image_tensor(
return output return output
@torch.jit.unused @_register_kernel_internal(affine, PIL.Image.Image)
def affine_image_pil( def affine_image_pil(
image: PIL.Image.Image, image: PIL.Image.Image,
angle: Union[int, float], angle: Union[int, float],
...@@ -983,23 +948,18 @@ def rotate( ...@@ -983,23 +948,18 @@ def rotate(
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
fill: datapoints._FillTypeJIT = None, fill: datapoints._FillTypeJIT = None,
) -> datapoints._InputTypeJIT: ) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting(): if torch.jit.is_scripting():
return rotate_image_tensor(
inpt, angle=angle, interpolation=interpolation, expand=expand, fill=fill, center=center
)
_log_api_usage_once(rotate) _log_api_usage_once(rotate)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return rotate_image_tensor(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center)
elif isinstance(inpt, datapoints.Datapoint):
kernel = _get_kernel(rotate, type(inpt)) kernel = _get_kernel(rotate, type(inpt))
return kernel(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center) return kernel(inpt, angle=angle, interpolation=interpolation, expand=expand, fill=fill, center=center)
elif isinstance(inpt, PIL.Image.Image):
return rotate_image_pil(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
@_register_kernel_internal(rotate, torch.Tensor)
@_register_kernel_internal(rotate, datapoints.Image) @_register_kernel_internal(rotate, datapoints.Image)
def rotate_image_tensor( def rotate_image_tensor(
image: torch.Tensor, image: torch.Tensor,
...@@ -1045,7 +1005,7 @@ def rotate_image_tensor( ...@@ -1045,7 +1005,7 @@ def rotate_image_tensor(
return output.reshape(shape[:-3] + (num_channels, new_height, new_width)) return output.reshape(shape[:-3] + (num_channels, new_height, new_width))
@torch.jit.unused @_register_kernel_internal(rotate, PIL.Image.Image)
def rotate_image_pil( def rotate_image_pil(
image: PIL.Image.Image, image: PIL.Image.Image,
angle: float, angle: float,
...@@ -1162,22 +1122,13 @@ def pad( ...@@ -1162,22 +1122,13 @@ def pad(
fill: Optional[Union[int, float, List[float]]] = None, fill: Optional[Union[int, float, List[float]]] = None,
padding_mode: str = "constant", padding_mode: str = "constant",
) -> datapoints._InputTypeJIT: ) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting(): if torch.jit.is_scripting():
_log_api_usage_once(pad) return pad_image_tensor(inpt, padding=padding, fill=fill, padding_mode=padding_mode)
if torch.jit.is_scripting() or is_simple_tensor(inpt): _log_api_usage_once(pad)
return pad_image_tensor(inpt, padding, fill=fill, padding_mode=padding_mode)
elif isinstance(inpt, datapoints.Datapoint):
kernel = _get_kernel(pad, type(inpt)) kernel = _get_kernel(pad, type(inpt))
return kernel(inpt, padding, fill=fill, padding_mode=padding_mode) return kernel(inpt, padding=padding, fill=fill, padding_mode=padding_mode)
elif isinstance(inpt, PIL.Image.Image):
return pad_image_pil(inpt, padding, fill=fill, padding_mode=padding_mode)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
def _parse_pad_padding(padding: Union[int, List[int]]) -> List[int]: def _parse_pad_padding(padding: Union[int, List[int]]) -> List[int]:
...@@ -1204,6 +1155,7 @@ def _parse_pad_padding(padding: Union[int, List[int]]) -> List[int]: ...@@ -1204,6 +1155,7 @@ def _parse_pad_padding(padding: Union[int, List[int]]) -> List[int]:
return [pad_left, pad_right, pad_top, pad_bottom] return [pad_left, pad_right, pad_top, pad_bottom]
@_register_kernel_internal(pad, torch.Tensor)
@_register_kernel_internal(pad, datapoints.Image) @_register_kernel_internal(pad, datapoints.Image)
def pad_image_tensor( def pad_image_tensor(
image: torch.Tensor, image: torch.Tensor,
...@@ -1303,7 +1255,7 @@ def _pad_with_vector_fill( ...@@ -1303,7 +1255,7 @@ def _pad_with_vector_fill(
return output return output
pad_image_pil = _FP.pad pad_image_pil = _register_kernel_internal(pad, PIL.Image.Image)(_FP.pad)
@_register_kernel_internal(pad, datapoints.Mask) @_register_kernel_internal(pad, datapoints.Mask)
...@@ -1385,23 +1337,16 @@ def pad_video( ...@@ -1385,23 +1337,16 @@ def pad_video(
def crop(inpt: datapoints._InputTypeJIT, top: int, left: int, height: int, width: int) -> datapoints._InputTypeJIT: def crop(inpt: datapoints._InputTypeJIT, top: int, left: int, height: int, width: int) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting(): if torch.jit.is_scripting():
return crop_image_tensor(inpt, top=top, left=left, height=height, width=width)
_log_api_usage_once(crop) _log_api_usage_once(crop)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return crop_image_tensor(inpt, top, left, height, width)
elif isinstance(inpt, datapoints.Datapoint):
kernel = _get_kernel(crop, type(inpt)) kernel = _get_kernel(crop, type(inpt))
return kernel(inpt, top, left, height, width) return kernel(inpt, top=top, left=left, height=height, width=width)
elif isinstance(inpt, PIL.Image.Image):
return crop_image_pil(inpt, top, left, height, width)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
@_register_kernel_internal(crop, torch.Tensor)
@_register_kernel_internal(crop, datapoints.Image) @_register_kernel_internal(crop, datapoints.Image)
def crop_image_tensor(image: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor: def crop_image_tensor(image: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor:
h, w = image.shape[-2:] h, w = image.shape[-2:]
...@@ -1422,6 +1367,7 @@ def crop_image_tensor(image: torch.Tensor, top: int, left: int, height: int, wid ...@@ -1422,6 +1367,7 @@ def crop_image_tensor(image: torch.Tensor, top: int, left: int, height: int, wid
crop_image_pil = _FP.crop crop_image_pil = _FP.crop
_register_kernel_internal(crop, PIL.Image.Image)(crop_image_pil)
def crop_bounding_boxes( def crop_bounding_boxes(
...@@ -1484,23 +1430,26 @@ def perspective( ...@@ -1484,23 +1430,26 @@ def perspective(
fill: datapoints._FillTypeJIT = None, fill: datapoints._FillTypeJIT = None,
coefficients: Optional[List[float]] = None, coefficients: Optional[List[float]] = None,
) -> datapoints._InputTypeJIT: ) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting(): if torch.jit.is_scripting():
_log_api_usage_once(perspective)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return perspective_image_tensor( return perspective_image_tensor(
inpt, startpoints, endpoints, interpolation=interpolation, fill=fill, coefficients=coefficients inpt,
startpoints=startpoints,
endpoints=endpoints,
interpolation=interpolation,
fill=fill,
coefficients=coefficients,
) )
elif isinstance(inpt, datapoints.Datapoint):
_log_api_usage_once(perspective)
kernel = _get_kernel(perspective, type(inpt)) kernel = _get_kernel(perspective, type(inpt))
return kernel(inpt, startpoints, endpoints, interpolation=interpolation, fill=fill, coefficients=coefficients) return kernel(
elif isinstance(inpt, PIL.Image.Image): inpt,
return perspective_image_pil( startpoints=startpoints,
inpt, startpoints, endpoints, interpolation=interpolation, fill=fill, coefficients=coefficients endpoints=endpoints,
) interpolation=interpolation,
else: fill=fill,
raise TypeError( coefficients=coefficients,
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
) )
...@@ -1551,6 +1500,7 @@ def _perspective_coefficients( ...@@ -1551,6 +1500,7 @@ def _perspective_coefficients(
raise ValueError("Either the startpoints/endpoints or the coefficients must have non `None` values.") raise ValueError("Either the startpoints/endpoints or the coefficients must have non `None` values.")
@_register_kernel_internal(perspective, torch.Tensor)
@_register_kernel_internal(perspective, datapoints.Image) @_register_kernel_internal(perspective, datapoints.Image)
def perspective_image_tensor( def perspective_image_tensor(
image: torch.Tensor, image: torch.Tensor,
...@@ -1598,7 +1548,7 @@ def perspective_image_tensor( ...@@ -1598,7 +1548,7 @@ def perspective_image_tensor(
return output return output
@torch.jit.unused @_register_kernel_internal(perspective, PIL.Image.Image)
def perspective_image_pil( def perspective_image_pil(
image: PIL.Image.Image, image: PIL.Image.Image,
startpoints: Optional[List[List[int]]], startpoints: Optional[List[List[int]]],
...@@ -1787,29 +1737,19 @@ def elastic( ...@@ -1787,29 +1737,19 @@ def elastic(
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: datapoints._FillTypeJIT = None, fill: datapoints._FillTypeJIT = None,
) -> datapoints._InputTypeJIT: ) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting(): if torch.jit.is_scripting():
_log_api_usage_once(elastic) return elastic_image_tensor(inpt, displacement=displacement, interpolation=interpolation, fill=fill)
if not isinstance(displacement, torch.Tensor): _log_api_usage_once(elastic)
raise TypeError("Argument displacement should be a Tensor")
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return elastic_image_tensor(inpt, displacement, interpolation=interpolation, fill=fill)
elif isinstance(inpt, datapoints.Datapoint):
kernel = _get_kernel(elastic, type(inpt)) kernel = _get_kernel(elastic, type(inpt))
return kernel(inpt, displacement, interpolation=interpolation, fill=fill) return kernel(inpt, displacement=displacement, interpolation=interpolation, fill=fill)
elif isinstance(inpt, PIL.Image.Image):
return elastic_image_pil(inpt, displacement, interpolation=interpolation, fill=fill)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
elastic_transform = elastic elastic_transform = elastic
@_register_kernel_internal(elastic, torch.Tensor)
@_register_kernel_internal(elastic, datapoints.Image) @_register_kernel_internal(elastic, datapoints.Image)
def elastic_image_tensor( def elastic_image_tensor(
image: torch.Tensor, image: torch.Tensor,
...@@ -1867,7 +1807,7 @@ def elastic_image_tensor( ...@@ -1867,7 +1807,7 @@ def elastic_image_tensor(
return output return output
@torch.jit.unused @_register_kernel_internal(elastic, PIL.Image.Image)
def elastic_image_pil( def elastic_image_pil(
image: PIL.Image.Image, image: PIL.Image.Image,
displacement: torch.Tensor, displacement: torch.Tensor,
...@@ -1990,21 +1930,13 @@ def elastic_video( ...@@ -1990,21 +1930,13 @@ def elastic_video(
def center_crop(inpt: datapoints._InputTypeJIT, output_size: List[int]) -> datapoints._InputTypeJIT: def center_crop(inpt: datapoints._InputTypeJIT, output_size: List[int]) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting(): if torch.jit.is_scripting():
return center_crop_image_tensor(inpt, output_size=output_size)
_log_api_usage_once(center_crop) _log_api_usage_once(center_crop)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return center_crop_image_tensor(inpt, output_size)
elif isinstance(inpt, datapoints.Datapoint):
kernel = _get_kernel(center_crop, type(inpt)) kernel = _get_kernel(center_crop, type(inpt))
return kernel(inpt, output_size) return kernel(inpt, output_size=output_size)
elif isinstance(inpt, PIL.Image.Image):
return center_crop_image_pil(inpt, output_size)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
def _center_crop_parse_output_size(output_size: List[int]) -> List[int]: def _center_crop_parse_output_size(output_size: List[int]) -> List[int]:
...@@ -2034,6 +1966,7 @@ def _center_crop_compute_crop_anchor( ...@@ -2034,6 +1966,7 @@ def _center_crop_compute_crop_anchor(
return crop_top, crop_left return crop_top, crop_left
@_register_kernel_internal(center_crop, torch.Tensor)
@_register_kernel_internal(center_crop, datapoints.Image) @_register_kernel_internal(center_crop, datapoints.Image)
def center_crop_image_tensor(image: torch.Tensor, output_size: List[int]) -> torch.Tensor: def center_crop_image_tensor(image: torch.Tensor, output_size: List[int]) -> torch.Tensor:
crop_height, crop_width = _center_crop_parse_output_size(output_size) crop_height, crop_width = _center_crop_parse_output_size(output_size)
...@@ -2054,7 +1987,7 @@ def center_crop_image_tensor(image: torch.Tensor, output_size: List[int]) -> tor ...@@ -2054,7 +1987,7 @@ def center_crop_image_tensor(image: torch.Tensor, output_size: List[int]) -> tor
return image[..., crop_top : (crop_top + crop_height), crop_left : (crop_left + crop_width)] return image[..., crop_top : (crop_top + crop_height), crop_left : (crop_left + crop_width)]
@torch.jit.unused @_register_kernel_internal(center_crop, PIL.Image.Image)
def center_crop_image_pil(image: PIL.Image.Image, output_size: List[int]) -> PIL.Image.Image: def center_crop_image_pil(image: PIL.Image.Image, output_size: List[int]) -> PIL.Image.Image:
crop_height, crop_width = _center_crop_parse_output_size(output_size) crop_height, crop_width = _center_crop_parse_output_size(output_size)
image_height, image_width = get_size_image_pil(image) image_height, image_width = get_size_image_pil(image)
...@@ -2125,25 +2058,34 @@ def resized_crop( ...@@ -2125,25 +2058,34 @@ def resized_crop(
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
antialias: Optional[Union[str, bool]] = "warn", antialias: Optional[Union[str, bool]] = "warn",
) -> datapoints._InputTypeJIT: ) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting(): if torch.jit.is_scripting():
_log_api_usage_once(resized_crop)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return resized_crop_image_tensor( return resized_crop_image_tensor(
inpt, top, left, height, width, antialias=antialias, size=size, interpolation=interpolation inpt,
top=top,
left=left,
height=height,
width=width,
size=size,
interpolation=interpolation,
antialias=antialias,
) )
elif isinstance(inpt, datapoints.Datapoint):
_log_api_usage_once(resized_crop)
kernel = _get_kernel(resized_crop, type(inpt)) kernel = _get_kernel(resized_crop, type(inpt))
return kernel(inpt, top, left, height, width, antialias=antialias, size=size, interpolation=interpolation) return kernel(
elif isinstance(inpt, PIL.Image.Image): inpt,
return resized_crop_image_pil(inpt, top, left, height, width, size=size, interpolation=interpolation) top=top,
else: left=left,
raise TypeError( height=height,
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " width=width,
f"but got {type(inpt)} instead." size=size,
interpolation=interpolation,
antialias=antialias,
) )
@_register_kernel_internal(resized_crop, torch.Tensor)
@_register_kernel_internal(resized_crop, datapoints.Image) @_register_kernel_internal(resized_crop, datapoints.Image)
def resized_crop_image_tensor( def resized_crop_image_tensor(
image: torch.Tensor, image: torch.Tensor,
...@@ -2159,7 +2101,6 @@ def resized_crop_image_tensor( ...@@ -2159,7 +2101,6 @@ def resized_crop_image_tensor(
return resize_image_tensor(image, size, interpolation=interpolation, antialias=antialias) return resize_image_tensor(image, size, interpolation=interpolation, antialias=antialias)
@torch.jit.unused
def resized_crop_image_pil( def resized_crop_image_pil(
image: PIL.Image.Image, image: PIL.Image.Image,
top: int, top: int,
...@@ -2173,6 +2114,30 @@ def resized_crop_image_pil( ...@@ -2173,6 +2114,30 @@ def resized_crop_image_pil(
return resize_image_pil(image, size, interpolation=interpolation) return resize_image_pil(image, size, interpolation=interpolation)
@_register_kernel_internal(resized_crop, PIL.Image.Image)
def resized_crop_image_pil_dispatch(
image: PIL.Image.Image,
top: int,
left: int,
height: int,
width: int,
size: List[int],
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
antialias: Optional[Union[str, bool]] = "warn",
) -> PIL.Image.Image:
if antialias is False:
warnings.warn("Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.")
return resized_crop_image_pil(
image,
top=top,
left=left,
height=height,
width=width,
size=size,
interpolation=interpolation,
)
def resized_crop_bounding_boxes( def resized_crop_bounding_boxes(
bounding_boxes: torch.Tensor, bounding_boxes: torch.Tensor,
format: datapoints.BoundingBoxFormat, format: datapoints.BoundingBoxFormat,
...@@ -2244,21 +2209,13 @@ def five_crop( ...@@ -2244,21 +2209,13 @@ def five_crop(
datapoints._InputTypeJIT, datapoints._InputTypeJIT,
datapoints._InputTypeJIT, datapoints._InputTypeJIT,
]: ]:
if not torch.jit.is_scripting(): if torch.jit.is_scripting():
return five_crop_image_tensor(inpt, size=size)
_log_api_usage_once(five_crop) _log_api_usage_once(five_crop)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return five_crop_image_tensor(inpt, size)
elif isinstance(inpt, datapoints.Datapoint):
kernel = _get_kernel(five_crop, type(inpt)) kernel = _get_kernel(five_crop, type(inpt))
return kernel(inpt, size) return kernel(inpt, size=size)
elif isinstance(inpt, PIL.Image.Image):
return five_crop_image_pil(inpt, size)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
def _parse_five_crop_size(size: List[int]) -> List[int]: def _parse_five_crop_size(size: List[int]) -> List[int]:
...@@ -2275,6 +2232,7 @@ def _parse_five_crop_size(size: List[int]) -> List[int]: ...@@ -2275,6 +2232,7 @@ def _parse_five_crop_size(size: List[int]) -> List[int]:
return size return size
@_register_five_ten_crop_kernel(five_crop, torch.Tensor)
@_register_five_ten_crop_kernel(five_crop, datapoints.Image) @_register_five_ten_crop_kernel(five_crop, datapoints.Image)
def five_crop_image_tensor( def five_crop_image_tensor(
image: torch.Tensor, size: List[int] image: torch.Tensor, size: List[int]
...@@ -2294,7 +2252,7 @@ def five_crop_image_tensor( ...@@ -2294,7 +2252,7 @@ def five_crop_image_tensor(
return tl, tr, bl, br, center return tl, tr, bl, br, center
@torch.jit.unused @_register_five_ten_crop_kernel(five_crop, PIL.Image.Image)
def five_crop_image_pil( def five_crop_image_pil(
image: PIL.Image.Image, size: List[int] image: PIL.Image.Image, size: List[int]
) -> Tuple[PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image]: ) -> Tuple[PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image]:
...@@ -2335,23 +2293,16 @@ def ten_crop( ...@@ -2335,23 +2293,16 @@ def ten_crop(
datapoints._InputTypeJIT, datapoints._InputTypeJIT,
datapoints._InputTypeJIT, datapoints._InputTypeJIT,
]: ]:
if not torch.jit.is_scripting(): if torch.jit.is_scripting():
return ten_crop_image_tensor(inpt, size=size, vertical_flip=vertical_flip)
_log_api_usage_once(ten_crop) _log_api_usage_once(ten_crop)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return ten_crop_image_tensor(inpt, size, vertical_flip=vertical_flip)
elif isinstance(inpt, datapoints.Datapoint):
kernel = _get_kernel(ten_crop, type(inpt)) kernel = _get_kernel(ten_crop, type(inpt))
return kernel(inpt, size, vertical_flip=vertical_flip) return kernel(inpt, size=size, vertical_flip=vertical_flip)
elif isinstance(inpt, PIL.Image.Image):
return ten_crop_image_pil(inpt, size, vertical_flip=vertical_flip)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
@_register_five_ten_crop_kernel(ten_crop, torch.Tensor)
@_register_five_ten_crop_kernel(ten_crop, datapoints.Image) @_register_five_ten_crop_kernel(ten_crop, datapoints.Image)
def ten_crop_image_tensor( def ten_crop_image_tensor(
image: torch.Tensor, size: List[int], vertical_flip: bool = False image: torch.Tensor, size: List[int], vertical_flip: bool = False
...@@ -2379,7 +2330,7 @@ def ten_crop_image_tensor( ...@@ -2379,7 +2330,7 @@ def ten_crop_image_tensor(
return non_flipped + flipped return non_flipped + flipped
@torch.jit.unused @_register_five_ten_crop_kernel(ten_crop, PIL.Image.Image)
def ten_crop_image_pil( def ten_crop_image_pil(
image: PIL.Image.Image, size: List[int], vertical_flip: bool = False image: PIL.Image.Image, size: List[int], vertical_flip: bool = False
) -> Tuple[ ) -> Tuple[
......
...@@ -13,23 +13,16 @@ from ._utils import _get_kernel, _register_kernel_internal, _register_unsupporte ...@@ -13,23 +13,16 @@ from ._utils import _get_kernel, _register_kernel_internal, _register_unsupporte
@_register_unsupported_type(datapoints.BoundingBoxes, datapoints.Mask) @_register_unsupported_type(datapoints.BoundingBoxes, datapoints.Mask)
def get_dimensions(inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]) -> List[int]: def get_dimensions(inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]) -> List[int]:
if not torch.jit.is_scripting(): if torch.jit.is_scripting():
return get_dimensions_image_tensor(inpt)
_log_api_usage_once(get_dimensions) _log_api_usage_once(get_dimensions)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return get_dimensions_image_tensor(inpt)
elif isinstance(inpt, datapoints.Datapoint):
kernel = _get_kernel(get_dimensions, type(inpt)) kernel = _get_kernel(get_dimensions, type(inpt))
return kernel(inpt) return kernel(inpt)
elif isinstance(inpt, PIL.Image.Image):
return get_dimensions_image_pil(inpt)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
@_register_kernel_internal(get_dimensions, torch.Tensor)
@_register_kernel_internal(get_dimensions, datapoints.Image, datapoint_wrapper=False) @_register_kernel_internal(get_dimensions, datapoints.Image, datapoint_wrapper=False)
def get_dimensions_image_tensor(image: torch.Tensor) -> List[int]: def get_dimensions_image_tensor(image: torch.Tensor) -> List[int]:
chw = list(image.shape[-3:]) chw = list(image.shape[-3:])
...@@ -43,7 +36,7 @@ def get_dimensions_image_tensor(image: torch.Tensor) -> List[int]: ...@@ -43,7 +36,7 @@ def get_dimensions_image_tensor(image: torch.Tensor) -> List[int]:
raise TypeError(f"Input tensor should have at least two dimensions, but got {ndims}") raise TypeError(f"Input tensor should have at least two dimensions, but got {ndims}")
get_dimensions_image_pil = _FP.get_dimensions get_dimensions_image_pil = _register_kernel_internal(get_dimensions, PIL.Image.Image)(_FP.get_dimensions)
@_register_kernel_internal(get_dimensions, datapoints.Video, datapoint_wrapper=False) @_register_kernel_internal(get_dimensions, datapoints.Video, datapoint_wrapper=False)
...@@ -53,23 +46,16 @@ def get_dimensions_video(video: torch.Tensor) -> List[int]: ...@@ -53,23 +46,16 @@ def get_dimensions_video(video: torch.Tensor) -> List[int]:
@_register_unsupported_type(datapoints.BoundingBoxes, datapoints.Mask) @_register_unsupported_type(datapoints.BoundingBoxes, datapoints.Mask)
def get_num_channels(inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]) -> int: def get_num_channels(inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]) -> int:
if not torch.jit.is_scripting(): if torch.jit.is_scripting():
return get_num_channels_image_tensor(inpt)
_log_api_usage_once(get_num_channels) _log_api_usage_once(get_num_channels)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return get_num_channels_image_tensor(inpt)
elif isinstance(inpt, datapoints.Datapoint):
kernel = _get_kernel(get_num_channels, type(inpt)) kernel = _get_kernel(get_num_channels, type(inpt))
return kernel(inpt) return kernel(inpt)
elif isinstance(inpt, PIL.Image.Image):
return get_num_channels_image_pil(inpt)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
@_register_kernel_internal(get_num_channels, torch.Tensor)
@_register_kernel_internal(get_num_channels, datapoints.Image, datapoint_wrapper=False) @_register_kernel_internal(get_num_channels, datapoints.Image, datapoint_wrapper=False)
def get_num_channels_image_tensor(image: torch.Tensor) -> int: def get_num_channels_image_tensor(image: torch.Tensor) -> int:
chw = image.shape[-3:] chw = image.shape[-3:]
...@@ -82,7 +68,7 @@ def get_num_channels_image_tensor(image: torch.Tensor) -> int: ...@@ -82,7 +68,7 @@ def get_num_channels_image_tensor(image: torch.Tensor) -> int:
raise TypeError(f"Input tensor should have at least two dimensions, but got {ndims}") raise TypeError(f"Input tensor should have at least two dimensions, but got {ndims}")
get_num_channels_image_pil = _FP.get_image_num_channels get_num_channels_image_pil = _register_kernel_internal(get_num_channels, PIL.Image.Image)(_FP.get_image_num_channels)
@_register_kernel_internal(get_num_channels, datapoints.Video, datapoint_wrapper=False) @_register_kernel_internal(get_num_channels, datapoints.Video, datapoint_wrapper=False)
...@@ -96,23 +82,16 @@ get_image_num_channels = get_num_channels ...@@ -96,23 +82,16 @@ get_image_num_channels = get_num_channels
def get_size(inpt: datapoints._InputTypeJIT) -> List[int]: def get_size(inpt: datapoints._InputTypeJIT) -> List[int]:
if not torch.jit.is_scripting(): if torch.jit.is_scripting():
return get_size_image_tensor(inpt)
_log_api_usage_once(get_size) _log_api_usage_once(get_size)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return get_size_image_tensor(inpt)
elif isinstance(inpt, datapoints.Datapoint):
kernel = _get_kernel(get_size, type(inpt)) kernel = _get_kernel(get_size, type(inpt))
return kernel(inpt) return kernel(inpt)
elif isinstance(inpt, PIL.Image.Image):
return get_size_image_pil(inpt)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
@_register_kernel_internal(get_size, torch.Tensor)
@_register_kernel_internal(get_size, datapoints.Image, datapoint_wrapper=False) @_register_kernel_internal(get_size, datapoints.Image, datapoint_wrapper=False)
def get_size_image_tensor(image: torch.Tensor) -> List[int]: def get_size_image_tensor(image: torch.Tensor) -> List[int]:
hw = list(image.shape[-2:]) hw = list(image.shape[-2:])
...@@ -123,7 +102,7 @@ def get_size_image_tensor(image: torch.Tensor) -> List[int]: ...@@ -123,7 +102,7 @@ def get_size_image_tensor(image: torch.Tensor) -> List[int]:
raise TypeError(f"Input tensor should have at least two dimensions, but got {ndims}") raise TypeError(f"Input tensor should have at least two dimensions, but got {ndims}")
@torch.jit.unused @_register_kernel_internal(get_size, PIL.Image.Image)
def get_size_image_pil(image: PIL.Image.Image) -> List[int]: def get_size_image_pil(image: PIL.Image.Image) -> List[int]:
width, height = _FP.get_image_size(image) width, height = _FP.get_image_size(image)
return [height, width] return [height, width]
...@@ -146,21 +125,16 @@ def get_size_bounding_boxes(bounding_box: datapoints.BoundingBoxes) -> List[int] ...@@ -146,21 +125,16 @@ def get_size_bounding_boxes(bounding_box: datapoints.BoundingBoxes) -> List[int]
@_register_unsupported_type(PIL.Image.Image, datapoints.Image, datapoints.BoundingBoxes, datapoints.Mask) @_register_unsupported_type(PIL.Image.Image, datapoints.Image, datapoints.BoundingBoxes, datapoints.Mask)
def get_num_frames(inpt: datapoints._VideoTypeJIT) -> int: def get_num_frames(inpt: datapoints._VideoTypeJIT) -> int:
if not torch.jit.is_scripting(): if torch.jit.is_scripting():
return get_num_frames_video(inpt)
_log_api_usage_once(get_num_frames) _log_api_usage_once(get_num_frames)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return get_num_frames_video(inpt)
elif isinstance(inpt, datapoints.Datapoint):
kernel = _get_kernel(get_num_frames, type(inpt)) kernel = _get_kernel(get_num_frames, type(inpt))
return kernel(inpt) return kernel(inpt)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
@_register_kernel_internal(get_num_frames, torch.Tensor)
@_register_kernel_internal(get_num_frames, datapoints.Video, datapoint_wrapper=False) @_register_kernel_internal(get_num_frames, datapoints.Video, datapoint_wrapper=False)
def get_num_frames_video(video: torch.Tensor) -> int: def get_num_frames_video(video: torch.Tensor) -> int:
return video.shape[-4] return video.shape[-4]
......
...@@ -11,13 +11,7 @@ from torchvision.transforms.functional import pil_to_tensor, to_pil_image ...@@ -11,13 +11,7 @@ from torchvision.transforms.functional import pil_to_tensor, to_pil_image
from torchvision.utils import _log_api_usage_once from torchvision.utils import _log_api_usage_once
from ._utils import ( from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_internal, _register_unsupported_type
_get_kernel,
_register_explicit_noop,
_register_kernel_internal,
_register_unsupported_type,
is_simple_tensor,
)
@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) @_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask)
...@@ -28,19 +22,16 @@ def normalize( ...@@ -28,19 +22,16 @@ def normalize(
std: List[float], std: List[float],
inplace: bool = False, inplace: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
if not torch.jit.is_scripting(): if torch.jit.is_scripting():
_log_api_usage_once(normalize)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return normalize_image_tensor(inpt, mean=mean, std=std, inplace=inplace) return normalize_image_tensor(inpt, mean=mean, std=std, inplace=inplace)
elif isinstance(inpt, datapoints.Datapoint):
_log_api_usage_once(normalize)
kernel = _get_kernel(normalize, type(inpt)) kernel = _get_kernel(normalize, type(inpt))
return kernel(inpt, mean=mean, std=std, inplace=inplace) return kernel(inpt, mean=mean, std=std, inplace=inplace)
else:
raise TypeError(
f"Input can either be a plain tensor or any TorchVision datapoint, but got {type(inpt)} instead."
)
@_register_kernel_internal(normalize, torch.Tensor)
@_register_kernel_internal(normalize, datapoints.Image) @_register_kernel_internal(normalize, datapoints.Image)
def normalize_image_tensor( def normalize_image_tensor(
image: torch.Tensor, mean: List[float], std: List[float], inplace: bool = False image: torch.Tensor, mean: List[float], std: List[float], inplace: bool = False
...@@ -86,21 +77,13 @@ def normalize_video(video: torch.Tensor, mean: List[float], std: List[float], in ...@@ -86,21 +77,13 @@ def normalize_video(video: torch.Tensor, mean: List[float], std: List[float], in
def gaussian_blur( def gaussian_blur(
inpt: datapoints._InputTypeJIT, kernel_size: List[int], sigma: Optional[List[float]] = None inpt: datapoints._InputTypeJIT, kernel_size: List[int], sigma: Optional[List[float]] = None
) -> datapoints._InputTypeJIT: ) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting(): if torch.jit.is_scripting():
return gaussian_blur_image_tensor(inpt, kernel_size=kernel_size, sigma=sigma)
_log_api_usage_once(gaussian_blur) _log_api_usage_once(gaussian_blur)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return gaussian_blur_image_tensor(inpt, kernel_size=kernel_size, sigma=sigma)
elif isinstance(inpt, datapoints.Datapoint):
kernel = _get_kernel(gaussian_blur, type(inpt)) kernel = _get_kernel(gaussian_blur, type(inpt))
return kernel(inpt, kernel_size=kernel_size, sigma=sigma) return kernel(inpt, kernel_size=kernel_size, sigma=sigma)
elif isinstance(inpt, PIL.Image.Image):
return gaussian_blur_image_pil(inpt, kernel_size=kernel_size, sigma=sigma)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
def _get_gaussian_kernel1d(kernel_size: int, sigma: float, dtype: torch.dtype, device: torch.device) -> torch.Tensor: def _get_gaussian_kernel1d(kernel_size: int, sigma: float, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
...@@ -119,6 +102,7 @@ def _get_gaussian_kernel2d( ...@@ -119,6 +102,7 @@ def _get_gaussian_kernel2d(
return kernel2d return kernel2d
@_register_kernel_internal(gaussian_blur, torch.Tensor)
@_register_kernel_internal(gaussian_blur, datapoints.Image) @_register_kernel_internal(gaussian_blur, datapoints.Image)
def gaussian_blur_image_tensor( def gaussian_blur_image_tensor(
image: torch.Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None image: torch.Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None
...@@ -184,7 +168,7 @@ def gaussian_blur_image_tensor( ...@@ -184,7 +168,7 @@ def gaussian_blur_image_tensor(
return output return output
@torch.jit.unused @_register_kernel_internal(gaussian_blur, PIL.Image.Image)
def gaussian_blur_image_pil( def gaussian_blur_image_pil(
image: PIL.Image.Image, kernel_size: List[int], sigma: Optional[List[float]] = None image: PIL.Image.Image, kernel_size: List[int], sigma: Optional[List[float]] = None
) -> PIL.Image.Image: ) -> PIL.Image.Image:
...@@ -200,21 +184,17 @@ def gaussian_blur_video( ...@@ -200,21 +184,17 @@ def gaussian_blur_video(
return gaussian_blur_image_tensor(video, kernel_size, sigma) return gaussian_blur_image_tensor(video, kernel_size, sigma)
@_register_unsupported_type(PIL.Image.Image)
def to_dtype( def to_dtype(
inpt: datapoints._InputTypeJIT, dtype: torch.dtype = torch.float, scale: bool = False inpt: datapoints._InputTypeJIT, dtype: torch.dtype = torch.float, scale: bool = False
) -> datapoints._InputTypeJIT: ) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting(): if torch.jit.is_scripting():
return to_dtype_image_tensor(inpt, dtype=dtype, scale=scale)
_log_api_usage_once(to_dtype) _log_api_usage_once(to_dtype)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return to_dtype_image_tensor(inpt, dtype, scale=scale)
elif isinstance(inpt, datapoints.Datapoint):
kernel = _get_kernel(to_dtype, type(inpt)) kernel = _get_kernel(to_dtype, type(inpt))
return kernel(inpt, dtype, scale=scale) return kernel(inpt, dtype=dtype, scale=scale)
else:
raise TypeError(
f"Input can either be a plain tensor or any TorchVision datapoint, but got {type(inpt)} instead."
)
def _num_value_bits(dtype: torch.dtype) -> int: def _num_value_bits(dtype: torch.dtype) -> int:
...@@ -232,6 +212,7 @@ def _num_value_bits(dtype: torch.dtype) -> int: ...@@ -232,6 +212,7 @@ def _num_value_bits(dtype: torch.dtype) -> int:
raise TypeError(f"Number of value bits is only defined for integer dtypes, but got {dtype}.") raise TypeError(f"Number of value bits is only defined for integer dtypes, but got {dtype}.")
@_register_kernel_internal(to_dtype, torch.Tensor)
@_register_kernel_internal(to_dtype, datapoints.Image) @_register_kernel_internal(to_dtype, datapoints.Image)
def to_dtype_image_tensor(image: torch.Tensor, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor: def to_dtype_image_tensor(image: torch.Tensor, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor:
......
...@@ -5,27 +5,23 @@ from torchvision import datapoints ...@@ -5,27 +5,23 @@ from torchvision import datapoints
from torchvision.utils import _log_api_usage_once from torchvision.utils import _log_api_usage_once
from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_internal, is_simple_tensor from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_internal
@_register_explicit_noop( @_register_explicit_noop(
PIL.Image.Image, datapoints.Image, datapoints.BoundingBoxes, datapoints.Mask, warn_passthrough=True PIL.Image.Image, datapoints.Image, datapoints.BoundingBoxes, datapoints.Mask, warn_passthrough=True
) )
def uniform_temporal_subsample(inpt: datapoints._VideoTypeJIT, num_samples: int) -> datapoints._VideoTypeJIT: def uniform_temporal_subsample(inpt: datapoints._VideoTypeJIT, num_samples: int) -> datapoints._VideoTypeJIT:
if not torch.jit.is_scripting(): if torch.jit.is_scripting():
return uniform_temporal_subsample_video(inpt, num_samples=num_samples)
_log_api_usage_once(uniform_temporal_subsample) _log_api_usage_once(uniform_temporal_subsample)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return uniform_temporal_subsample_video(inpt, num_samples)
elif isinstance(inpt, datapoints.Datapoint):
kernel = _get_kernel(uniform_temporal_subsample, type(inpt)) kernel = _get_kernel(uniform_temporal_subsample, type(inpt))
return kernel(inpt, num_samples) return kernel(inpt, num_samples=num_samples)
else:
raise TypeError(
f"Input can either be a plain tensor or any TorchVision datapoint, but got {type(inpt)} instead."
)
@_register_kernel_internal(uniform_temporal_subsample, torch.Tensor)
@_register_kernel_internal(uniform_temporal_subsample, datapoints.Video) @_register_kernel_internal(uniform_temporal_subsample, datapoints.Video)
def uniform_temporal_subsample_video(video: torch.Tensor, num_samples: int) -> torch.Tensor: def uniform_temporal_subsample_video(video: torch.Tensor, num_samples: int) -> torch.Tensor:
# Reference: https://github.com/facebookresearch/pytorchvideo/blob/a0a131e/pytorchvideo/transforms/functional.py#L19 # Reference: https://github.com/facebookresearch/pytorchvideo/blob/a0a131e/pytorchvideo/transforms/functional.py#L19
......
...@@ -23,15 +23,17 @@ def _kernel_datapoint_wrapper(kernel): ...@@ -23,15 +23,17 @@ def _kernel_datapoint_wrapper(kernel):
return wrapper return wrapper
def _register_kernel_internal(dispatcher, datapoint_cls, *, datapoint_wrapper=True): def _register_kernel_internal(dispatcher, input_type, *, datapoint_wrapper=True):
registry = _KERNEL_REGISTRY.setdefault(dispatcher, {}) registry = _KERNEL_REGISTRY.setdefault(dispatcher, {})
if datapoint_cls in registry: if input_type in registry:
raise TypeError( raise ValueError(f"Dispatcher {dispatcher} already has a kernel registered for type {input_type}.")
f"Dispatcher '{dispatcher.__name__}' already has a kernel registered for type '{datapoint_cls.__name__}'."
)
def decorator(kernel): def decorator(kernel):
registry[datapoint_cls] = _kernel_datapoint_wrapper(kernel) if datapoint_wrapper else kernel registry[input_type] = (
_kernel_datapoint_wrapper(kernel)
if issubclass(input_type, datapoints.Datapoint) and datapoint_wrapper
else kernel
)
return kernel return kernel
return decorator return decorator
...@@ -43,7 +45,9 @@ def _name_to_dispatcher(name): ...@@ -43,7 +45,9 @@ def _name_to_dispatcher(name):
try: try:
return getattr(torchvision.transforms.v2.functional, name) return getattr(torchvision.transforms.v2.functional, name)
except AttributeError: except AttributeError:
raise ValueError(f"Could not find dispatcher with name '{name}'.") from None raise ValueError(
f"Could not find dispatcher with name '{name}' in torchvision.transforms.v2.functional."
) from None
def register_kernel(dispatcher, datapoint_cls): def register_kernel(dispatcher, datapoint_cls):
...@@ -54,23 +58,58 @@ def register_kernel(dispatcher, datapoint_cls): ...@@ -54,23 +58,58 @@ def register_kernel(dispatcher, datapoint_cls):
""" """
if isinstance(dispatcher, str): if isinstance(dispatcher, str):
dispatcher = _name_to_dispatcher(name=dispatcher) dispatcher = _name_to_dispatcher(name=dispatcher)
elif not (
callable(dispatcher)
and getattr(dispatcher, "__module__", "").startswith("torchvision.transforms.v2.functional")
):
raise ValueError(
f"Kernels can only be registered on dispatchers from the torchvision.transforms.v2.functional namespace, "
f"but got {dispatcher}."
)
if not (
isinstance(datapoint_cls, type)
and issubclass(datapoint_cls, datapoints.Datapoint)
and datapoint_cls is not datapoints.Datapoint
):
raise ValueError(
f"Kernels can only be registered for subclasses of torchvision.datapoints.Datapoint, "
f"but got {datapoint_cls}."
)
return _register_kernel_internal(dispatcher, datapoint_cls, datapoint_wrapper=False) return _register_kernel_internal(dispatcher, datapoint_cls, datapoint_wrapper=False)
def _get_kernel(dispatcher, datapoint_cls): def _get_kernel(dispatcher, input_type):
registry = _KERNEL_REGISTRY.get(dispatcher) registry = _KERNEL_REGISTRY.get(dispatcher)
if not registry: if not registry:
raise ValueError(f"No kernel registered for dispatcher '{dispatcher.__name__}'.") raise ValueError(f"No kernel registered for dispatcher {dispatcher.__name__}.")
if datapoint_cls in registry: # In case we have an exact type match, we take a shortcut.
return registry[datapoint_cls] if input_type in registry:
return registry[input_type]
for registered_cls, kernel in registry.items():
if issubclass(datapoint_cls, registered_cls): # In case of datapoints, we check if we have a kernel for a superclass registered
return kernel if issubclass(input_type, datapoints.Datapoint):
# Since we have already checked for an exact match above, we can start the traversal at the superclass.
for cls in input_type.__mro__[1:]:
if cls is datapoints.Datapoint:
# We don't want user-defined datapoints to dispatch to the pure Tensor kernels, so we explicit stop the
# MRO traversal before hitting torch.Tensor. We can even stop at datapoints.Datapoint, since we don't
# allow kernels to be registered for datapoints.Datapoint anyway.
break
elif cls in registry:
return registry[cls]
# Note that in the future we are not going to return a noop here, but rather raise the error below
return _noop return _noop
raise TypeError(
f"Dispatcher {dispatcher} supports inputs of type torch.Tensor, PIL.Image.Image, "
f"and subclasses of torchvision.datapoints.Datapoint, "
f"but got {input_type} instead."
)
# Everything below this block is stuff that we need right now, since it looks like we need to release in an intermediate # Everything below this block is stuff that we need right now, since it looks like we need to release in an intermediate
# stage. See https://github.com/pytorch/vision/pull/7747#issuecomment-1661698450 for details. # stage. See https://github.com/pytorch/vision/pull/7747#issuecomment-1661698450 for details.
...@@ -101,7 +140,9 @@ def _register_explicit_noop(*datapoints_classes, warn_passthrough=False): ...@@ -101,7 +140,9 @@ def _register_explicit_noop(*datapoints_classes, warn_passthrough=False):
f"F.{dispatcher.__name__} is currently passing through inputs of type datapoints.{cls.__name__}. " f"F.{dispatcher.__name__} is currently passing through inputs of type datapoints.{cls.__name__}. "
f"This will likely change in the future." f"This will likely change in the future."
) )
register_kernel(dispatcher, cls)(functools.partial(_noop, __msg__=msg if warn_passthrough else None)) _register_kernel_internal(dispatcher, cls, datapoint_wrapper=False)(
functools.partial(_noop, __msg__=msg if warn_passthrough else None)
)
return dispatcher return dispatcher
return decorator return decorator
...@@ -115,13 +156,15 @@ def _noop(inpt, *args, __msg__=None, **kwargs): ...@@ -115,13 +156,15 @@ def _noop(inpt, *args, __msg__=None, **kwargs):
# TODO: we only need this, since our default behavior in case no kernel is found is passthrough. When we change that # TODO: we only need this, since our default behavior in case no kernel is found is passthrough. When we change that
# to error later, this decorator can be removed, since the error will be raised by _get_kernel # to error later, this decorator can be removed, since the error will be raised by _get_kernel
def _register_unsupported_type(*datapoints_classes): def _register_unsupported_type(*input_types):
def kernel(inpt, *args, __dispatcher_name__, **kwargs): def kernel(inpt, *args, __dispatcher_name__, **kwargs):
raise TypeError(f"F.{__dispatcher_name__} does not support inputs of type {type(inpt)}.") raise TypeError(f"F.{__dispatcher_name__} does not support inputs of type {type(inpt)}.")
def decorator(dispatcher): def decorator(dispatcher):
for cls in datapoints_classes: for input_type in input_types:
register_kernel(dispatcher, cls)(functools.partial(kernel, __dispatcher_name__=dispatcher.__name__)) _register_kernel_internal(dispatcher, input_type, datapoint_wrapper=False)(
functools.partial(kernel, __dispatcher_name__=dispatcher.__name__)
)
return dispatcher return dispatcher
return decorator return decorator
...@@ -129,13 +172,10 @@ def _register_unsupported_type(*datapoints_classes): ...@@ -129,13 +172,10 @@ def _register_unsupported_type(*datapoints_classes):
# This basically replicates _register_kernel_internal, but with a specialized wrapper for five_crop / ten_crop # This basically replicates _register_kernel_internal, but with a specialized wrapper for five_crop / ten_crop
# We could get rid of this by letting _register_kernel_internal take arbitrary dispatchers rather than wrap_kernel: bool # We could get rid of this by letting _register_kernel_internal take arbitrary dispatchers rather than wrap_kernel: bool
# TODO: decide if we want that def _register_five_ten_crop_kernel(dispatcher, input_type):
def _register_five_ten_crop_kernel(dispatcher, datapoint_cls):
registry = _KERNEL_REGISTRY.setdefault(dispatcher, {}) registry = _KERNEL_REGISTRY.setdefault(dispatcher, {})
if datapoint_cls in registry: if input_type in registry:
raise TypeError( raise TypeError(f"Dispatcher '{dispatcher}' already has a kernel registered for type '{input_type}'.")
f"Dispatcher '{dispatcher.__name__}' already has a kernel registered for type '{datapoint_cls.__name__}'."
)
def wrap(kernel): def wrap(kernel):
@functools.wraps(kernel) @functools.wraps(kernel)
...@@ -147,7 +187,7 @@ def _register_five_ten_crop_kernel(dispatcher, datapoint_cls): ...@@ -147,7 +187,7 @@ def _register_five_ten_crop_kernel(dispatcher, datapoint_cls):
return wrapper return wrapper
def decorator(kernel): def decorator(kernel):
registry[datapoint_cls] = wrap(kernel) registry[input_type] = wrap(kernel) if issubclass(input_type, datapoints.Datapoint) else kernel
return kernel return kernel
return decorator return decorator
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment