Unverified Commit a893f313 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

refactor Datapoint dispatch mechanism (#7747)


Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
parent 16d62e30
...@@ -8,9 +8,29 @@ from torchvision.transforms import _functional_pil as _FP ...@@ -8,9 +8,29 @@ from torchvision.transforms import _functional_pil as _FP
from torchvision.utils import _log_api_usage_once from torchvision.utils import _log_api_usage_once
from ._utils import is_simple_tensor from ._utils import _get_kernel, _register_kernel_internal, _register_unsupported_type, is_simple_tensor
@_register_unsupported_type(datapoints.BoundingBoxes, datapoints.Mask)
def get_dimensions(inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]) -> List[int]:
if not torch.jit.is_scripting():
_log_api_usage_once(get_dimensions)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return get_dimensions_image_tensor(inpt)
elif isinstance(inpt, datapoints.Datapoint):
kernel = _get_kernel(get_dimensions, type(inpt))
return kernel(inpt)
elif isinstance(inpt, PIL.Image.Image):
return get_dimensions_image_pil(inpt)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
@_register_kernel_internal(get_dimensions, datapoints.Image, datapoint_wrapper=False)
def get_dimensions_image_tensor(image: torch.Tensor) -> List[int]: def get_dimensions_image_tensor(image: torch.Tensor) -> List[int]:
chw = list(image.shape[-3:]) chw = list(image.shape[-3:])
ndims = len(chw) ndims = len(chw)
...@@ -26,31 +46,31 @@ def get_dimensions_image_tensor(image: torch.Tensor) -> List[int]: ...@@ -26,31 +46,31 @@ def get_dimensions_image_tensor(image: torch.Tensor) -> List[int]:
get_dimensions_image_pil = _FP.get_dimensions get_dimensions_image_pil = _FP.get_dimensions
@_register_kernel_internal(get_dimensions, datapoints.Video, datapoint_wrapper=False)
def get_dimensions_video(video: torch.Tensor) -> List[int]: def get_dimensions_video(video: torch.Tensor) -> List[int]:
return get_dimensions_image_tensor(video) return get_dimensions_image_tensor(video)
def get_dimensions(inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]) -> List[int]: @_register_unsupported_type(datapoints.BoundingBoxes, datapoints.Mask)
def get_num_channels(inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]) -> int:
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
_log_api_usage_once(get_dimensions) _log_api_usage_once(get_num_channels)
if torch.jit.is_scripting() or is_simple_tensor(inpt): if torch.jit.is_scripting() or is_simple_tensor(inpt):
return get_dimensions_image_tensor(inpt) return get_num_channels_image_tensor(inpt)
elif isinstance(inpt, datapoints.Datapoint):
for typ, get_size_fn in { kernel = _get_kernel(get_num_channels, type(inpt))
datapoints.Image: get_dimensions_image_tensor, return kernel(inpt)
datapoints.Video: get_dimensions_video, elif isinstance(inpt, PIL.Image.Image):
PIL.Image.Image: get_dimensions_image_pil, return get_num_channels_image_pil(inpt)
}.items(): else:
if isinstance(inpt, typ): raise TypeError(
return get_size_fn(inpt) f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
raise TypeError( )
f"Input can either be a plain tensor, an `Image` or `Video` datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
@_register_kernel_internal(get_num_channels, datapoints.Image, datapoint_wrapper=False)
def get_num_channels_image_tensor(image: torch.Tensor) -> int: def get_num_channels_image_tensor(image: torch.Tensor) -> int:
chw = image.shape[-3:] chw = image.shape[-3:]
ndims = len(chw) ndims = len(chw)
...@@ -65,36 +85,35 @@ def get_num_channels_image_tensor(image: torch.Tensor) -> int: ...@@ -65,36 +85,35 @@ def get_num_channels_image_tensor(image: torch.Tensor) -> int:
get_num_channels_image_pil = _FP.get_image_num_channels get_num_channels_image_pil = _FP.get_image_num_channels
@_register_kernel_internal(get_num_channels, datapoints.Video, datapoint_wrapper=False)
def get_num_channels_video(video: torch.Tensor) -> int: def get_num_channels_video(video: torch.Tensor) -> int:
return get_num_channels_image_tensor(video) return get_num_channels_image_tensor(video)
def get_num_channels(inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]) -> int:
if not torch.jit.is_scripting():
_log_api_usage_once(get_num_channels)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return get_num_channels_image_tensor(inpt)
for typ, get_size_fn in {
datapoints.Image: get_num_channels_image_tensor,
datapoints.Video: get_num_channels_video,
PIL.Image.Image: get_num_channels_image_pil,
}.items():
if isinstance(inpt, typ):
return get_size_fn(inpt)
raise TypeError(
f"Input can either be a plain tensor, an `Image` or `Video` datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
# We changed the names to ensure it can be used not only for images but also videos. Thus, we just alias it without # We changed the names to ensure it can be used not only for images but also videos. Thus, we just alias it without
# deprecating the old names. # deprecating the old names.
get_image_num_channels = get_num_channels get_image_num_channels = get_num_channels
def get_size(inpt: datapoints._InputTypeJIT) -> List[int]:
if not torch.jit.is_scripting():
_log_api_usage_once(get_size)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return get_size_image_tensor(inpt)
elif isinstance(inpt, datapoints.Datapoint):
kernel = _get_kernel(get_size, type(inpt))
return kernel(inpt)
elif isinstance(inpt, PIL.Image.Image):
return get_size_image_pil(inpt)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
@_register_kernel_internal(get_size, datapoints.Image, datapoint_wrapper=False)
def get_size_image_tensor(image: torch.Tensor) -> List[int]: def get_size_image_tensor(image: torch.Tensor) -> List[int]:
hw = list(image.shape[-2:]) hw = list(image.shape[-2:])
ndims = len(hw) ndims = len(hw)
...@@ -110,59 +129,41 @@ def get_size_image_pil(image: PIL.Image.Image) -> List[int]: ...@@ -110,59 +129,41 @@ def get_size_image_pil(image: PIL.Image.Image) -> List[int]:
return [height, width] return [height, width]
@_register_kernel_internal(get_size, datapoints.Video, datapoint_wrapper=False)
def get_size_video(video: torch.Tensor) -> List[int]: def get_size_video(video: torch.Tensor) -> List[int]:
return get_size_image_tensor(video) return get_size_image_tensor(video)
@_register_kernel_internal(get_size, datapoints.Mask, datapoint_wrapper=False)
def get_size_mask(mask: torch.Tensor) -> List[int]: def get_size_mask(mask: torch.Tensor) -> List[int]:
return get_size_image_tensor(mask) return get_size_image_tensor(mask)
@torch.jit.unused @_register_kernel_internal(get_size, datapoints.BoundingBoxes, datapoint_wrapper=False)
def get_size_bounding_boxes(bounding_box: datapoints.BoundingBoxes) -> List[int]: def get_size_bounding_boxes(bounding_box: datapoints.BoundingBoxes) -> List[int]:
return list(bounding_box.canvas_size) return list(bounding_box.canvas_size)
def get_size(inpt: datapoints._InputTypeJIT) -> List[int]: @_register_unsupported_type(PIL.Image.Image, datapoints.Image, datapoints.BoundingBoxes, datapoints.Mask)
if not torch.jit.is_scripting():
_log_api_usage_once(get_size)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return get_size_image_tensor(inpt)
# TODO: This is just the poor mans version of a dispatcher. This will be properly addressed with
# https://github.com/pytorch/vision/pull/7747 when we can register the kernels above without the need to have
# a method on the datapoint class
for typ, get_size_fn in {
datapoints.Image: get_size_image_tensor,
datapoints.BoundingBoxes: get_size_bounding_boxes,
datapoints.Mask: get_size_mask,
datapoints.Video: get_size_video,
PIL.Image.Image: get_size_image_pil,
}.items():
if isinstance(inpt, typ):
return get_size_fn(inpt)
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
def get_num_frames_video(video: torch.Tensor) -> int:
return video.shape[-4]
def get_num_frames(inpt: datapoints._VideoTypeJIT) -> int: def get_num_frames(inpt: datapoints._VideoTypeJIT) -> int:
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
_log_api_usage_once(get_num_frames) _log_api_usage_once(get_num_frames)
if torch.jit.is_scripting() or is_simple_tensor(inpt): if torch.jit.is_scripting() or is_simple_tensor(inpt):
return get_num_frames_video(inpt) return get_num_frames_video(inpt)
elif isinstance(inpt, datapoints.Video): elif isinstance(inpt, datapoints.Datapoint):
return get_num_frames_video(inpt) kernel = _get_kernel(get_num_frames, type(inpt))
return kernel(inpt)
else: else:
raise TypeError(f"Input can either be a plain tensor or a `Video` datapoint, but got {type(inpt)} instead.") raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
@_register_kernel_internal(get_num_frames, datapoints.Video, datapoint_wrapper=False)
def get_num_frames_video(video: torch.Tensor) -> int:
return video.shape[-4]
def _xywh_to_xyxy(xywh: torch.Tensor, inplace: bool) -> torch.Tensor: def _xywh_to_xyxy(xywh: torch.Tensor, inplace: bool) -> torch.Tensor:
......
...@@ -11,9 +11,37 @@ from torchvision.transforms.functional import pil_to_tensor, to_pil_image ...@@ -11,9 +11,37 @@ from torchvision.transforms.functional import pil_to_tensor, to_pil_image
from torchvision.utils import _log_api_usage_once from torchvision.utils import _log_api_usage_once
from ._utils import is_simple_tensor from ._utils import (
_get_kernel,
_register_explicit_noop,
_register_kernel_internal,
_register_unsupported_type,
is_simple_tensor,
)
@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask)
@_register_unsupported_type(PIL.Image.Image)
def normalize(
inpt: Union[datapoints._TensorImageTypeJIT, datapoints._TensorVideoTypeJIT],
mean: List[float],
std: List[float],
inplace: bool = False,
) -> torch.Tensor:
if not torch.jit.is_scripting():
_log_api_usage_once(normalize)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return normalize_image_tensor(inpt, mean=mean, std=std, inplace=inplace)
elif isinstance(inpt, datapoints.Datapoint):
kernel = _get_kernel(normalize, type(inpt))
return kernel(inpt, mean=mean, std=std, inplace=inplace)
else:
raise TypeError(
f"Input can either be a plain tensor or any TorchVision datapoint, but got {type(inpt)} instead."
)
@_register_kernel_internal(normalize, datapoints.Image)
def normalize_image_tensor( def normalize_image_tensor(
image: torch.Tensor, mean: List[float], std: List[float], inplace: bool = False image: torch.Tensor, mean: List[float], std: List[float], inplace: bool = False
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -49,25 +77,29 @@ def normalize_image_tensor( ...@@ -49,25 +77,29 @@ def normalize_image_tensor(
return image.div_(std) return image.div_(std)
@_register_kernel_internal(normalize, datapoints.Video)
def normalize_video(video: torch.Tensor, mean: List[float], std: List[float], inplace: bool = False) -> torch.Tensor: def normalize_video(video: torch.Tensor, mean: List[float], std: List[float], inplace: bool = False) -> torch.Tensor:
return normalize_image_tensor(video, mean, std, inplace=inplace) return normalize_image_tensor(video, mean, std, inplace=inplace)
def normalize( @_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask)
inpt: Union[datapoints._TensorImageTypeJIT, datapoints._TensorVideoTypeJIT], def gaussian_blur(
mean: List[float], inpt: datapoints._InputTypeJIT, kernel_size: List[int], sigma: Optional[List[float]] = None
std: List[float], ) -> datapoints._InputTypeJIT:
inplace: bool = False,
) -> torch.Tensor:
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
_log_api_usage_once(normalize) _log_api_usage_once(gaussian_blur)
if torch.jit.is_scripting() or is_simple_tensor(inpt): if torch.jit.is_scripting() or is_simple_tensor(inpt):
return normalize_image_tensor(inpt, mean=mean, std=std, inplace=inplace) return gaussian_blur_image_tensor(inpt, kernel_size=kernel_size, sigma=sigma)
elif isinstance(inpt, (datapoints.Image, datapoints.Video)): elif isinstance(inpt, datapoints.Datapoint):
return inpt.normalize(mean=mean, std=std, inplace=inplace) kernel = _get_kernel(gaussian_blur, type(inpt))
return kernel(inpt, kernel_size=kernel_size, sigma=sigma)
elif isinstance(inpt, PIL.Image.Image):
return gaussian_blur_image_pil(inpt, kernel_size=kernel_size, sigma=sigma)
else: else:
raise TypeError( raise TypeError(
f"Input can either be a plain tensor or an `Image` or `Video` datapoint, " f"but got {type(inpt)} instead." f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
) )
...@@ -87,6 +119,7 @@ def _get_gaussian_kernel2d( ...@@ -87,6 +119,7 @@ def _get_gaussian_kernel2d(
return kernel2d return kernel2d
@_register_kernel_internal(gaussian_blur, datapoints.Image)
def gaussian_blur_image_tensor( def gaussian_blur_image_tensor(
image: torch.Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None image: torch.Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -160,28 +193,27 @@ def gaussian_blur_image_pil( ...@@ -160,28 +193,27 @@ def gaussian_blur_image_pil(
return to_pil_image(output, mode=image.mode) return to_pil_image(output, mode=image.mode)
@_register_kernel_internal(gaussian_blur, datapoints.Video)
def gaussian_blur_video( def gaussian_blur_video(
video: torch.Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None video: torch.Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None
) -> torch.Tensor: ) -> torch.Tensor:
return gaussian_blur_image_tensor(video, kernel_size, sigma) return gaussian_blur_image_tensor(video, kernel_size, sigma)
def gaussian_blur( def to_dtype(
inpt: datapoints._InputTypeJIT, kernel_size: List[int], sigma: Optional[List[float]] = None inpt: datapoints._InputTypeJIT, dtype: torch.dtype = torch.float, scale: bool = False
) -> datapoints._InputTypeJIT: ) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
_log_api_usage_once(gaussian_blur) _log_api_usage_once(to_dtype)
if torch.jit.is_scripting() or is_simple_tensor(inpt): if torch.jit.is_scripting() or is_simple_tensor(inpt):
return gaussian_blur_image_tensor(inpt, kernel_size=kernel_size, sigma=sigma) return to_dtype_image_tensor(inpt, dtype, scale=scale)
elif isinstance(inpt, datapoints._datapoint.Datapoint): elif isinstance(inpt, datapoints.Datapoint):
return inpt.gaussian_blur(kernel_size=kernel_size, sigma=sigma) kernel = _get_kernel(to_dtype, type(inpt))
elif isinstance(inpt, PIL.Image.Image): return kernel(inpt, dtype, scale=scale)
return gaussian_blur_image_pil(inpt, kernel_size=kernel_size, sigma=sigma)
else: else:
raise TypeError( raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " f"Input can either be a plain tensor or any TorchVision datapoint, but got {type(inpt)} instead."
f"but got {type(inpt)} instead."
) )
...@@ -200,6 +232,7 @@ def _num_value_bits(dtype: torch.dtype) -> int: ...@@ -200,6 +232,7 @@ def _num_value_bits(dtype: torch.dtype) -> int:
raise TypeError(f"Number of value bits is only defined for integer dtypes, but got {dtype}.") raise TypeError(f"Number of value bits is only defined for integer dtypes, but got {dtype}.")
@_register_kernel_internal(to_dtype, datapoints.Image)
def to_dtype_image_tensor(image: torch.Tensor, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor: def to_dtype_image_tensor(image: torch.Tensor, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor:
if image.dtype == dtype: if image.dtype == dtype:
...@@ -257,23 +290,15 @@ def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float32) ...@@ -257,23 +290,15 @@ def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float32)
return to_dtype_image_tensor(image, dtype=dtype, scale=True) return to_dtype_image_tensor(image, dtype=dtype, scale=True)
@_register_kernel_internal(to_dtype, datapoints.Video)
def to_dtype_video(video: torch.Tensor, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor: def to_dtype_video(video: torch.Tensor, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor:
return to_dtype_image_tensor(video, dtype, scale=scale) return to_dtype_image_tensor(video, dtype, scale=scale)
def to_dtype(inpt: datapoints._InputTypeJIT, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor: @_register_kernel_internal(to_dtype, datapoints.BoundingBoxes, datapoint_wrapper=False)
if not torch.jit.is_scripting(): @_register_kernel_internal(to_dtype, datapoints.Mask, datapoint_wrapper=False)
_log_api_usage_once(to_dtype) def _to_dtype_tensor_dispatch(
inpt: datapoints._InputTypeJIT, dtype: torch.dtype, scale: bool = False
if torch.jit.is_scripting() or is_simple_tensor(inpt): ) -> datapoints._InputTypeJIT:
return to_dtype_image_tensor(inpt, dtype, scale=scale) # We don't need to unwrap and rewrap here, since Datapoint.to() preserves the type
elif isinstance(inpt, datapoints.Image): return inpt.to(dtype)
output = to_dtype_image_tensor(inpt.as_subclass(torch.Tensor), dtype, scale=scale)
return datapoints.Image.wrap_like(inpt, output)
elif isinstance(inpt, datapoints.Video):
output = to_dtype_video(inpt.as_subclass(torch.Tensor), dtype, scale=scale)
return datapoints.Video.wrap_like(inpt, output)
elif isinstance(inpt, datapoints._datapoint.Datapoint):
return inpt.to(dtype)
else:
raise TypeError(f"Input can either be a plain tensor or a datapoint, but got {type(inpt)} instead.")
import PIL.Image
import torch import torch
from torchvision import datapoints from torchvision import datapoints
from torchvision.utils import _log_api_usage_once from torchvision.utils import _log_api_usage_once
from ._utils import is_simple_tensor from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_internal, is_simple_tensor
def uniform_temporal_subsample_video(video: torch.Tensor, num_samples: int) -> torch.Tensor:
# Reference: https://github.com/facebookresearch/pytorchvideo/blob/a0a131e/pytorchvideo/transforms/functional.py#L19
t_max = video.shape[-4] - 1
indices = torch.linspace(0, t_max, num_samples, device=video.device).long()
return torch.index_select(video, -4, indices)
@_register_explicit_noop(
PIL.Image.Image, datapoints.Image, datapoints.BoundingBoxes, datapoints.Mask, warn_passthrough=True
)
def uniform_temporal_subsample(inpt: datapoints._VideoTypeJIT, num_samples: int) -> datapoints._VideoTypeJIT: def uniform_temporal_subsample(inpt: datapoints._VideoTypeJIT, num_samples: int) -> datapoints._VideoTypeJIT:
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
_log_api_usage_once(uniform_temporal_subsample) _log_api_usage_once(uniform_temporal_subsample)
if torch.jit.is_scripting() or is_simple_tensor(inpt): if torch.jit.is_scripting() or is_simple_tensor(inpt):
return uniform_temporal_subsample_video(inpt, num_samples) return uniform_temporal_subsample_video(inpt, num_samples)
elif isinstance(inpt, datapoints.Video): elif isinstance(inpt, datapoints.Datapoint):
output = uniform_temporal_subsample_video(inpt.as_subclass(torch.Tensor), num_samples) kernel = _get_kernel(uniform_temporal_subsample, type(inpt))
return datapoints.Video.wrap_like(inpt, output) return kernel(inpt, num_samples)
else: else:
raise TypeError(f"Input can either be a plain tensor or a `Video` datapoint, but got {type(inpt)} instead.") raise TypeError(
f"Input can either be a plain tensor or any TorchVision datapoint, but got {type(inpt)} instead."
)
@_register_kernel_internal(uniform_temporal_subsample, datapoints.Video)
def uniform_temporal_subsample_video(video: torch.Tensor, num_samples: int) -> torch.Tensor:
# Reference: https://github.com/facebookresearch/pytorchvideo/blob/a0a131e/pytorchvideo/transforms/functional.py#L19
t_max = video.shape[-4] - 1
indices = torch.linspace(0, t_max, num_samples, device=video.device).long()
return torch.index_select(video, -4, indices)
from typing import Any import functools
import warnings
from typing import Any, Callable, Dict, Type
import torch import torch
from torchvision.datapoints._datapoint import Datapoint from torchvision import datapoints
def is_simple_tensor(inpt: Any) -> bool: def is_simple_tensor(inpt: Any) -> bool:
return isinstance(inpt, torch.Tensor) and not isinstance(inpt, Datapoint) return isinstance(inpt, torch.Tensor) and not isinstance(inpt, datapoints.Datapoint)
# {dispatcher: {input_type: type_specific_kernel}}
_KERNEL_REGISTRY: Dict[Callable, Dict[Type, Callable]] = {}
def _kernel_datapoint_wrapper(kernel):
@functools.wraps(kernel)
def wrapper(inpt, *args, **kwargs):
output = kernel(inpt.as_subclass(torch.Tensor), *args, **kwargs)
return type(inpt).wrap_like(inpt, output)
return wrapper
def _register_kernel_internal(dispatcher, datapoint_cls, *, datapoint_wrapper=True):
registry = _KERNEL_REGISTRY.setdefault(dispatcher, {})
if datapoint_cls in registry:
raise TypeError(
f"Dispatcher '{dispatcher.__name__}' already has a kernel registered for type '{datapoint_cls.__name__}'."
)
def decorator(kernel):
registry[datapoint_cls] = _kernel_datapoint_wrapper(kernel) if datapoint_wrapper else kernel
return kernel
return decorator
def register_kernel(dispatcher, datapoint_cls):
return _register_kernel_internal(dispatcher, datapoint_cls, datapoint_wrapper=False)
def _get_kernel(dispatcher, datapoint_cls):
registry = _KERNEL_REGISTRY.get(dispatcher)
if not registry:
raise ValueError(f"No kernel registered for dispatcher '{dispatcher.__name__}'.")
if datapoint_cls in registry:
return registry[datapoint_cls]
for registered_cls, kernel in registry.items():
if issubclass(datapoint_cls, registered_cls):
return kernel
return _noop
# Everything below this block is stuff that we need right now, since it looks like we need to release in an intermediate
# stage. See https://github.com/pytorch/vision/pull/7747#issuecomment-1661698450 for details.
# In the future, the default behavior will be to error on unsupported types in dispatchers. The noop behavior that we
# need for transforms will be handled by _get_kernel rather than actually registering no-ops on the dispatcher.
# Finally, the use case of preventing users from registering kernels for our builtin types will be handled inside
# register_kernel.
def _register_explicit_noop(*datapoints_classes, warn_passthrough=False):
"""
Although this looks redundant with the no-op behavior of _get_kernel, this explicit registration prevents users
from registering kernels for builtin datapoints on builtin dispatchers that rely on the no-op behavior.
For example, without explicit no-op registration the following would be valid user code:
.. code::
from torchvision.transforms.v2 import functional as F
@F.register_kernel(F.adjust_brightness, datapoints.BoundingBox)
def lol(...):
...
"""
def decorator(dispatcher):
for cls in datapoints_classes:
msg = (
f"F.{dispatcher.__name__} is currently passing through inputs of type datapoints.{cls.__name__}. "
f"This will likely change in the future."
)
register_kernel(dispatcher, cls)(functools.partial(_noop, __msg__=msg if warn_passthrough else None))
return dispatcher
return decorator
def _noop(inpt, *args, __msg__=None, **kwargs):
if __msg__:
warnings.warn(__msg__, UserWarning, stacklevel=2)
return inpt
# TODO: we only need this, since our default behavior in case no kernel is found is passthrough. When we change that
# to error later, this decorator can be removed, since the error will be raised by _get_kernel
def _register_unsupported_type(*datapoints_classes):
def kernel(inpt, *args, __dispatcher_name__, **kwargs):
raise TypeError(f"F.{__dispatcher_name__} does not support inputs of type {type(inpt)}.")
def decorator(dispatcher):
for cls in datapoints_classes:
register_kernel(dispatcher, cls)(functools.partial(kernel, __dispatcher_name__=dispatcher.__name__))
return dispatcher
return decorator
# This basically replicates _register_kernel_internal, but with a specialized wrapper for five_crop / ten_crop
# We could get rid of this by letting _register_kernel_internal take arbitrary dispatchers rather than wrap_kernel: bool
# TODO: decide if we want that
def _register_five_ten_crop_kernel(dispatcher, datapoint_cls):
registry = _KERNEL_REGISTRY.setdefault(dispatcher, {})
if datapoint_cls in registry:
raise TypeError(
f"Dispatcher '{dispatcher.__name__}' already has a kernel registered for type '{datapoint_cls.__name__}'."
)
def wrap(kernel):
@functools.wraps(kernel)
def wrapper(inpt, *args, **kwargs):
output = kernel(inpt, *args, **kwargs)
container_type = type(output)
return container_type(type(inpt).wrap_like(inpt, o) for o in output)
return wrapper
def decorator(kernel):
registry[datapoint_cls] = wrap(kernel)
return kernel
return decorator
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment