"git@developer.sourcefind.cn:OpenDAS/torchani.git" did not exist on "2febd8bcbd7d2e645e6934eb5ae06fab312bd8fa"
Unverified Commit 2030d208 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

register tensor and PIL kernel the same way as datapoints (#7797)


Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
parent 84db2ac4
...@@ -2,7 +2,6 @@ import inspect ...@@ -2,7 +2,6 @@ import inspect
import math import math
import os import os
import re import re
from unittest import mock
import numpy as np import numpy as np
import PIL.Image import PIL.Image
...@@ -25,7 +24,6 @@ from torchvision.transforms.functional import _get_perspective_coeffs ...@@ -25,7 +24,6 @@ from torchvision.transforms.functional import _get_perspective_coeffs
from torchvision.transforms.v2 import functional as F from torchvision.transforms.v2 import functional as F
from torchvision.transforms.v2.functional._geometry import _center_crop_compute_padding from torchvision.transforms.v2.functional._geometry import _center_crop_compute_padding
from torchvision.transforms.v2.functional._meta import clamp_bounding_boxes, convert_format_bounding_boxes from torchvision.transforms.v2.functional._meta import clamp_bounding_boxes, convert_format_bounding_boxes
from torchvision.transforms.v2.functional._utils import _KERNEL_REGISTRY
from torchvision.transforms.v2.utils import is_simple_tensor from torchvision.transforms.v2.utils import is_simple_tensor
from transforms_v2_dispatcher_infos import DISPATCHER_INFOS from transforms_v2_dispatcher_infos import DISPATCHER_INFOS
from transforms_v2_kernel_infos import KERNEL_INFOS from transforms_v2_kernel_infos import KERNEL_INFOS
...@@ -359,18 +357,6 @@ class TestDispatchers: ...@@ -359,18 +357,6 @@ class TestDispatchers:
def test_scriptable(self, dispatcher): def test_scriptable(self, dispatcher):
script(dispatcher) script(dispatcher)
@image_sample_inputs
def test_dispatch_simple_tensor(self, info, args_kwargs, spy_on):
(image_datapoint, *other_args), kwargs = args_kwargs.load()
image_simple_tensor = torch.Tensor(image_datapoint)
kernel_info = info.kernel_infos[datapoints.Image]
spy = spy_on(kernel_info.kernel, module=info.dispatcher.__module__, name=kernel_info.id)
info.dispatcher(image_simple_tensor, *other_args, **kwargs)
spy.assert_called_once()
@image_sample_inputs @image_sample_inputs
def test_simple_tensor_output_type(self, info, args_kwargs): def test_simple_tensor_output_type(self, info, args_kwargs):
(image_datapoint, *other_args), kwargs = args_kwargs.load() (image_datapoint, *other_args), kwargs = args_kwargs.load()
...@@ -381,25 +367,6 @@ class TestDispatchers: ...@@ -381,25 +367,6 @@ class TestDispatchers:
# We cannot use `isinstance` here since all datapoints are instances of `torch.Tensor` as well # We cannot use `isinstance` here since all datapoints are instances of `torch.Tensor` as well
assert type(output) is torch.Tensor assert type(output) is torch.Tensor
@make_info_args_kwargs_parametrization(
[info for info in DISPATCHER_INFOS if info.pil_kernel_info is not None],
args_kwargs_fn=lambda info: info.sample_inputs(datapoints.Image),
)
def test_dispatch_pil(self, info, args_kwargs, spy_on):
(image_datapoint, *other_args), kwargs = args_kwargs.load()
if image_datapoint.ndim > 3:
pytest.skip("Input is batched")
image_pil = F.to_image_pil(image_datapoint)
pil_kernel_info = info.pil_kernel_info
spy = spy_on(pil_kernel_info.kernel, module=info.dispatcher.__module__, name=pil_kernel_info.id)
info.dispatcher(image_pil, *other_args, **kwargs)
spy.assert_called_once()
@make_info_args_kwargs_parametrization( @make_info_args_kwargs_parametrization(
[info for info in DISPATCHER_INFOS if info.pil_kernel_info is not None], [info for info in DISPATCHER_INFOS if info.pil_kernel_info is not None],
args_kwargs_fn=lambda info: info.sample_inputs(datapoints.Image), args_kwargs_fn=lambda info: info.sample_inputs(datapoints.Image),
...@@ -416,28 +383,6 @@ class TestDispatchers: ...@@ -416,28 +383,6 @@ class TestDispatchers:
assert isinstance(output, PIL.Image.Image) assert isinstance(output, PIL.Image.Image)
@make_info_args_kwargs_parametrization(
DISPATCHER_INFOS,
args_kwargs_fn=lambda info: info.sample_inputs(),
)
def test_dispatch_datapoint(self, info, args_kwargs, spy_on):
(datapoint, *other_args), kwargs = args_kwargs.load()
input_type = type(datapoint)
wrapped_kernel = _KERNEL_REGISTRY[info.dispatcher][input_type]
# In case the wrapper was decorated with @functools.wraps, we can make the check more strict and test if the
# proper kernel was wrapped
if hasattr(wrapped_kernel, "__wrapped__"):
assert wrapped_kernel.__wrapped__ is info.kernels[input_type]
spy = mock.MagicMock(wraps=wrapped_kernel, name=wrapped_kernel.__name__)
with mock.patch.dict(_KERNEL_REGISTRY[info.dispatcher], values={input_type: spy}):
info.dispatcher(datapoint, *other_args, **kwargs)
spy.assert_called_once()
@make_info_args_kwargs_parametrization( @make_info_args_kwargs_parametrization(
DISPATCHER_INFOS, DISPATCHER_INFOS,
args_kwargs_fn=lambda info: info.sample_inputs(), args_kwargs_fn=lambda info: info.sample_inputs(),
...@@ -449,6 +394,9 @@ class TestDispatchers: ...@@ -449,6 +394,9 @@ class TestDispatchers:
assert isinstance(output, type(datapoint)) assert isinstance(output, type(datapoint))
if isinstance(datapoint, datapoints.BoundingBoxes) and info.dispatcher is not F.convert_format_bounding_boxes:
assert output.format == datapoint.format
@pytest.mark.parametrize( @pytest.mark.parametrize(
("dispatcher_info", "datapoint_type", "kernel_info"), ("dispatcher_info", "datapoint_type", "kernel_info"),
[ [
......
...@@ -39,7 +39,7 @@ from torchvision import datapoints ...@@ -39,7 +39,7 @@ from torchvision import datapoints
from torchvision.transforms._functional_tensor import _max_value as get_max_value from torchvision.transforms._functional_tensor import _max_value as get_max_value
from torchvision.transforms.functional import pil_modes_mapping from torchvision.transforms.functional import pil_modes_mapping
from torchvision.transforms.v2 import functional as F from torchvision.transforms.v2 import functional as F
from torchvision.transforms.v2.functional._utils import _KERNEL_REGISTRY from torchvision.transforms.v2.functional._utils import _get_kernel, _KERNEL_REGISTRY, _noop, _register_kernel_internal
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
...@@ -173,59 +173,32 @@ def _check_dispatcher_scripted_smoke(dispatcher, input, *args, **kwargs): ...@@ -173,59 +173,32 @@ def _check_dispatcher_scripted_smoke(dispatcher, input, *args, **kwargs):
dispatcher_scripted(input.as_subclass(torch.Tensor), *args, **kwargs) dispatcher_scripted(input.as_subclass(torch.Tensor), *args, **kwargs)
def _check_dispatcher_dispatch(dispatcher, kernel, input, *args, **kwargs):
"""Checks if the dispatcher correctly dispatches the input to the corresponding kernel and that the input type is
preserved in doing so. For bounding boxes also checks that the format is preserved.
"""
input_type = type(input)
if isinstance(input, datapoints.Datapoint):
wrapped_kernel = _KERNEL_REGISTRY[dispatcher][input_type]
# In case the wrapper was decorated with @functools.wraps, we can make the check more strict and test if the
# proper kernel was wrapped
if hasattr(wrapped_kernel, "__wrapped__"):
assert wrapped_kernel.__wrapped__ is kernel
spy = mock.MagicMock(wraps=wrapped_kernel, name=wrapped_kernel.__name__)
with mock.patch.dict(_KERNEL_REGISTRY[dispatcher], values={input_type: spy}):
output = dispatcher(input, *args, **kwargs)
spy.assert_called_once()
else:
with mock.patch(f"{dispatcher.__module__}.{kernel.__name__}", wraps=kernel) as spy:
output = dispatcher(input, *args, **kwargs)
spy.assert_called_once()
assert isinstance(output, input_type)
if isinstance(input, datapoints.BoundingBoxes):
assert output.format == input.format
def check_dispatcher( def check_dispatcher(
dispatcher, dispatcher,
# TODO: remove this parameter
kernel, kernel,
input, input,
*args, *args,
check_scripted_smoke=True, check_scripted_smoke=True,
check_dispatch=True,
**kwargs, **kwargs,
): ):
unknown_input = object() unknown_input = object()
with pytest.raises(TypeError, match=re.escape(str(type(unknown_input)))):
dispatcher(unknown_input, *args, **kwargs)
with mock.patch("torch._C._log_api_usage_once", wraps=torch._C._log_api_usage_once) as spy: with mock.patch("torch._C._log_api_usage_once", wraps=torch._C._log_api_usage_once) as spy:
with pytest.raises(TypeError, match=re.escape(str(type(unknown_input)))): output = dispatcher(input, *args, **kwargs)
dispatcher(unknown_input, *args, **kwargs)
spy.assert_any_call(f"{dispatcher.__module__}.{dispatcher.__name__}") spy.assert_any_call(f"{dispatcher.__module__}.{dispatcher.__name__}")
assert isinstance(output, type(input))
if isinstance(input, datapoints.BoundingBoxes):
assert output.format == input.format
if check_scripted_smoke: if check_scripted_smoke:
_check_dispatcher_scripted_smoke(dispatcher, input, *args, **kwargs) _check_dispatcher_scripted_smoke(dispatcher, input, *args, **kwargs)
if check_dispatch:
_check_dispatcher_dispatch(dispatcher, kernel, input, *args, **kwargs)
def check_dispatcher_kernel_signature_match(dispatcher, *, kernel, input_type): def check_dispatcher_kernel_signature_match(dispatcher, *, kernel, input_type):
"""Checks if the signature of the dispatcher matches the kernel signature.""" """Checks if the signature of the dispatcher matches the kernel signature."""
...@@ -412,18 +385,20 @@ def reference_affine_bounding_boxes_helper(bounding_boxes, *, format, canvas_siz ...@@ -412,18 +385,20 @@ def reference_affine_bounding_boxes_helper(bounding_boxes, *, format, canvas_siz
@pytest.mark.parametrize( @pytest.mark.parametrize(
("dispatcher", "registered_datapoint_clss"), ("dispatcher", "registered_input_types"),
[(dispatcher, set(registry.keys())) for dispatcher, registry in _KERNEL_REGISTRY.items()], [(dispatcher, set(registry.keys())) for dispatcher, registry in _KERNEL_REGISTRY.items()],
) )
def test_exhaustive_kernel_registration(dispatcher, registered_datapoint_clss): def test_exhaustive_kernel_registration(dispatcher, registered_input_types):
missing = { missing = {
torch.Tensor,
PIL.Image.Image,
datapoints.Image, datapoints.Image,
datapoints.BoundingBoxes, datapoints.BoundingBoxes,
datapoints.Mask, datapoints.Mask,
datapoints.Video, datapoints.Video,
} - registered_datapoint_clss } - registered_input_types
if missing: if missing:
names = sorted(f"datapoints.{cls.__name__}" for cls in missing) names = sorted(str(t) for t in missing)
raise AssertionError( raise AssertionError(
"\n".join( "\n".join(
[ [
...@@ -1753,11 +1728,6 @@ class TestToDtype: ...@@ -1753,11 +1728,6 @@ class TestToDtype:
F.to_dtype, F.to_dtype,
kernel, kernel,
make_input(dtype=input_dtype, device=device), make_input(dtype=input_dtype, device=device),
# TODO: we could leave check_dispatch to True but it currently fails
# in _check_dispatcher_dispatch because there is no to_dtype() method on the datapoints.
# We should be able to put this back if we change the dispatch
# mechanism e.g. via https://github.com/pytorch/vision/pull/7733
check_dispatch=False,
dtype=output_dtype, dtype=output_dtype,
scale=scale, scale=scale,
) )
...@@ -2208,9 +2178,105 @@ class TestRegisterKernel: ...@@ -2208,9 +2178,105 @@ class TestRegisterKernel:
t(torch.rand(3, 10, 10)).shape == (3, 224, 224) t(torch.rand(3, 10, 10)).shape == (3, 224, 224)
t(datapoints.Image(torch.rand(3, 10, 10))).shape == (3, 224, 224) t(datapoints.Image(torch.rand(3, 10, 10))).shape == (3, 224, 224)
def test_bad_disaptcher_name(self): def test_errors(self):
class CustomDatapoint(datapoints.Datapoint): with pytest.raises(ValueError, match="Could not find dispatcher with name"):
F.register_kernel("bad_name", datapoints.Image)
with pytest.raises(ValueError, match="Kernels can only be registered on dispatchers"):
F.register_kernel(datapoints.Image, F.resize)
with pytest.raises(ValueError, match="Kernels can only be registered for subclasses"):
F.register_kernel(F.resize, object)
with pytest.raises(ValueError, match="already has a kernel registered for type"):
F.register_kernel(F.resize, datapoints.Image)(F.resize_image_tensor)
class TestGetKernel:
# We are using F.resize as dispatcher and the kernels below as proxy. Any other dispatcher / kernels combination
# would also be fine
KERNELS = {
torch.Tensor: F.resize_image_tensor,
PIL.Image.Image: F.resize_image_pil,
datapoints.Image: F.resize_image_tensor,
datapoints.BoundingBoxes: F.resize_bounding_boxes,
datapoints.Mask: F.resize_mask,
datapoints.Video: F.resize_video,
}
def test_unsupported_types(self):
class MyTensor(torch.Tensor):
pass pass
with pytest.raises(ValueError, match="Could not find dispatcher with name"): class MyPILImage(PIL.Image.Image):
F.register_kernel("bad_name", CustomDatapoint) pass
for input_type in [str, int, object, MyTensor, MyPILImage]:
with pytest.raises(
TypeError,
match=(
"supports inputs of type torch.Tensor, PIL.Image.Image, "
"and subclasses of torchvision.datapoints.Datapoint"
),
):
_get_kernel(F.resize, input_type)
def test_exact_match(self):
# We cannot use F.resize together with self.KERNELS mapping here directly here, since this is only the
# ideal wrapping. Practically, we have an intermediate wrapper layer. Thus, we create a new resize dispatcher
# here, register the kernels without wrapper, and check the exact matching afterwards.
def resize_with_pure_kernels():
pass
for input_type, kernel in self.KERNELS.items():
_register_kernel_internal(resize_with_pure_kernels, input_type, datapoint_wrapper=False)(kernel)
assert _get_kernel(resize_with_pure_kernels, input_type) is kernel
def test_builtin_datapoint_subclass(self):
# We cannot use F.resize together with self.KERNELS mapping here directly here, since this is only the
# ideal wrapping. Practically, we have an intermediate wrapper layer. Thus, we create a new resize dispatcher
# here, register the kernels without wrapper, and check if subclasses of our builtin datapoints get dispatched
# to the kernel of the corresponding superclass
def resize_with_pure_kernels():
pass
class MyImage(datapoints.Image):
pass
class MyBoundingBoxes(datapoints.BoundingBoxes):
pass
class MyMask(datapoints.Mask):
pass
class MyVideo(datapoints.Video):
pass
for custom_datapoint_subclass in [
MyImage,
MyBoundingBoxes,
MyMask,
MyVideo,
]:
builtin_datapoint_class = custom_datapoint_subclass.__mro__[1]
builtin_datapoint_kernel = self.KERNELS[builtin_datapoint_class]
_register_kernel_internal(resize_with_pure_kernels, builtin_datapoint_class, datapoint_wrapper=False)(
builtin_datapoint_kernel
)
assert _get_kernel(resize_with_pure_kernels, custom_datapoint_subclass) is builtin_datapoint_kernel
def test_datapoint_subclass(self):
class MyDatapoint(datapoints.Datapoint):
pass
# Note that this will be an error in the future
assert _get_kernel(F.resize, MyDatapoint) is _noop
def resize_my_datapoint():
pass
_register_kernel_internal(F.resize, MyDatapoint, datapoint_wrapper=False)(resize_my_datapoint)
assert _get_kernel(F.resize, MyDatapoint) is resize_my_datapoint
...@@ -7,7 +7,7 @@ from torchvision import datapoints ...@@ -7,7 +7,7 @@ from torchvision import datapoints
from torchvision.transforms.functional import pil_to_tensor, to_pil_image from torchvision.transforms.functional import pil_to_tensor, to_pil_image
from torchvision.utils import _log_api_usage_once from torchvision.utils import _log_api_usage_once
from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_internal, is_simple_tensor from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_internal
@_register_explicit_noop(datapoints.Mask, datapoints.BoundingBoxes, warn_passthrough=True) @_register_explicit_noop(datapoints.Mask, datapoints.BoundingBoxes, warn_passthrough=True)
...@@ -20,23 +20,16 @@ def erase( ...@@ -20,23 +20,16 @@ def erase(
v: torch.Tensor, v: torch.Tensor,
inplace: bool = False, inplace: bool = False,
) -> Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]: ) -> Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]:
if not torch.jit.is_scripting(): if torch.jit.is_scripting():
_log_api_usage_once(erase)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return erase_image_tensor(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace) return erase_image_tensor(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
elif isinstance(inpt, datapoints.Datapoint):
kernel = _get_kernel(erase, type(inpt)) _log_api_usage_once(erase)
return kernel(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
elif isinstance(inpt, PIL.Image.Image): kernel = _get_kernel(erase, type(inpt))
return erase_image_pil(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace) return kernel(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
@_register_kernel_internal(erase, torch.Tensor)
@_register_kernel_internal(erase, datapoints.Image) @_register_kernel_internal(erase, datapoints.Image)
def erase_image_tensor( def erase_image_tensor(
image: torch.Tensor, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False image: torch.Tensor, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False
...@@ -48,7 +41,7 @@ def erase_image_tensor( ...@@ -48,7 +41,7 @@ def erase_image_tensor(
return image return image
@torch.jit.unused @_register_kernel_internal(erase, PIL.Image.Image)
def erase_image_pil( def erase_image_pil(
image: PIL.Image.Image, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False image: PIL.Image.Image, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False
) -> PIL.Image.Image: ) -> PIL.Image.Image:
......
...@@ -13,23 +13,16 @@ from ._utils import _get_kernel, _register_kernel_internal, _register_unsupporte ...@@ -13,23 +13,16 @@ from ._utils import _get_kernel, _register_kernel_internal, _register_unsupporte
@_register_unsupported_type(datapoints.BoundingBoxes, datapoints.Mask) @_register_unsupported_type(datapoints.BoundingBoxes, datapoints.Mask)
def get_dimensions(inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]) -> List[int]: def get_dimensions(inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]) -> List[int]:
if not torch.jit.is_scripting(): if torch.jit.is_scripting():
_log_api_usage_once(get_dimensions)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return get_dimensions_image_tensor(inpt) return get_dimensions_image_tensor(inpt)
elif isinstance(inpt, datapoints.Datapoint):
kernel = _get_kernel(get_dimensions, type(inpt)) _log_api_usage_once(get_dimensions)
return kernel(inpt)
elif isinstance(inpt, PIL.Image.Image): kernel = _get_kernel(get_dimensions, type(inpt))
return get_dimensions_image_pil(inpt) return kernel(inpt)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
@_register_kernel_internal(get_dimensions, torch.Tensor)
@_register_kernel_internal(get_dimensions, datapoints.Image, datapoint_wrapper=False) @_register_kernel_internal(get_dimensions, datapoints.Image, datapoint_wrapper=False)
def get_dimensions_image_tensor(image: torch.Tensor) -> List[int]: def get_dimensions_image_tensor(image: torch.Tensor) -> List[int]:
chw = list(image.shape[-3:]) chw = list(image.shape[-3:])
...@@ -43,7 +36,7 @@ def get_dimensions_image_tensor(image: torch.Tensor) -> List[int]: ...@@ -43,7 +36,7 @@ def get_dimensions_image_tensor(image: torch.Tensor) -> List[int]:
raise TypeError(f"Input tensor should have at least two dimensions, but got {ndims}") raise TypeError(f"Input tensor should have at least two dimensions, but got {ndims}")
get_dimensions_image_pil = _FP.get_dimensions get_dimensions_image_pil = _register_kernel_internal(get_dimensions, PIL.Image.Image)(_FP.get_dimensions)
@_register_kernel_internal(get_dimensions, datapoints.Video, datapoint_wrapper=False) @_register_kernel_internal(get_dimensions, datapoints.Video, datapoint_wrapper=False)
...@@ -53,23 +46,16 @@ def get_dimensions_video(video: torch.Tensor) -> List[int]: ...@@ -53,23 +46,16 @@ def get_dimensions_video(video: torch.Tensor) -> List[int]:
@_register_unsupported_type(datapoints.BoundingBoxes, datapoints.Mask) @_register_unsupported_type(datapoints.BoundingBoxes, datapoints.Mask)
def get_num_channels(inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]) -> int: def get_num_channels(inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]) -> int:
if not torch.jit.is_scripting(): if torch.jit.is_scripting():
_log_api_usage_once(get_num_channels)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return get_num_channels_image_tensor(inpt) return get_num_channels_image_tensor(inpt)
elif isinstance(inpt, datapoints.Datapoint):
kernel = _get_kernel(get_num_channels, type(inpt)) _log_api_usage_once(get_num_channels)
return kernel(inpt)
elif isinstance(inpt, PIL.Image.Image): kernel = _get_kernel(get_num_channels, type(inpt))
return get_num_channels_image_pil(inpt) return kernel(inpt)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
@_register_kernel_internal(get_num_channels, torch.Tensor)
@_register_kernel_internal(get_num_channels, datapoints.Image, datapoint_wrapper=False) @_register_kernel_internal(get_num_channels, datapoints.Image, datapoint_wrapper=False)
def get_num_channels_image_tensor(image: torch.Tensor) -> int: def get_num_channels_image_tensor(image: torch.Tensor) -> int:
chw = image.shape[-3:] chw = image.shape[-3:]
...@@ -82,7 +68,7 @@ def get_num_channels_image_tensor(image: torch.Tensor) -> int: ...@@ -82,7 +68,7 @@ def get_num_channels_image_tensor(image: torch.Tensor) -> int:
raise TypeError(f"Input tensor should have at least two dimensions, but got {ndims}") raise TypeError(f"Input tensor should have at least two dimensions, but got {ndims}")
get_num_channels_image_pil = _FP.get_image_num_channels get_num_channels_image_pil = _register_kernel_internal(get_num_channels, PIL.Image.Image)(_FP.get_image_num_channels)
@_register_kernel_internal(get_num_channels, datapoints.Video, datapoint_wrapper=False) @_register_kernel_internal(get_num_channels, datapoints.Video, datapoint_wrapper=False)
...@@ -96,23 +82,16 @@ get_image_num_channels = get_num_channels ...@@ -96,23 +82,16 @@ get_image_num_channels = get_num_channels
def get_size(inpt: datapoints._InputTypeJIT) -> List[int]: def get_size(inpt: datapoints._InputTypeJIT) -> List[int]:
if not torch.jit.is_scripting(): if torch.jit.is_scripting():
_log_api_usage_once(get_size)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return get_size_image_tensor(inpt) return get_size_image_tensor(inpt)
elif isinstance(inpt, datapoints.Datapoint):
kernel = _get_kernel(get_size, type(inpt)) _log_api_usage_once(get_size)
return kernel(inpt)
elif isinstance(inpt, PIL.Image.Image): kernel = _get_kernel(get_size, type(inpt))
return get_size_image_pil(inpt) return kernel(inpt)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
@_register_kernel_internal(get_size, torch.Tensor)
@_register_kernel_internal(get_size, datapoints.Image, datapoint_wrapper=False) @_register_kernel_internal(get_size, datapoints.Image, datapoint_wrapper=False)
def get_size_image_tensor(image: torch.Tensor) -> List[int]: def get_size_image_tensor(image: torch.Tensor) -> List[int]:
hw = list(image.shape[-2:]) hw = list(image.shape[-2:])
...@@ -123,7 +102,7 @@ def get_size_image_tensor(image: torch.Tensor) -> List[int]: ...@@ -123,7 +102,7 @@ def get_size_image_tensor(image: torch.Tensor) -> List[int]:
raise TypeError(f"Input tensor should have at least two dimensions, but got {ndims}") raise TypeError(f"Input tensor should have at least two dimensions, but got {ndims}")
@torch.jit.unused @_register_kernel_internal(get_size, PIL.Image.Image)
def get_size_image_pil(image: PIL.Image.Image) -> List[int]: def get_size_image_pil(image: PIL.Image.Image) -> List[int]:
width, height = _FP.get_image_size(image) width, height = _FP.get_image_size(image)
return [height, width] return [height, width]
...@@ -146,21 +125,16 @@ def get_size_bounding_boxes(bounding_box: datapoints.BoundingBoxes) -> List[int] ...@@ -146,21 +125,16 @@ def get_size_bounding_boxes(bounding_box: datapoints.BoundingBoxes) -> List[int]
@_register_unsupported_type(PIL.Image.Image, datapoints.Image, datapoints.BoundingBoxes, datapoints.Mask) @_register_unsupported_type(PIL.Image.Image, datapoints.Image, datapoints.BoundingBoxes, datapoints.Mask)
def get_num_frames(inpt: datapoints._VideoTypeJIT) -> int: def get_num_frames(inpt: datapoints._VideoTypeJIT) -> int:
if not torch.jit.is_scripting(): if torch.jit.is_scripting():
_log_api_usage_once(get_num_frames)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return get_num_frames_video(inpt) return get_num_frames_video(inpt)
elif isinstance(inpt, datapoints.Datapoint):
kernel = _get_kernel(get_num_frames, type(inpt)) _log_api_usage_once(get_num_frames)
return kernel(inpt)
else: kernel = _get_kernel(get_num_frames, type(inpt))
raise TypeError( return kernel(inpt)
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
@_register_kernel_internal(get_num_frames, torch.Tensor)
@_register_kernel_internal(get_num_frames, datapoints.Video, datapoint_wrapper=False) @_register_kernel_internal(get_num_frames, datapoints.Video, datapoint_wrapper=False)
def get_num_frames_video(video: torch.Tensor) -> int: def get_num_frames_video(video: torch.Tensor) -> int:
return video.shape[-4] return video.shape[-4]
......
...@@ -11,13 +11,7 @@ from torchvision.transforms.functional import pil_to_tensor, to_pil_image ...@@ -11,13 +11,7 @@ from torchvision.transforms.functional import pil_to_tensor, to_pil_image
from torchvision.utils import _log_api_usage_once from torchvision.utils import _log_api_usage_once
from ._utils import ( from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_internal, _register_unsupported_type
_get_kernel,
_register_explicit_noop,
_register_kernel_internal,
_register_unsupported_type,
is_simple_tensor,
)
@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) @_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask)
...@@ -28,19 +22,16 @@ def normalize( ...@@ -28,19 +22,16 @@ def normalize(
std: List[float], std: List[float],
inplace: bool = False, inplace: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
if not torch.jit.is_scripting(): if torch.jit.is_scripting():
_log_api_usage_once(normalize)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return normalize_image_tensor(inpt, mean=mean, std=std, inplace=inplace) return normalize_image_tensor(inpt, mean=mean, std=std, inplace=inplace)
elif isinstance(inpt, datapoints.Datapoint):
kernel = _get_kernel(normalize, type(inpt)) _log_api_usage_once(normalize)
return kernel(inpt, mean=mean, std=std, inplace=inplace)
else: kernel = _get_kernel(normalize, type(inpt))
raise TypeError( return kernel(inpt, mean=mean, std=std, inplace=inplace)
f"Input can either be a plain tensor or any TorchVision datapoint, but got {type(inpt)} instead."
)
@_register_kernel_internal(normalize, torch.Tensor)
@_register_kernel_internal(normalize, datapoints.Image) @_register_kernel_internal(normalize, datapoints.Image)
def normalize_image_tensor( def normalize_image_tensor(
image: torch.Tensor, mean: List[float], std: List[float], inplace: bool = False image: torch.Tensor, mean: List[float], std: List[float], inplace: bool = False
...@@ -86,21 +77,13 @@ def normalize_video(video: torch.Tensor, mean: List[float], std: List[float], in ...@@ -86,21 +77,13 @@ def normalize_video(video: torch.Tensor, mean: List[float], std: List[float], in
def gaussian_blur( def gaussian_blur(
inpt: datapoints._InputTypeJIT, kernel_size: List[int], sigma: Optional[List[float]] = None inpt: datapoints._InputTypeJIT, kernel_size: List[int], sigma: Optional[List[float]] = None
) -> datapoints._InputTypeJIT: ) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting(): if torch.jit.is_scripting():
_log_api_usage_once(gaussian_blur)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return gaussian_blur_image_tensor(inpt, kernel_size=kernel_size, sigma=sigma) return gaussian_blur_image_tensor(inpt, kernel_size=kernel_size, sigma=sigma)
elif isinstance(inpt, datapoints.Datapoint):
kernel = _get_kernel(gaussian_blur, type(inpt)) _log_api_usage_once(gaussian_blur)
return kernel(inpt, kernel_size=kernel_size, sigma=sigma)
elif isinstance(inpt, PIL.Image.Image): kernel = _get_kernel(gaussian_blur, type(inpt))
return gaussian_blur_image_pil(inpt, kernel_size=kernel_size, sigma=sigma) return kernel(inpt, kernel_size=kernel_size, sigma=sigma)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
def _get_gaussian_kernel1d(kernel_size: int, sigma: float, dtype: torch.dtype, device: torch.device) -> torch.Tensor: def _get_gaussian_kernel1d(kernel_size: int, sigma: float, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
...@@ -119,6 +102,7 @@ def _get_gaussian_kernel2d( ...@@ -119,6 +102,7 @@ def _get_gaussian_kernel2d(
return kernel2d return kernel2d
@_register_kernel_internal(gaussian_blur, torch.Tensor)
@_register_kernel_internal(gaussian_blur, datapoints.Image) @_register_kernel_internal(gaussian_blur, datapoints.Image)
def gaussian_blur_image_tensor( def gaussian_blur_image_tensor(
image: torch.Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None image: torch.Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None
...@@ -184,7 +168,7 @@ def gaussian_blur_image_tensor( ...@@ -184,7 +168,7 @@ def gaussian_blur_image_tensor(
return output return output
@torch.jit.unused @_register_kernel_internal(gaussian_blur, PIL.Image.Image)
def gaussian_blur_image_pil( def gaussian_blur_image_pil(
image: PIL.Image.Image, kernel_size: List[int], sigma: Optional[List[float]] = None image: PIL.Image.Image, kernel_size: List[int], sigma: Optional[List[float]] = None
) -> PIL.Image.Image: ) -> PIL.Image.Image:
...@@ -200,21 +184,17 @@ def gaussian_blur_video( ...@@ -200,21 +184,17 @@ def gaussian_blur_video(
return gaussian_blur_image_tensor(video, kernel_size, sigma) return gaussian_blur_image_tensor(video, kernel_size, sigma)
@_register_unsupported_type(PIL.Image.Image)
def to_dtype( def to_dtype(
inpt: datapoints._InputTypeJIT, dtype: torch.dtype = torch.float, scale: bool = False inpt: datapoints._InputTypeJIT, dtype: torch.dtype = torch.float, scale: bool = False
) -> datapoints._InputTypeJIT: ) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting(): if torch.jit.is_scripting():
_log_api_usage_once(to_dtype) return to_dtype_image_tensor(inpt, dtype=dtype, scale=scale)
if torch.jit.is_scripting() or is_simple_tensor(inpt): _log_api_usage_once(to_dtype)
return to_dtype_image_tensor(inpt, dtype, scale=scale)
elif isinstance(inpt, datapoints.Datapoint): kernel = _get_kernel(to_dtype, type(inpt))
kernel = _get_kernel(to_dtype, type(inpt)) return kernel(inpt, dtype=dtype, scale=scale)
return kernel(inpt, dtype, scale=scale)
else:
raise TypeError(
f"Input can either be a plain tensor or any TorchVision datapoint, but got {type(inpt)} instead."
)
def _num_value_bits(dtype: torch.dtype) -> int: def _num_value_bits(dtype: torch.dtype) -> int:
...@@ -232,6 +212,7 @@ def _num_value_bits(dtype: torch.dtype) -> int: ...@@ -232,6 +212,7 @@ def _num_value_bits(dtype: torch.dtype) -> int:
raise TypeError(f"Number of value bits is only defined for integer dtypes, but got {dtype}.") raise TypeError(f"Number of value bits is only defined for integer dtypes, but got {dtype}.")
@_register_kernel_internal(to_dtype, torch.Tensor)
@_register_kernel_internal(to_dtype, datapoints.Image) @_register_kernel_internal(to_dtype, datapoints.Image)
def to_dtype_image_tensor(image: torch.Tensor, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor: def to_dtype_image_tensor(image: torch.Tensor, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor:
......
...@@ -5,27 +5,23 @@ from torchvision import datapoints ...@@ -5,27 +5,23 @@ from torchvision import datapoints
from torchvision.utils import _log_api_usage_once from torchvision.utils import _log_api_usage_once
from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_internal, is_simple_tensor from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_internal
@_register_explicit_noop( @_register_explicit_noop(
PIL.Image.Image, datapoints.Image, datapoints.BoundingBoxes, datapoints.Mask, warn_passthrough=True PIL.Image.Image, datapoints.Image, datapoints.BoundingBoxes, datapoints.Mask, warn_passthrough=True
) )
def uniform_temporal_subsample(inpt: datapoints._VideoTypeJIT, num_samples: int) -> datapoints._VideoTypeJIT: def uniform_temporal_subsample(inpt: datapoints._VideoTypeJIT, num_samples: int) -> datapoints._VideoTypeJIT:
if not torch.jit.is_scripting(): if torch.jit.is_scripting():
_log_api_usage_once(uniform_temporal_subsample) return uniform_temporal_subsample_video(inpt, num_samples=num_samples)
if torch.jit.is_scripting() or is_simple_tensor(inpt): _log_api_usage_once(uniform_temporal_subsample)
return uniform_temporal_subsample_video(inpt, num_samples)
elif isinstance(inpt, datapoints.Datapoint):
kernel = _get_kernel(uniform_temporal_subsample, type(inpt))
return kernel(inpt, num_samples)
else:
raise TypeError(
f"Input can either be a plain tensor or any TorchVision datapoint, but got {type(inpt)} instead."
)
kernel = _get_kernel(uniform_temporal_subsample, type(inpt))
return kernel(inpt, num_samples=num_samples)
@_register_kernel_internal(uniform_temporal_subsample, torch.Tensor)
@_register_kernel_internal(uniform_temporal_subsample, datapoints.Video) @_register_kernel_internal(uniform_temporal_subsample, datapoints.Video)
def uniform_temporal_subsample_video(video: torch.Tensor, num_samples: int) -> torch.Tensor: def uniform_temporal_subsample_video(video: torch.Tensor, num_samples: int) -> torch.Tensor:
# Reference: https://github.com/facebookresearch/pytorchvideo/blob/a0a131e/pytorchvideo/transforms/functional.py#L19 # Reference: https://github.com/facebookresearch/pytorchvideo/blob/a0a131e/pytorchvideo/transforms/functional.py#L19
......
...@@ -23,15 +23,17 @@ def _kernel_datapoint_wrapper(kernel): ...@@ -23,15 +23,17 @@ def _kernel_datapoint_wrapper(kernel):
return wrapper return wrapper
def _register_kernel_internal(dispatcher, datapoint_cls, *, datapoint_wrapper=True): def _register_kernel_internal(dispatcher, input_type, *, datapoint_wrapper=True):
registry = _KERNEL_REGISTRY.setdefault(dispatcher, {}) registry = _KERNEL_REGISTRY.setdefault(dispatcher, {})
if datapoint_cls in registry: if input_type in registry:
raise TypeError( raise ValueError(f"Dispatcher {dispatcher} already has a kernel registered for type {input_type}.")
f"Dispatcher '{dispatcher.__name__}' already has a kernel registered for type '{datapoint_cls.__name__}'."
)
def decorator(kernel): def decorator(kernel):
registry[datapoint_cls] = _kernel_datapoint_wrapper(kernel) if datapoint_wrapper else kernel registry[input_type] = (
_kernel_datapoint_wrapper(kernel)
if issubclass(input_type, datapoints.Datapoint) and datapoint_wrapper
else kernel
)
return kernel return kernel
return decorator return decorator
...@@ -43,7 +45,9 @@ def _name_to_dispatcher(name): ...@@ -43,7 +45,9 @@ def _name_to_dispatcher(name):
try: try:
return getattr(torchvision.transforms.v2.functional, name) return getattr(torchvision.transforms.v2.functional, name)
except AttributeError: except AttributeError:
raise ValueError(f"Could not find dispatcher with name '{name}'.") from None raise ValueError(
f"Could not find dispatcher with name '{name}' in torchvision.transforms.v2.functional."
) from None
def register_kernel(dispatcher, datapoint_cls): def register_kernel(dispatcher, datapoint_cls):
...@@ -54,22 +58,57 @@ def register_kernel(dispatcher, datapoint_cls): ...@@ -54,22 +58,57 @@ def register_kernel(dispatcher, datapoint_cls):
""" """
if isinstance(dispatcher, str): if isinstance(dispatcher, str):
dispatcher = _name_to_dispatcher(name=dispatcher) dispatcher = _name_to_dispatcher(name=dispatcher)
elif not (
callable(dispatcher)
and getattr(dispatcher, "__module__", "").startswith("torchvision.transforms.v2.functional")
):
raise ValueError(
f"Kernels can only be registered on dispatchers from the torchvision.transforms.v2.functional namespace, "
f"but got {dispatcher}."
)
if not (
isinstance(datapoint_cls, type)
and issubclass(datapoint_cls, datapoints.Datapoint)
and datapoint_cls is not datapoints.Datapoint
):
raise ValueError(
f"Kernels can only be registered for subclasses of torchvision.datapoints.Datapoint, "
f"but got {datapoint_cls}."
)
return _register_kernel_internal(dispatcher, datapoint_cls, datapoint_wrapper=False) return _register_kernel_internal(dispatcher, datapoint_cls, datapoint_wrapper=False)
def _get_kernel(dispatcher, datapoint_cls): def _get_kernel(dispatcher, input_type):
registry = _KERNEL_REGISTRY.get(dispatcher) registry = _KERNEL_REGISTRY.get(dispatcher)
if not registry: if not registry:
raise ValueError(f"No kernel registered for dispatcher '{dispatcher.__name__}'.") raise ValueError(f"No kernel registered for dispatcher {dispatcher.__name__}.")
if datapoint_cls in registry: # In case we have an exact type match, we take a shortcut.
return registry[datapoint_cls] if input_type in registry:
return registry[input_type]
for registered_cls, kernel in registry.items():
if issubclass(datapoint_cls, registered_cls): # In case of datapoints, we check if we have a kernel for a superclass registered
return kernel if issubclass(input_type, datapoints.Datapoint):
# Since we have already checked for an exact match above, we can start the traversal at the superclass.
return _noop for cls in input_type.__mro__[1:]:
if cls is datapoints.Datapoint:
# We don't want user-defined datapoints to dispatch to the pure Tensor kernels, so we explicit stop the
# MRO traversal before hitting torch.Tensor. We can even stop at datapoints.Datapoint, since we don't
# allow kernels to be registered for datapoints.Datapoint anyway.
break
elif cls in registry:
return registry[cls]
# Note that in the future we are not going to return a noop here, but rather raise the error below
return _noop
raise TypeError(
f"Dispatcher {dispatcher} supports inputs of type torch.Tensor, PIL.Image.Image, "
f"and subclasses of torchvision.datapoints.Datapoint, "
f"but got {input_type} instead."
)
# Everything below this block is stuff that we need right now, since it looks like we need to release in an intermediate # Everything below this block is stuff that we need right now, since it looks like we need to release in an intermediate
...@@ -101,7 +140,9 @@ def _register_explicit_noop(*datapoints_classes, warn_passthrough=False): ...@@ -101,7 +140,9 @@ def _register_explicit_noop(*datapoints_classes, warn_passthrough=False):
f"F.{dispatcher.__name__} is currently passing through inputs of type datapoints.{cls.__name__}. " f"F.{dispatcher.__name__} is currently passing through inputs of type datapoints.{cls.__name__}. "
f"This will likely change in the future." f"This will likely change in the future."
) )
register_kernel(dispatcher, cls)(functools.partial(_noop, __msg__=msg if warn_passthrough else None)) _register_kernel_internal(dispatcher, cls, datapoint_wrapper=False)(
functools.partial(_noop, __msg__=msg if warn_passthrough else None)
)
return dispatcher return dispatcher
return decorator return decorator
...@@ -115,13 +156,15 @@ def _noop(inpt, *args, __msg__=None, **kwargs): ...@@ -115,13 +156,15 @@ def _noop(inpt, *args, __msg__=None, **kwargs):
# TODO: we only need this, since our default behavior in case no kernel is found is passthrough. When we change that # TODO: we only need this, since our default behavior in case no kernel is found is passthrough. When we change that
# to error later, this decorator can be removed, since the error will be raised by _get_kernel # to error later, this decorator can be removed, since the error will be raised by _get_kernel
def _register_unsupported_type(*datapoints_classes): def _register_unsupported_type(*input_types):
def kernel(inpt, *args, __dispatcher_name__, **kwargs): def kernel(inpt, *args, __dispatcher_name__, **kwargs):
raise TypeError(f"F.{__dispatcher_name__} does not support inputs of type {type(inpt)}.") raise TypeError(f"F.{__dispatcher_name__} does not support inputs of type {type(inpt)}.")
def decorator(dispatcher): def decorator(dispatcher):
for cls in datapoints_classes: for input_type in input_types:
register_kernel(dispatcher, cls)(functools.partial(kernel, __dispatcher_name__=dispatcher.__name__)) _register_kernel_internal(dispatcher, input_type, datapoint_wrapper=False)(
functools.partial(kernel, __dispatcher_name__=dispatcher.__name__)
)
return dispatcher return dispatcher
return decorator return decorator
...@@ -129,13 +172,10 @@ def _register_unsupported_type(*datapoints_classes): ...@@ -129,13 +172,10 @@ def _register_unsupported_type(*datapoints_classes):
# This basically replicates _register_kernel_internal, but with a specialized wrapper for five_crop / ten_crop # This basically replicates _register_kernel_internal, but with a specialized wrapper for five_crop / ten_crop
# We could get rid of this by letting _register_kernel_internal take arbitrary dispatchers rather than wrap_kernel: bool # We could get rid of this by letting _register_kernel_internal take arbitrary dispatchers rather than wrap_kernel: bool
# TODO: decide if we want that def _register_five_ten_crop_kernel(dispatcher, input_type):
def _register_five_ten_crop_kernel(dispatcher, datapoint_cls):
registry = _KERNEL_REGISTRY.setdefault(dispatcher, {}) registry = _KERNEL_REGISTRY.setdefault(dispatcher, {})
if datapoint_cls in registry: if input_type in registry:
raise TypeError( raise TypeError(f"Dispatcher '{dispatcher}' already has a kernel registered for type '{input_type}'.")
f"Dispatcher '{dispatcher.__name__}' already has a kernel registered for type '{datapoint_cls.__name__}'."
)
def wrap(kernel): def wrap(kernel):
@functools.wraps(kernel) @functools.wraps(kernel)
...@@ -147,7 +187,7 @@ def _register_five_ten_crop_kernel(dispatcher, datapoint_cls): ...@@ -147,7 +187,7 @@ def _register_five_ten_crop_kernel(dispatcher, datapoint_cls):
return wrapper return wrapper
def decorator(kernel): def decorator(kernel):
registry[datapoint_cls] = wrap(kernel) registry[input_type] = wrap(kernel) if issubclass(input_type, datapoints.Datapoint) else kernel
return kernel return kernel
return decorator return decorator
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment