Unverified Commit d5f4cc38 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Datapoint -> TVTensor; datapoint[s] -> tv_tensor[s] (#7894)

parent b9447fdd
...@@ -5,13 +5,13 @@ from typing import Any, Optional, Sequence, Type, TypeVar, Union ...@@ -5,13 +5,13 @@ from typing import Any, Optional, Sequence, Type, TypeVar, Union
import torch import torch
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
from torchvision.datapoints._datapoint import Datapoint from torchvision.tv_tensors._tv_tensor import TVTensor
L = TypeVar("L", bound="_LabelBase") L = TypeVar("L", bound="_LabelBase")
class _LabelBase(Datapoint): class _LabelBase(TVTensor):
categories: Optional[Sequence[str]] categories: Optional[Sequence[str]]
@classmethod @classmethod
......
...@@ -7,7 +7,7 @@ import PIL.Image ...@@ -7,7 +7,7 @@ import PIL.Image
import torch import torch
from torch.nn.functional import one_hot from torch.nn.functional import one_hot
from torch.utils._pytree import tree_flatten, tree_unflatten from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision import datapoints, transforms as _transforms from torchvision import transforms as _transforms, tv_tensors
from torchvision.transforms.v2 import functional as F from torchvision.transforms.v2 import functional as F
from ._transform import _RandomApplyTransform, Transform from ._transform import _RandomApplyTransform, Transform
...@@ -91,10 +91,10 @@ class RandomErasing(_RandomApplyTransform): ...@@ -91,10 +91,10 @@ class RandomErasing(_RandomApplyTransform):
self._log_ratio = torch.log(torch.tensor(self.ratio)) self._log_ratio = torch.log(torch.tensor(self.ratio))
def _call_kernel(self, functional: Callable, inpt: Any, *args: Any, **kwargs: Any) -> Any: def _call_kernel(self, functional: Callable, inpt: Any, *args: Any, **kwargs: Any) -> Any:
if isinstance(inpt, (datapoints.BoundingBoxes, datapoints.Mask)): if isinstance(inpt, (tv_tensors.BoundingBoxes, tv_tensors.Mask)):
warnings.warn( warnings.warn(
f"{type(self).__name__}() is currently passing through inputs of type " f"{type(self).__name__}() is currently passing through inputs of type "
f"datapoints.{type(inpt).__name__}. This will likely change in the future." f"tv_tensors.{type(inpt).__name__}. This will likely change in the future."
) )
return super()._call_kernel(functional, inpt, *args, **kwargs) return super()._call_kernel(functional, inpt, *args, **kwargs)
...@@ -158,7 +158,7 @@ class _BaseMixUpCutMix(Transform): ...@@ -158,7 +158,7 @@ class _BaseMixUpCutMix(Transform):
flat_inputs, spec = tree_flatten(inputs) flat_inputs, spec = tree_flatten(inputs)
needs_transform_list = self._needs_transform_list(flat_inputs) needs_transform_list = self._needs_transform_list(flat_inputs)
if has_any(flat_inputs, PIL.Image.Image, datapoints.BoundingBoxes, datapoints.Mask): if has_any(flat_inputs, PIL.Image.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask):
raise ValueError(f"{type(self).__name__}() does not support PIL images, bounding boxes and masks.") raise ValueError(f"{type(self).__name__}() does not support PIL images, bounding boxes and masks.")
labels = self._labels_getter(inputs) labels = self._labels_getter(inputs)
...@@ -188,7 +188,7 @@ class _BaseMixUpCutMix(Transform): ...@@ -188,7 +188,7 @@ class _BaseMixUpCutMix(Transform):
return tree_unflatten(flat_outputs, spec) return tree_unflatten(flat_outputs, spec)
def _check_image_or_video(self, inpt: torch.Tensor, *, batch_size: int): def _check_image_or_video(self, inpt: torch.Tensor, *, batch_size: int):
expected_num_dims = 5 if isinstance(inpt, datapoints.Video) else 4 expected_num_dims = 5 if isinstance(inpt, tv_tensors.Video) else 4
if inpt.ndim != expected_num_dims: if inpt.ndim != expected_num_dims:
raise ValueError( raise ValueError(
f"Expected a batched input with {expected_num_dims} dims, but got {inpt.ndim} dimensions instead." f"Expected a batched input with {expected_num_dims} dims, but got {inpt.ndim} dimensions instead."
...@@ -242,13 +242,13 @@ class MixUp(_BaseMixUpCutMix): ...@@ -242,13 +242,13 @@ class MixUp(_BaseMixUpCutMix):
if inpt is params["labels"]: if inpt is params["labels"]:
return self._mixup_label(inpt, lam=lam) return self._mixup_label(inpt, lam=lam)
elif isinstance(inpt, (datapoints.Image, datapoints.Video)) or is_pure_tensor(inpt): elif isinstance(inpt, (tv_tensors.Image, tv_tensors.Video)) or is_pure_tensor(inpt):
self._check_image_or_video(inpt, batch_size=params["batch_size"]) self._check_image_or_video(inpt, batch_size=params["batch_size"])
output = inpt.roll(1, 0).mul_(1.0 - lam).add_(inpt.mul(lam)) output = inpt.roll(1, 0).mul_(1.0 - lam).add_(inpt.mul(lam))
if isinstance(inpt, (datapoints.Image, datapoints.Video)): if isinstance(inpt, (tv_tensors.Image, tv_tensors.Video)):
output = datapoints.wrap(output, like=inpt) output = tv_tensors.wrap(output, like=inpt)
return output return output
else: else:
...@@ -309,7 +309,7 @@ class CutMix(_BaseMixUpCutMix): ...@@ -309,7 +309,7 @@ class CutMix(_BaseMixUpCutMix):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if inpt is params["labels"]: if inpt is params["labels"]:
return self._mixup_label(inpt, lam=params["lam_adjusted"]) return self._mixup_label(inpt, lam=params["lam_adjusted"])
elif isinstance(inpt, (datapoints.Image, datapoints.Video)) or is_pure_tensor(inpt): elif isinstance(inpt, (tv_tensors.Image, tv_tensors.Video)) or is_pure_tensor(inpt):
self._check_image_or_video(inpt, batch_size=params["batch_size"]) self._check_image_or_video(inpt, batch_size=params["batch_size"])
x1, y1, x2, y2 = params["box"] x1, y1, x2, y2 = params["box"]
...@@ -317,8 +317,8 @@ class CutMix(_BaseMixUpCutMix): ...@@ -317,8 +317,8 @@ class CutMix(_BaseMixUpCutMix):
output = inpt.clone() output = inpt.clone()
output[..., y1:y2, x1:x2] = rolled[..., y1:y2, x1:x2] output[..., y1:y2, x1:x2] = rolled[..., y1:y2, x1:x2]
if isinstance(inpt, (datapoints.Image, datapoints.Video)): if isinstance(inpt, (tv_tensors.Image, tv_tensors.Video)):
output = datapoints.wrap(output, like=inpt) output = tv_tensors.wrap(output, like=inpt)
return output return output
else: else:
......
...@@ -5,7 +5,7 @@ import PIL.Image ...@@ -5,7 +5,7 @@ import PIL.Image
import torch import torch
from torch.utils._pytree import tree_flatten, tree_unflatten, TreeSpec from torch.utils._pytree import tree_flatten, tree_unflatten, TreeSpec
from torchvision import datapoints, transforms as _transforms from torchvision import transforms as _transforms, tv_tensors
from torchvision.transforms import _functional_tensor as _FT from torchvision.transforms import _functional_tensor as _FT
from torchvision.transforms.v2 import AutoAugmentPolicy, functional as F, InterpolationMode, Transform from torchvision.transforms.v2 import AutoAugmentPolicy, functional as F, InterpolationMode, Transform
from torchvision.transforms.v2.functional._geometry import _check_interpolation from torchvision.transforms.v2.functional._geometry import _check_interpolation
...@@ -15,7 +15,7 @@ from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT ...@@ -15,7 +15,7 @@ from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT
from ._utils import _get_fill, _setup_fill_arg, check_type, is_pure_tensor from ._utils import _get_fill, _setup_fill_arg, check_type, is_pure_tensor
ImageOrVideo = Union[torch.Tensor, PIL.Image.Image, datapoints.Image, datapoints.Video] ImageOrVideo = Union[torch.Tensor, PIL.Image.Image, tv_tensors.Image, tv_tensors.Video]
class _AutoAugmentBase(Transform): class _AutoAugmentBase(Transform):
...@@ -46,7 +46,7 @@ class _AutoAugmentBase(Transform): ...@@ -46,7 +46,7 @@ class _AutoAugmentBase(Transform):
def _flatten_and_extract_image_or_video( def _flatten_and_extract_image_or_video(
self, self,
inputs: Any, inputs: Any,
unsupported_types: Tuple[Type, ...] = (datapoints.BoundingBoxes, datapoints.Mask), unsupported_types: Tuple[Type, ...] = (tv_tensors.BoundingBoxes, tv_tensors.Mask),
) -> Tuple[Tuple[List[Any], TreeSpec, int], ImageOrVideo]: ) -> Tuple[Tuple[List[Any], TreeSpec, int], ImageOrVideo]:
flat_inputs, spec = tree_flatten(inputs if len(inputs) > 1 else inputs[0]) flat_inputs, spec = tree_flatten(inputs if len(inputs) > 1 else inputs[0])
needs_transform_list = self._needs_transform_list(flat_inputs) needs_transform_list = self._needs_transform_list(flat_inputs)
...@@ -56,10 +56,10 @@ class _AutoAugmentBase(Transform): ...@@ -56,10 +56,10 @@ class _AutoAugmentBase(Transform):
if needs_transform and check_type( if needs_transform and check_type(
inpt, inpt,
( (
datapoints.Image, tv_tensors.Image,
PIL.Image.Image, PIL.Image.Image,
is_pure_tensor, is_pure_tensor,
datapoints.Video, tv_tensors.Video,
), ),
): ):
image_or_videos.append((idx, inpt)) image_or_videos.append((idx, inpt))
...@@ -590,7 +590,7 @@ class AugMix(_AutoAugmentBase): ...@@ -590,7 +590,7 @@ class AugMix(_AutoAugmentBase):
augmentation_space = self._AUGMENTATION_SPACE if self.all_ops else self._PARTIAL_AUGMENTATION_SPACE augmentation_space = self._AUGMENTATION_SPACE if self.all_ops else self._PARTIAL_AUGMENTATION_SPACE
orig_dims = list(image_or_video.shape) orig_dims = list(image_or_video.shape)
expected_ndim = 5 if isinstance(orig_image_or_video, datapoints.Video) else 4 expected_ndim = 5 if isinstance(orig_image_or_video, tv_tensors.Video) else 4
batch = image_or_video.reshape([1] * max(expected_ndim - image_or_video.ndim, 0) + orig_dims) batch = image_or_video.reshape([1] * max(expected_ndim - image_or_video.ndim, 0) + orig_dims)
batch_dims = [batch.size(0)] + [1] * (batch.ndim - 1) batch_dims = [batch.size(0)] + [1] * (batch.ndim - 1)
...@@ -627,8 +627,8 @@ class AugMix(_AutoAugmentBase): ...@@ -627,8 +627,8 @@ class AugMix(_AutoAugmentBase):
mix.add_(combined_weights[:, i].reshape(batch_dims) * aug) mix.add_(combined_weights[:, i].reshape(batch_dims) * aug)
mix = mix.reshape(orig_dims).to(dtype=image_or_video.dtype) mix = mix.reshape(orig_dims).to(dtype=image_or_video.dtype)
if isinstance(orig_image_or_video, (datapoints.Image, datapoints.Video)): if isinstance(orig_image_or_video, (tv_tensors.Image, tv_tensors.Video)):
mix = datapoints.wrap(mix, like=orig_image_or_video) mix = tv_tensors.wrap(mix, like=orig_image_or_video)
elif isinstance(orig_image_or_video, PIL.Image.Image): elif isinstance(orig_image_or_video, PIL.Image.Image):
mix = F.to_pil_image(mix) mix = F.to_pil_image(mix)
......
...@@ -6,7 +6,7 @@ from typing import Any, Callable, cast, Dict, List, Literal, Optional, Sequence, ...@@ -6,7 +6,7 @@ from typing import Any, Callable, cast, Dict, List, Literal, Optional, Sequence,
import PIL.Image import PIL.Image
import torch import torch
from torchvision import datapoints, transforms as _transforms from torchvision import transforms as _transforms, tv_tensors
from torchvision.ops.boxes import box_iou from torchvision.ops.boxes import box_iou
from torchvision.transforms.functional import _get_perspective_coeffs from torchvision.transforms.functional import _get_perspective_coeffs
from torchvision.transforms.v2 import functional as F, InterpolationMode, Transform from torchvision.transforms.v2 import functional as F, InterpolationMode, Transform
...@@ -36,8 +36,8 @@ class RandomHorizontalFlip(_RandomApplyTransform): ...@@ -36,8 +36,8 @@ class RandomHorizontalFlip(_RandomApplyTransform):
.. v2betastatus:: RandomHorizontalFlip transform .. v2betastatus:: RandomHorizontalFlip transform
If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`, If the input is a :class:`torch.Tensor` or a ``TVTensor`` (e.g. :class:`~torchvision.tv_tensors.Image`,
:class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBoxes` etc.) :class:`~torchvision.tv_tensors.Video`, :class:`~torchvision.tv_tensors.BoundingBoxes` etc.)
it can have arbitrary number of leading batch dimensions. For example, it can have arbitrary number of leading batch dimensions. For example,
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape. the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
...@@ -56,8 +56,8 @@ class RandomVerticalFlip(_RandomApplyTransform): ...@@ -56,8 +56,8 @@ class RandomVerticalFlip(_RandomApplyTransform):
.. v2betastatus:: RandomVerticalFlip transform .. v2betastatus:: RandomVerticalFlip transform
If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`, If the input is a :class:`torch.Tensor` or a ``TVTensor`` (e.g. :class:`~torchvision.tv_tensors.Image`,
:class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBoxes` etc.) :class:`~torchvision.tv_tensors.Video`, :class:`~torchvision.tv_tensors.BoundingBoxes` etc.)
it can have arbitrary number of leading batch dimensions. For example, it can have arbitrary number of leading batch dimensions. For example,
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape. the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
...@@ -76,8 +76,8 @@ class Resize(Transform): ...@@ -76,8 +76,8 @@ class Resize(Transform):
.. v2betastatus:: Resize transform .. v2betastatus:: Resize transform
If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`, If the input is a :class:`torch.Tensor` or a ``TVTensor`` (e.g. :class:`~torchvision.tv_tensors.Image`,
:class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBoxes` etc.) :class:`~torchvision.tv_tensors.Video`, :class:`~torchvision.tv_tensors.BoundingBoxes` etc.)
it can have arbitrary number of leading batch dimensions. For example, it can have arbitrary number of leading batch dimensions. For example,
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape. the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
...@@ -171,8 +171,8 @@ class CenterCrop(Transform): ...@@ -171,8 +171,8 @@ class CenterCrop(Transform):
.. v2betastatus:: CenterCrop transform .. v2betastatus:: CenterCrop transform
If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`, If the input is a :class:`torch.Tensor` or a ``TVTensor`` (e.g. :class:`~torchvision.tv_tensors.Image`,
:class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBoxes` etc.) :class:`~torchvision.tv_tensors.Video`, :class:`~torchvision.tv_tensors.BoundingBoxes` etc.)
it can have arbitrary number of leading batch dimensions. For example, it can have arbitrary number of leading batch dimensions. For example,
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape. the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
...@@ -199,8 +199,8 @@ class RandomResizedCrop(Transform): ...@@ -199,8 +199,8 @@ class RandomResizedCrop(Transform):
.. v2betastatus:: RandomResizedCrop transform .. v2betastatus:: RandomResizedCrop transform
If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`, If the input is a :class:`torch.Tensor` or a ``TVTensor`` (e.g. :class:`~torchvision.tv_tensors.Image`,
:class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBoxes` etc.) :class:`~torchvision.tv_tensors.Video`, :class:`~torchvision.tv_tensors.BoundingBoxes` etc.)
it can have arbitrary number of leading batch dimensions. For example, it can have arbitrary number of leading batch dimensions. For example,
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape. the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
...@@ -322,8 +322,8 @@ class FiveCrop(Transform): ...@@ -322,8 +322,8 @@ class FiveCrop(Transform):
.. v2betastatus:: FiveCrop transform .. v2betastatus:: FiveCrop transform
If the input is a :class:`torch.Tensor` or a :class:`~torchvision.datapoints.Image` or a If the input is a :class:`torch.Tensor` or a :class:`~torchvision.tv_tensors.Image` or a
:class:`~torchvision.datapoints.Video` it can have arbitrary number of leading batch dimensions. :class:`~torchvision.tv_tensors.Video` it can have arbitrary number of leading batch dimensions.
For example, the image can have ``[..., C, H, W]`` shape. For example, the image can have ``[..., C, H, W]`` shape.
.. Note:: .. Note::
...@@ -338,15 +338,15 @@ class FiveCrop(Transform): ...@@ -338,15 +338,15 @@ class FiveCrop(Transform):
Example: Example:
>>> class BatchMultiCrop(transforms.Transform): >>> class BatchMultiCrop(transforms.Transform):
... def forward(self, sample: Tuple[Tuple[Union[datapoints.Image, datapoints.Video], ...], int]): ... def forward(self, sample: Tuple[Tuple[Union[tv_tensors.Image, tv_tensors.Video], ...], int]):
... images_or_videos, labels = sample ... images_or_videos, labels = sample
... batch_size = len(images_or_videos) ... batch_size = len(images_or_videos)
... image_or_video = images_or_videos[0] ... image_or_video = images_or_videos[0]
... images_or_videos = datapoints.wrap(torch.stack(images_or_videos), like=image_or_video) ... images_or_videos = tv_tensors.wrap(torch.stack(images_or_videos), like=image_or_video)
... labels = torch.full((batch_size,), label, device=images_or_videos.device) ... labels = torch.full((batch_size,), label, device=images_or_videos.device)
... return images_or_videos, labels ... return images_or_videos, labels
... ...
>>> image = datapoints.Image(torch.rand(3, 256, 256)) >>> image = tv_tensors.Image(torch.rand(3, 256, 256))
>>> label = 3 >>> label = 3
>>> transform = transforms.Compose([transforms.FiveCrop(224), BatchMultiCrop()]) >>> transform = transforms.Compose([transforms.FiveCrop(224), BatchMultiCrop()])
>>> images, labels = transform(image, label) >>> images, labels = transform(image, label)
...@@ -363,10 +363,10 @@ class FiveCrop(Transform): ...@@ -363,10 +363,10 @@ class FiveCrop(Transform):
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
def _call_kernel(self, functional: Callable, inpt: Any, *args: Any, **kwargs: Any) -> Any: def _call_kernel(self, functional: Callable, inpt: Any, *args: Any, **kwargs: Any) -> Any:
if isinstance(inpt, (datapoints.BoundingBoxes, datapoints.Mask)): if isinstance(inpt, (tv_tensors.BoundingBoxes, tv_tensors.Mask)):
warnings.warn( warnings.warn(
f"{type(self).__name__}() is currently passing through inputs of type " f"{type(self).__name__}() is currently passing through inputs of type "
f"datapoints.{type(inpt).__name__}. This will likely change in the future." f"tv_tensors.{type(inpt).__name__}. This will likely change in the future."
) )
return super()._call_kernel(functional, inpt, *args, **kwargs) return super()._call_kernel(functional, inpt, *args, **kwargs)
...@@ -374,7 +374,7 @@ class FiveCrop(Transform): ...@@ -374,7 +374,7 @@ class FiveCrop(Transform):
return self._call_kernel(F.five_crop, inpt, self.size) return self._call_kernel(F.five_crop, inpt, self.size)
def _check_inputs(self, flat_inputs: List[Any]) -> None: def _check_inputs(self, flat_inputs: List[Any]) -> None:
if has_any(flat_inputs, datapoints.BoundingBoxes, datapoints.Mask): if has_any(flat_inputs, tv_tensors.BoundingBoxes, tv_tensors.Mask):
raise TypeError(f"BoundingBoxes'es and Mask's are not supported by {type(self).__name__}()") raise TypeError(f"BoundingBoxes'es and Mask's are not supported by {type(self).__name__}()")
...@@ -384,8 +384,8 @@ class TenCrop(Transform): ...@@ -384,8 +384,8 @@ class TenCrop(Transform):
.. v2betastatus:: TenCrop transform .. v2betastatus:: TenCrop transform
If the input is a :class:`torch.Tensor` or a :class:`~torchvision.datapoints.Image` or a If the input is a :class:`torch.Tensor` or a :class:`~torchvision.tv_tensors.Image` or a
:class:`~torchvision.datapoints.Video` it can have arbitrary number of leading batch dimensions. :class:`~torchvision.tv_tensors.Video` it can have arbitrary number of leading batch dimensions.
For example, the image can have ``[..., C, H, W]`` shape. For example, the image can have ``[..., C, H, W]`` shape.
See :class:`~torchvision.transforms.v2.FiveCrop` for an example. See :class:`~torchvision.transforms.v2.FiveCrop` for an example.
...@@ -410,15 +410,15 @@ class TenCrop(Transform): ...@@ -410,15 +410,15 @@ class TenCrop(Transform):
self.vertical_flip = vertical_flip self.vertical_flip = vertical_flip
def _call_kernel(self, functional: Callable, inpt: Any, *args: Any, **kwargs: Any) -> Any: def _call_kernel(self, functional: Callable, inpt: Any, *args: Any, **kwargs: Any) -> Any:
if isinstance(inpt, (datapoints.BoundingBoxes, datapoints.Mask)): if isinstance(inpt, (tv_tensors.BoundingBoxes, tv_tensors.Mask)):
warnings.warn( warnings.warn(
f"{type(self).__name__}() is currently passing through inputs of type " f"{type(self).__name__}() is currently passing through inputs of type "
f"datapoints.{type(inpt).__name__}. This will likely change in the future." f"tv_tensors.{type(inpt).__name__}. This will likely change in the future."
) )
return super()._call_kernel(functional, inpt, *args, **kwargs) return super()._call_kernel(functional, inpt, *args, **kwargs)
def _check_inputs(self, flat_inputs: List[Any]) -> None: def _check_inputs(self, flat_inputs: List[Any]) -> None:
if has_any(flat_inputs, datapoints.BoundingBoxes, datapoints.Mask): if has_any(flat_inputs, tv_tensors.BoundingBoxes, tv_tensors.Mask):
raise TypeError(f"BoundingBoxes'es and Mask's are not supported by {type(self).__name__}()") raise TypeError(f"BoundingBoxes'es and Mask's are not supported by {type(self).__name__}()")
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
...@@ -430,8 +430,8 @@ class Pad(Transform): ...@@ -430,8 +430,8 @@ class Pad(Transform):
.. v2betastatus:: Pad transform .. v2betastatus:: Pad transform
If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`, If the input is a :class:`torch.Tensor` or a ``TVTensor`` (e.g. :class:`~torchvision.tv_tensors.Image`,
:class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBoxes` etc.) :class:`~torchvision.tv_tensors.Video`, :class:`~torchvision.tv_tensors.BoundingBoxes` etc.)
it can have arbitrary number of leading batch dimensions. For example, it can have arbitrary number of leading batch dimensions. For example,
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape. the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
...@@ -447,7 +447,7 @@ class Pad(Transform): ...@@ -447,7 +447,7 @@ class Pad(Transform):
fill (number or tuple or dict, optional): Pixel fill value used when the ``padding_mode`` is constant. fill (number or tuple or dict, optional): Pixel fill value used when the ``padding_mode`` is constant.
Default is 0. If a tuple of length 3, it is used to fill R, G, B channels respectively. Default is 0. If a tuple of length 3, it is used to fill R, G, B channels respectively.
Fill value can be also a dictionary mapping data type to the fill value, e.g. Fill value can be also a dictionary mapping data type to the fill value, e.g.
``fill={datapoints.Image: 127, datapoints.Mask: 0}`` where ``Image`` will be filled with 127 and ``fill={tv_tensors.Image: 127, tv_tensors.Mask: 0}`` where ``Image`` will be filled with 127 and
``Mask`` will be filled with 0. ``Mask`` will be filled with 0.
padding_mode (str, optional): Type of padding. Should be: constant, edge, reflect or symmetric. padding_mode (str, optional): Type of padding. Should be: constant, edge, reflect or symmetric.
Default is "constant". Default is "constant".
...@@ -515,8 +515,8 @@ class RandomZoomOut(_RandomApplyTransform): ...@@ -515,8 +515,8 @@ class RandomZoomOut(_RandomApplyTransform):
output_width = input_width * r output_width = input_width * r
output_height = input_height * r output_height = input_height * r
If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`, If the input is a :class:`torch.Tensor` or a ``TVTensor`` (e.g. :class:`~torchvision.tv_tensors.Image`,
:class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBoxes` etc.) :class:`~torchvision.tv_tensors.Video`, :class:`~torchvision.tv_tensors.BoundingBoxes` etc.)
it can have arbitrary number of leading batch dimensions. For example, it can have arbitrary number of leading batch dimensions. For example,
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape. the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
...@@ -524,7 +524,7 @@ class RandomZoomOut(_RandomApplyTransform): ...@@ -524,7 +524,7 @@ class RandomZoomOut(_RandomApplyTransform):
fill (number or tuple or dict, optional): Pixel fill value used when the ``padding_mode`` is constant. fill (number or tuple or dict, optional): Pixel fill value used when the ``padding_mode`` is constant.
Default is 0. If a tuple of length 3, it is used to fill R, G, B channels respectively. Default is 0. If a tuple of length 3, it is used to fill R, G, B channels respectively.
Fill value can be also a dictionary mapping data type to the fill value, e.g. Fill value can be also a dictionary mapping data type to the fill value, e.g.
``fill={datapoints.Image: 127, datapoints.Mask: 0}`` where ``Image`` will be filled with 127 and ``fill={tv_tensors.Image: 127, tv_tensors.Mask: 0}`` where ``Image`` will be filled with 127 and
``Mask`` will be filled with 0. ``Mask`` will be filled with 0.
side_range (sequence of floats, optional): tuple of two floats defines minimum and maximum factors to side_range (sequence of floats, optional): tuple of two floats defines minimum and maximum factors to
scale the input size. scale the input size.
...@@ -574,8 +574,8 @@ class RandomRotation(Transform): ...@@ -574,8 +574,8 @@ class RandomRotation(Transform):
.. v2betastatus:: RandomRotation transform .. v2betastatus:: RandomRotation transform
If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`, If the input is a :class:`torch.Tensor` or a ``TVTensor`` (e.g. :class:`~torchvision.tv_tensors.Image`,
:class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBoxes` etc.) :class:`~torchvision.tv_tensors.Video`, :class:`~torchvision.tv_tensors.BoundingBoxes` etc.)
it can have arbitrary number of leading batch dimensions. For example, it can have arbitrary number of leading batch dimensions. For example,
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape. the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
...@@ -596,7 +596,7 @@ class RandomRotation(Transform): ...@@ -596,7 +596,7 @@ class RandomRotation(Transform):
fill (number or tuple or dict, optional): Pixel fill value used when the ``padding_mode`` is constant. fill (number or tuple or dict, optional): Pixel fill value used when the ``padding_mode`` is constant.
Default is 0. If a tuple of length 3, it is used to fill R, G, B channels respectively. Default is 0. If a tuple of length 3, it is used to fill R, G, B channels respectively.
Fill value can be also a dictionary mapping data type to the fill value, e.g. Fill value can be also a dictionary mapping data type to the fill value, e.g.
``fill={datapoints.Image: 127, datapoints.Mask: 0}`` where ``Image`` will be filled with 127 and ``fill={tv_tensors.Image: 127, tv_tensors.Mask: 0}`` where ``Image`` will be filled with 127 and
``Mask`` will be filled with 0. ``Mask`` will be filled with 0.
.. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters
...@@ -648,8 +648,8 @@ class RandomAffine(Transform): ...@@ -648,8 +648,8 @@ class RandomAffine(Transform):
.. v2betastatus:: RandomAffine transform .. v2betastatus:: RandomAffine transform
If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`, If the input is a :class:`torch.Tensor` or a ``TVTensor`` (e.g. :class:`~torchvision.tv_tensors.Image`,
:class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBoxes` etc.) :class:`~torchvision.tv_tensors.Video`, :class:`~torchvision.tv_tensors.BoundingBoxes` etc.)
it can have arbitrary number of leading batch dimensions. For example, it can have arbitrary number of leading batch dimensions. For example,
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape. the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
...@@ -676,7 +676,7 @@ class RandomAffine(Transform): ...@@ -676,7 +676,7 @@ class RandomAffine(Transform):
fill (number or tuple or dict, optional): Pixel fill value used when the ``padding_mode`` is constant. fill (number or tuple or dict, optional): Pixel fill value used when the ``padding_mode`` is constant.
Default is 0. If a tuple of length 3, it is used to fill R, G, B channels respectively. Default is 0. If a tuple of length 3, it is used to fill R, G, B channels respectively.
Fill value can be also a dictionary mapping data type to the fill value, e.g. Fill value can be also a dictionary mapping data type to the fill value, e.g.
``fill={datapoints.Image: 127, datapoints.Mask: 0}`` where ``Image`` will be filled with 127 and ``fill={tv_tensors.Image: 127, tv_tensors.Mask: 0}`` where ``Image`` will be filled with 127 and
``Mask`` will be filled with 0. ``Mask`` will be filled with 0.
center (sequence, optional): Optional center of rotation, (x, y). Origin is the upper left corner. center (sequence, optional): Optional center of rotation, (x, y). Origin is the upper left corner.
Default is the center of the image. Default is the center of the image.
...@@ -770,8 +770,8 @@ class RandomCrop(Transform): ...@@ -770,8 +770,8 @@ class RandomCrop(Transform):
.. v2betastatus:: RandomCrop transform .. v2betastatus:: RandomCrop transform
If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`, If the input is a :class:`torch.Tensor` or a ``TVTensor`` (e.g. :class:`~torchvision.tv_tensors.Image`,
:class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBoxes` etc.) :class:`~torchvision.tv_tensors.Video`, :class:`~torchvision.tv_tensors.BoundingBoxes` etc.)
it can have arbitrary number of leading batch dimensions. For example, it can have arbitrary number of leading batch dimensions. For example,
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape. the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
...@@ -794,7 +794,7 @@ class RandomCrop(Transform): ...@@ -794,7 +794,7 @@ class RandomCrop(Transform):
fill (number or tuple or dict, optional): Pixel fill value used when the ``padding_mode`` is constant. fill (number or tuple or dict, optional): Pixel fill value used when the ``padding_mode`` is constant.
Default is 0. If a tuple of length 3, it is used to fill R, G, B channels respectively. Default is 0. If a tuple of length 3, it is used to fill R, G, B channels respectively.
Fill value can be also a dictionary mapping data type to the fill value, e.g. Fill value can be also a dictionary mapping data type to the fill value, e.g.
``fill={datapoints.Image: 127, datapoints.Mask: 0}`` where ``Image`` will be filled with 127 and ``fill={tv_tensors.Image: 127, tv_tensors.Mask: 0}`` where ``Image`` will be filled with 127 and
``Mask`` will be filled with 0. ``Mask`` will be filled with 0.
padding_mode (str, optional): Type of padding. Should be: constant, edge, reflect or symmetric. padding_mode (str, optional): Type of padding. Should be: constant, edge, reflect or symmetric.
Default is constant. Default is constant.
...@@ -927,8 +927,8 @@ class RandomPerspective(_RandomApplyTransform): ...@@ -927,8 +927,8 @@ class RandomPerspective(_RandomApplyTransform):
.. v2betastatus:: RandomPerspective transform .. v2betastatus:: RandomPerspective transform
If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`, If the input is a :class:`torch.Tensor` or a ``TVTensor`` (e.g. :class:`~torchvision.tv_tensors.Image`,
:class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBoxes` etc.) :class:`~torchvision.tv_tensors.Video`, :class:`~torchvision.tv_tensors.BoundingBoxes` etc.)
it can have arbitrary number of leading batch dimensions. For example, it can have arbitrary number of leading batch dimensions. For example,
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape. the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
...@@ -943,7 +943,7 @@ class RandomPerspective(_RandomApplyTransform): ...@@ -943,7 +943,7 @@ class RandomPerspective(_RandomApplyTransform):
fill (number or tuple or dict, optional): Pixel fill value used when the ``padding_mode`` is constant. fill (number or tuple or dict, optional): Pixel fill value used when the ``padding_mode`` is constant.
Default is 0. If a tuple of length 3, it is used to fill R, G, B channels respectively. Default is 0. If a tuple of length 3, it is used to fill R, G, B channels respectively.
Fill value can be also a dictionary mapping data type to the fill value, e.g. Fill value can be also a dictionary mapping data type to the fill value, e.g.
``fill={datapoints.Image: 127, datapoints.Mask: 0}`` where ``Image`` will be filled with 127 and ``fill={tv_tensors.Image: 127, tv_tensors.Mask: 0}`` where ``Image`` will be filled with 127 and
``Mask`` will be filled with 0. ``Mask`` will be filled with 0.
""" """
...@@ -1014,8 +1014,8 @@ class ElasticTransform(Transform): ...@@ -1014,8 +1014,8 @@ class ElasticTransform(Transform):
.. v2betastatus:: RandomPerspective transform .. v2betastatus:: RandomPerspective transform
If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`, If the input is a :class:`torch.Tensor` or a ``TVTensor`` (e.g. :class:`~torchvision.tv_tensors.Image`,
:class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBoxes` etc.) :class:`~torchvision.tv_tensors.Video`, :class:`~torchvision.tv_tensors.BoundingBoxes` etc.)
it can have arbitrary number of leading batch dimensions. For example, it can have arbitrary number of leading batch dimensions. For example,
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape. the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
...@@ -1046,7 +1046,7 @@ class ElasticTransform(Transform): ...@@ -1046,7 +1046,7 @@ class ElasticTransform(Transform):
fill (number or tuple or dict, optional): Pixel fill value used when the ``padding_mode`` is constant. fill (number or tuple or dict, optional): Pixel fill value used when the ``padding_mode`` is constant.
Default is 0. If a tuple of length 3, it is used to fill R, G, B channels respectively. Default is 0. If a tuple of length 3, it is used to fill R, G, B channels respectively.
Fill value can be also a dictionary mapping data type to the fill value, e.g. Fill value can be also a dictionary mapping data type to the fill value, e.g.
``fill={datapoints.Image: 127, datapoints.Mask: 0}`` where ``Image`` will be filled with 127 and ``fill={tv_tensors.Image: 127, tv_tensors.Mask: 0}`` where ``Image`` will be filled with 127 and
``Mask`` will be filled with 0. ``Mask`` will be filled with 0.
""" """
...@@ -1107,15 +1107,15 @@ class RandomIoUCrop(Transform): ...@@ -1107,15 +1107,15 @@ class RandomIoUCrop(Transform):
.. v2betastatus:: RandomIoUCrop transform .. v2betastatus:: RandomIoUCrop transform
This transformation requires an image or video data and ``datapoints.BoundingBoxes`` in the input. This transformation requires an image or video data and ``tv_tensors.BoundingBoxes`` in the input.
.. warning:: .. warning::
In order to properly remove the bounding boxes below the IoU threshold, `RandomIoUCrop` In order to properly remove the bounding boxes below the IoU threshold, `RandomIoUCrop`
must be followed by :class:`~torchvision.transforms.v2.SanitizeBoundingBoxes`, either immediately must be followed by :class:`~torchvision.transforms.v2.SanitizeBoundingBoxes`, either immediately
after or later in the transforms pipeline. after or later in the transforms pipeline.
If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`, If the input is a :class:`torch.Tensor` or a ``TVTensor`` (e.g. :class:`~torchvision.tv_tensors.Image`,
:class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBoxes` etc.) :class:`~torchvision.tv_tensors.Video`, :class:`~torchvision.tv_tensors.BoundingBoxes` etc.)
it can have arbitrary number of leading batch dimensions. For example, it can have arbitrary number of leading batch dimensions. For example,
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape. the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
...@@ -1152,8 +1152,8 @@ class RandomIoUCrop(Transform): ...@@ -1152,8 +1152,8 @@ class RandomIoUCrop(Transform):
def _check_inputs(self, flat_inputs: List[Any]) -> None: def _check_inputs(self, flat_inputs: List[Any]) -> None:
if not ( if not (
has_all(flat_inputs, datapoints.BoundingBoxes) has_all(flat_inputs, tv_tensors.BoundingBoxes)
and has_any(flat_inputs, PIL.Image.Image, datapoints.Image, is_pure_tensor) and has_any(flat_inputs, PIL.Image.Image, tv_tensors.Image, is_pure_tensor)
): ):
raise TypeError( raise TypeError(
f"{type(self).__name__}() requires input sample to contain tensor or PIL images " f"{type(self).__name__}() requires input sample to contain tensor or PIL images "
...@@ -1193,7 +1193,7 @@ class RandomIoUCrop(Transform): ...@@ -1193,7 +1193,7 @@ class RandomIoUCrop(Transform):
xyxy_bboxes = F.convert_bounding_box_format( xyxy_bboxes = F.convert_bounding_box_format(
bboxes.as_subclass(torch.Tensor), bboxes.as_subclass(torch.Tensor),
bboxes.format, bboxes.format,
datapoints.BoundingBoxFormat.XYXY, tv_tensors.BoundingBoxFormat.XYXY,
) )
cx = 0.5 * (xyxy_bboxes[..., 0] + xyxy_bboxes[..., 2]) cx = 0.5 * (xyxy_bboxes[..., 0] + xyxy_bboxes[..., 2])
cy = 0.5 * (xyxy_bboxes[..., 1] + xyxy_bboxes[..., 3]) cy = 0.5 * (xyxy_bboxes[..., 1] + xyxy_bboxes[..., 3])
...@@ -1221,7 +1221,7 @@ class RandomIoUCrop(Transform): ...@@ -1221,7 +1221,7 @@ class RandomIoUCrop(Transform):
F.crop, inpt, top=params["top"], left=params["left"], height=params["height"], width=params["width"] F.crop, inpt, top=params["top"], left=params["left"], height=params["height"], width=params["width"]
) )
if isinstance(output, datapoints.BoundingBoxes): if isinstance(output, tv_tensors.BoundingBoxes):
# We "mark" the invalid boxes as degenreate, and they can be # We "mark" the invalid boxes as degenreate, and they can be
# removed by a later call to SanitizeBoundingBoxes() # removed by a later call to SanitizeBoundingBoxes()
output[~params["is_within_crop_area"]] = 0 output[~params["is_within_crop_area"]] = 0
...@@ -1235,8 +1235,8 @@ class ScaleJitter(Transform): ...@@ -1235,8 +1235,8 @@ class ScaleJitter(Transform):
.. v2betastatus:: ScaleJitter transform .. v2betastatus:: ScaleJitter transform
If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`, If the input is a :class:`torch.Tensor` or a ``TVTensor`` (e.g. :class:`~torchvision.tv_tensors.Image`,
:class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBoxes` etc.) :class:`~torchvision.tv_tensors.Video`, :class:`~torchvision.tv_tensors.BoundingBoxes` etc.)
it can have arbitrary number of leading batch dimensions. For example, it can have arbitrary number of leading batch dimensions. For example,
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape. the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
...@@ -1303,8 +1303,8 @@ class RandomShortestSize(Transform): ...@@ -1303,8 +1303,8 @@ class RandomShortestSize(Transform):
.. v2betastatus:: RandomShortestSize transform .. v2betastatus:: RandomShortestSize transform
If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`, If the input is a :class:`torch.Tensor` or a ``TVTensor`` (e.g. :class:`~torchvision.tv_tensors.Image`,
:class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBoxes` etc.) :class:`~torchvision.tv_tensors.Video`, :class:`~torchvision.tv_tensors.BoundingBoxes` etc.)
it can have arbitrary number of leading batch dimensions. For example, it can have arbitrary number of leading batch dimensions. For example,
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape. the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
...@@ -1384,8 +1384,8 @@ class RandomResize(Transform): ...@@ -1384,8 +1384,8 @@ class RandomResize(Transform):
output_width = size output_width = size
output_height = size output_height = size
If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`, If the input is a :class:`torch.Tensor` or a ``TVTensor`` (e.g. :class:`~torchvision.tv_tensors.Image`,
:class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBoxes` etc.) :class:`~torchvision.tv_tensors.Video`, :class:`~torchvision.tv_tensors.BoundingBoxes` etc.)
it can have arbitrary number of leading batch dimensions. For example, it can have arbitrary number of leading batch dimensions. For example,
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape. the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
......
from typing import Any, Dict, Union from typing import Any, Dict, Union
from torchvision import datapoints from torchvision import tv_tensors
from torchvision.transforms.v2 import functional as F, Transform from torchvision.transforms.v2 import functional as F, Transform
...@@ -10,20 +10,20 @@ class ConvertBoundingBoxFormat(Transform): ...@@ -10,20 +10,20 @@ class ConvertBoundingBoxFormat(Transform):
.. v2betastatus:: ConvertBoundingBoxFormat transform .. v2betastatus:: ConvertBoundingBoxFormat transform
Args: Args:
format (str or datapoints.BoundingBoxFormat): output bounding box format. format (str or tv_tensors.BoundingBoxFormat): output bounding box format.
Possible values are defined by :class:`~torchvision.datapoints.BoundingBoxFormat` and Possible values are defined by :class:`~torchvision.tv_tensors.BoundingBoxFormat` and
string values match the enums, e.g. "XYXY" or "XYWH" etc. string values match the enums, e.g. "XYXY" or "XYWH" etc.
""" """
_transformed_types = (datapoints.BoundingBoxes,) _transformed_types = (tv_tensors.BoundingBoxes,)
def __init__(self, format: Union[str, datapoints.BoundingBoxFormat]) -> None: def __init__(self, format: Union[str, tv_tensors.BoundingBoxFormat]) -> None:
super().__init__() super().__init__()
if isinstance(format, str): if isinstance(format, str):
format = datapoints.BoundingBoxFormat[format] format = tv_tensors.BoundingBoxFormat[format]
self.format = format self.format = format
def _transform(self, inpt: datapoints.BoundingBoxes, params: Dict[str, Any]) -> datapoints.BoundingBoxes: def _transform(self, inpt: tv_tensors.BoundingBoxes, params: Dict[str, Any]) -> tv_tensors.BoundingBoxes:
return F.convert_bounding_box_format(inpt, new_format=self.format) # type: ignore[return-value] return F.convert_bounding_box_format(inpt, new_format=self.format) # type: ignore[return-value]
...@@ -36,7 +36,7 @@ class ClampBoundingBoxes(Transform): ...@@ -36,7 +36,7 @@ class ClampBoundingBoxes(Transform):
""" """
_transformed_types = (datapoints.BoundingBoxes,) _transformed_types = (tv_tensors.BoundingBoxes,)
def _transform(self, inpt: datapoints.BoundingBoxes, params: Dict[str, Any]) -> datapoints.BoundingBoxes: def _transform(self, inpt: tv_tensors.BoundingBoxes, params: Dict[str, Any]) -> tv_tensors.BoundingBoxes:
return F.clamp_bounding_boxes(inpt) # type: ignore[return-value] return F.clamp_bounding_boxes(inpt) # type: ignore[return-value]
...@@ -6,7 +6,7 @@ import PIL.Image ...@@ -6,7 +6,7 @@ import PIL.Image
import torch import torch
from torch.utils._pytree import tree_flatten, tree_unflatten from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision import datapoints, transforms as _transforms from torchvision import transforms as _transforms, tv_tensors
from torchvision.transforms.v2 import functional as F, Transform from torchvision.transforms.v2 import functional as F, Transform
from ._utils import _parse_labels_getter, _setup_float_or_seq, _setup_size, get_bounding_boxes, has_any, is_pure_tensor from ._utils import _parse_labels_getter, _setup_float_or_seq, _setup_size, get_bounding_boxes, has_any, is_pure_tensor
...@@ -74,7 +74,7 @@ class LinearTransformation(Transform): ...@@ -74,7 +74,7 @@ class LinearTransformation(Transform):
_v1_transform_cls = _transforms.LinearTransformation _v1_transform_cls = _transforms.LinearTransformation
_transformed_types = (is_pure_tensor, datapoints.Image, datapoints.Video) _transformed_types = (is_pure_tensor, tv_tensors.Image, tv_tensors.Video)
def __init__(self, transformation_matrix: torch.Tensor, mean_vector: torch.Tensor): def __init__(self, transformation_matrix: torch.Tensor, mean_vector: torch.Tensor):
super().__init__() super().__init__()
...@@ -129,8 +129,8 @@ class LinearTransformation(Transform): ...@@ -129,8 +129,8 @@ class LinearTransformation(Transform):
output = torch.mm(flat_inpt, transformation_matrix) output = torch.mm(flat_inpt, transformation_matrix)
output = output.reshape(shape) output = output.reshape(shape)
if isinstance(inpt, (datapoints.Image, datapoints.Video)): if isinstance(inpt, (tv_tensors.Image, tv_tensors.Video)):
output = datapoints.wrap(output, like=inpt) output = tv_tensors.wrap(output, like=inpt)
return output return output
...@@ -227,12 +227,12 @@ class ToDtype(Transform): ...@@ -227,12 +227,12 @@ class ToDtype(Transform):
``ToDtype(dtype, scale=True)`` is the recommended replacement for ``ConvertImageDtype(dtype)``. ``ToDtype(dtype, scale=True)`` is the recommended replacement for ``ConvertImageDtype(dtype)``.
Args: Args:
dtype (``torch.dtype`` or dict of ``Datapoint`` -> ``torch.dtype``): The dtype to convert to. dtype (``torch.dtype`` or dict of ``TVTensor`` -> ``torch.dtype``): The dtype to convert to.
If a ``torch.dtype`` is passed, e.g. ``torch.float32``, only images and videos will be converted If a ``torch.dtype`` is passed, e.g. ``torch.float32``, only images and videos will be converted
to that dtype: this is for compatibility with :class:`~torchvision.transforms.v2.ConvertImageDtype`. to that dtype: this is for compatibility with :class:`~torchvision.transforms.v2.ConvertImageDtype`.
A dict can be passed to specify per-datapoint conversions, e.g. A dict can be passed to specify per-tv_tensor conversions, e.g.
``dtype={datapoints.Image: torch.float32, datapoints.Mask: torch.int64, "others":None}``. The "others" ``dtype={tv_tensors.Image: torch.float32, tv_tensors.Mask: torch.int64, "others":None}``. The "others"
key can be used as a catch-all for any other datapoint type, and ``None`` means no conversion. key can be used as a catch-all for any other tv_tensor type, and ``None`` means no conversion.
scale (bool, optional): Whether to scale the values for images or videos. See :ref:`range_and_dtype`. scale (bool, optional): Whether to scale the values for images or videos. See :ref:`range_and_dtype`.
Default: ``False``. Default: ``False``.
""" """
...@@ -250,12 +250,12 @@ class ToDtype(Transform): ...@@ -250,12 +250,12 @@ class ToDtype(Transform):
if ( if (
isinstance(dtype, dict) isinstance(dtype, dict)
and torch.Tensor in dtype and torch.Tensor in dtype
and any(cls in dtype for cls in [datapoints.Image, datapoints.Video]) and any(cls in dtype for cls in [tv_tensors.Image, tv_tensors.Video])
): ):
warnings.warn( warnings.warn(
"Got `dtype` values for `torch.Tensor` and either `datapoints.Image` or `datapoints.Video`. " "Got `dtype` values for `torch.Tensor` and either `tv_tensors.Image` or `tv_tensors.Video`. "
"Note that a plain `torch.Tensor` will *not* be transformed by this (or any other transformation) " "Note that a plain `torch.Tensor` will *not* be transformed by this (or any other transformation) "
"in case a `datapoints.Image` or `datapoints.Video` is present in the input." "in case a `tv_tensors.Image` or `tv_tensors.Video` is present in the input."
) )
self.dtype = dtype self.dtype = dtype
self.scale = scale self.scale = scale
...@@ -264,7 +264,7 @@ class ToDtype(Transform): ...@@ -264,7 +264,7 @@ class ToDtype(Transform):
if isinstance(self.dtype, torch.dtype): if isinstance(self.dtype, torch.dtype):
# For consistency / BC with ConvertImageDtype, we only care about images or videos when dtype # For consistency / BC with ConvertImageDtype, we only care about images or videos when dtype
# is a simple torch.dtype # is a simple torch.dtype
if not is_pure_tensor(inpt) and not isinstance(inpt, (datapoints.Image, datapoints.Video)): if not is_pure_tensor(inpt) and not isinstance(inpt, (tv_tensors.Image, tv_tensors.Video)):
return inpt return inpt
dtype: Optional[torch.dtype] = self.dtype dtype: Optional[torch.dtype] = self.dtype
...@@ -278,10 +278,10 @@ class ToDtype(Transform): ...@@ -278,10 +278,10 @@ class ToDtype(Transform):
"If you only need to convert the dtype of images or videos, you can just pass e.g. dtype=torch.float32. " "If you only need to convert the dtype of images or videos, you can just pass e.g. dtype=torch.float32. "
"If you're passing a dict as dtype, " "If you're passing a dict as dtype, "
'you can use "others" as a catch-all key ' 'you can use "others" as a catch-all key '
'e.g. dtype={datapoints.Mask: torch.int64, "others": None} to pass-through the rest of the inputs.' 'e.g. dtype={tv_tensors.Mask: torch.int64, "others": None} to pass-through the rest of the inputs.'
) )
supports_scaling = is_pure_tensor(inpt) or isinstance(inpt, (datapoints.Image, datapoints.Video)) supports_scaling = is_pure_tensor(inpt) or isinstance(inpt, (tv_tensors.Image, tv_tensors.Video))
if dtype is None: if dtype is None:
if self.scale and supports_scaling: if self.scale and supports_scaling:
warnings.warn( warnings.warn(
...@@ -389,10 +389,10 @@ class SanitizeBoundingBoxes(Transform): ...@@ -389,10 +389,10 @@ class SanitizeBoundingBoxes(Transform):
) )
boxes = cast( boxes = cast(
datapoints.BoundingBoxes, tv_tensors.BoundingBoxes,
F.convert_bounding_box_format( F.convert_bounding_box_format(
boxes, boxes,
new_format=datapoints.BoundingBoxFormat.XYXY, new_format=tv_tensors.BoundingBoxFormat.XYXY,
), ),
) )
ws, hs = boxes[:, 2] - boxes[:, 0], boxes[:, 3] - boxes[:, 1] ws, hs = boxes[:, 2] - boxes[:, 0], boxes[:, 3] - boxes[:, 1]
...@@ -415,7 +415,7 @@ class SanitizeBoundingBoxes(Transform): ...@@ -415,7 +415,7 @@ class SanitizeBoundingBoxes(Transform):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
is_label = inpt is not None and inpt is params["labels"] is_label = inpt is not None and inpt is params["labels"]
is_bounding_boxes_or_mask = isinstance(inpt, (datapoints.BoundingBoxes, datapoints.Mask)) is_bounding_boxes_or_mask = isinstance(inpt, (tv_tensors.BoundingBoxes, tv_tensors.Mask))
if not (is_label or is_bounding_boxes_or_mask): if not (is_label or is_bounding_boxes_or_mask):
return inpt return inpt
...@@ -425,4 +425,4 @@ class SanitizeBoundingBoxes(Transform): ...@@ -425,4 +425,4 @@ class SanitizeBoundingBoxes(Transform):
if is_label: if is_label:
return output return output
return datapoints.wrap(output, like=inpt) return tv_tensors.wrap(output, like=inpt)
...@@ -7,7 +7,7 @@ import PIL.Image ...@@ -7,7 +7,7 @@ import PIL.Image
import torch import torch
from torch import nn from torch import nn
from torch.utils._pytree import tree_flatten, tree_unflatten from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision import datapoints from torchvision import tv_tensors
from torchvision.transforms.v2._utils import check_type, has_any, is_pure_tensor from torchvision.transforms.v2._utils import check_type, has_any, is_pure_tensor
from torchvision.utils import _log_api_usage_once from torchvision.utils import _log_api_usage_once
...@@ -56,8 +56,8 @@ class Transform(nn.Module): ...@@ -56,8 +56,8 @@ class Transform(nn.Module):
def _needs_transform_list(self, flat_inputs: List[Any]) -> List[bool]: def _needs_transform_list(self, flat_inputs: List[Any]) -> List[bool]:
# Below is a heuristic on how to deal with pure tensor inputs: # Below is a heuristic on how to deal with pure tensor inputs:
# 1. Pure tensors, i.e. tensors that are not a datapoint, are passed through if there is an explicit image # 1. Pure tensors, i.e. tensors that are not a tv_tensor, are passed through if there is an explicit image
# (`datapoints.Image` or `PIL.Image.Image`) or video (`datapoints.Video`) in the sample. # (`tv_tensors.Image` or `PIL.Image.Image`) or video (`tv_tensors.Video`) in the sample.
# 2. If there is no explicit image or video in the sample, only the first encountered pure tensor is # 2. If there is no explicit image or video in the sample, only the first encountered pure tensor is
# transformed as image, while the rest is passed through. The order is defined by the returned `flat_inputs` # transformed as image, while the rest is passed through. The order is defined by the returned `flat_inputs`
# of `tree_flatten`, which recurses depth-first through the input. # of `tree_flatten`, which recurses depth-first through the input.
...@@ -72,7 +72,7 @@ class Transform(nn.Module): ...@@ -72,7 +72,7 @@ class Transform(nn.Module):
# However, this case wasn't supported by transforms v1 either, so there is no BC concern. # However, this case wasn't supported by transforms v1 either, so there is no BC concern.
needs_transform_list = [] needs_transform_list = []
transform_pure_tensor = not has_any(flat_inputs, datapoints.Image, datapoints.Video, PIL.Image.Image) transform_pure_tensor = not has_any(flat_inputs, tv_tensors.Image, tv_tensors.Video, PIL.Image.Image)
for inpt in flat_inputs: for inpt in flat_inputs:
needs_transform = True needs_transform = True
......
...@@ -4,7 +4,7 @@ import numpy as np ...@@ -4,7 +4,7 @@ import numpy as np
import PIL.Image import PIL.Image
import torch import torch
from torchvision import datapoints from torchvision import tv_tensors
from torchvision.transforms.v2 import functional as F, Transform from torchvision.transforms.v2 import functional as F, Transform
from torchvision.transforms.v2._utils import is_pure_tensor from torchvision.transforms.v2._utils import is_pure_tensor
...@@ -27,7 +27,7 @@ class PILToTensor(Transform): ...@@ -27,7 +27,7 @@ class PILToTensor(Transform):
class ToImage(Transform): class ToImage(Transform):
"""[BETA] Convert a tensor, ndarray, or PIL Image to :class:`~torchvision.datapoints.Image` """[BETA] Convert a tensor, ndarray, or PIL Image to :class:`~torchvision.tv_tensors.Image`
; this does not scale values. ; this does not scale values.
.. v2betastatus:: ToImage transform .. v2betastatus:: ToImage transform
...@@ -39,7 +39,7 @@ class ToImage(Transform): ...@@ -39,7 +39,7 @@ class ToImage(Transform):
def _transform( def _transform(
self, inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray], params: Dict[str, Any] self, inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray], params: Dict[str, Any]
) -> datapoints.Image: ) -> tv_tensors.Image:
return F.to_image(inpt) return F.to_image(inpt)
...@@ -66,7 +66,7 @@ class ToPILImage(Transform): ...@@ -66,7 +66,7 @@ class ToPILImage(Transform):
.. _PIL.Image mode: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#concept-modes .. _PIL.Image mode: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#concept-modes
""" """
_transformed_types = (is_pure_tensor, datapoints.Image, np.ndarray) _transformed_types = (is_pure_tensor, tv_tensors.Image, np.ndarray)
def __init__(self, mode: Optional[str] = None) -> None: def __init__(self, mode: Optional[str] = None) -> None:
super().__init__() super().__init__()
...@@ -79,14 +79,14 @@ class ToPILImage(Transform): ...@@ -79,14 +79,14 @@ class ToPILImage(Transform):
class ToPureTensor(Transform): class ToPureTensor(Transform):
"""[BETA] Convert all datapoints to pure tensors, removing associated metadata (if any). """[BETA] Convert all tv_tensors to pure tensors, removing associated metadata (if any).
.. v2betastatus:: ToPureTensor transform .. v2betastatus:: ToPureTensor transform
This doesn't scale or change the values, only the type. This doesn't scale or change the values, only the type.
""" """
_transformed_types = (datapoints.Datapoint,) _transformed_types = (tv_tensors.TVTensor,)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> torch.Tensor: def _transform(self, inpt: Any, params: Dict[str, Any]) -> torch.Tensor:
return inpt.as_subclass(torch.Tensor) return inpt.as_subclass(torch.Tensor)
...@@ -9,7 +9,7 @@ from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple ...@@ -9,7 +9,7 @@ from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple
import PIL.Image import PIL.Image
import torch import torch
from torchvision import datapoints from torchvision import tv_tensors
from torchvision._utils import sequence_to_str from torchvision._utils import sequence_to_str
...@@ -149,10 +149,10 @@ def _parse_labels_getter( ...@@ -149,10 +149,10 @@ def _parse_labels_getter(
raise ValueError(f"labels_getter should either be 'default', a callable, or None, but got {labels_getter}.") raise ValueError(f"labels_getter should either be 'default', a callable, or None, but got {labels_getter}.")
def get_bounding_boxes(flat_inputs: List[Any]) -> datapoints.BoundingBoxes: def get_bounding_boxes(flat_inputs: List[Any]) -> tv_tensors.BoundingBoxes:
# This assumes there is only one bbox per sample as per the general convention # This assumes there is only one bbox per sample as per the general convention
try: try:
return next(inpt for inpt in flat_inputs if isinstance(inpt, datapoints.BoundingBoxes)) return next(inpt for inpt in flat_inputs if isinstance(inpt, tv_tensors.BoundingBoxes))
except StopIteration: except StopIteration:
raise ValueError("No bounding boxes were found in the sample") raise ValueError("No bounding boxes were found in the sample")
...@@ -161,7 +161,7 @@ def query_chw(flat_inputs: List[Any]) -> Tuple[int, int, int]: ...@@ -161,7 +161,7 @@ def query_chw(flat_inputs: List[Any]) -> Tuple[int, int, int]:
chws = { chws = {
tuple(get_dimensions(inpt)) tuple(get_dimensions(inpt))
for inpt in flat_inputs for inpt in flat_inputs
if check_type(inpt, (is_pure_tensor, datapoints.Image, PIL.Image.Image, datapoints.Video)) if check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video))
} }
if not chws: if not chws:
raise TypeError("No image or video was found in the sample") raise TypeError("No image or video was found in the sample")
...@@ -179,11 +179,11 @@ def query_size(flat_inputs: List[Any]) -> Tuple[int, int]: ...@@ -179,11 +179,11 @@ def query_size(flat_inputs: List[Any]) -> Tuple[int, int]:
inpt, inpt,
( (
is_pure_tensor, is_pure_tensor,
datapoints.Image, tv_tensors.Image,
PIL.Image.Image, PIL.Image.Image,
datapoints.Video, tv_tensors.Video,
datapoints.Mask, tv_tensors.Mask,
datapoints.BoundingBoxes, tv_tensors.BoundingBoxes,
), ),
) )
} }
......
import PIL.Image import PIL.Image
import torch import torch
from torchvision import datapoints from torchvision import tv_tensors
from torchvision.transforms.functional import pil_to_tensor, to_pil_image from torchvision.transforms.functional import pil_to_tensor, to_pil_image
from torchvision.utils import _log_api_usage_once from torchvision.utils import _log_api_usage_once
...@@ -28,7 +28,7 @@ def erase( ...@@ -28,7 +28,7 @@ def erase(
@_register_kernel_internal(erase, torch.Tensor) @_register_kernel_internal(erase, torch.Tensor)
@_register_kernel_internal(erase, datapoints.Image) @_register_kernel_internal(erase, tv_tensors.Image)
def erase_image( def erase_image(
image: torch.Tensor, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False image: torch.Tensor, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -48,7 +48,7 @@ def _erase_image_pil( ...@@ -48,7 +48,7 @@ def _erase_image_pil(
return to_pil_image(output, mode=image.mode) return to_pil_image(output, mode=image.mode)
@_register_kernel_internal(erase, datapoints.Video) @_register_kernel_internal(erase, tv_tensors.Video)
def erase_video( def erase_video(
video: torch.Tensor, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False video: torch.Tensor, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False
) -> torch.Tensor: ) -> torch.Tensor:
......
...@@ -3,7 +3,7 @@ from typing import List ...@@ -3,7 +3,7 @@ from typing import List
import PIL.Image import PIL.Image
import torch import torch
from torch.nn.functional import conv2d from torch.nn.functional import conv2d
from torchvision import datapoints from torchvision import tv_tensors
from torchvision.transforms import _functional_pil as _FP from torchvision.transforms import _functional_pil as _FP
from torchvision.transforms._functional_tensor import _max_value from torchvision.transforms._functional_tensor import _max_value
...@@ -47,7 +47,7 @@ def _rgb_to_grayscale_image( ...@@ -47,7 +47,7 @@ def _rgb_to_grayscale_image(
@_register_kernel_internal(rgb_to_grayscale, torch.Tensor) @_register_kernel_internal(rgb_to_grayscale, torch.Tensor)
@_register_kernel_internal(rgb_to_grayscale, datapoints.Image) @_register_kernel_internal(rgb_to_grayscale, tv_tensors.Image)
def rgb_to_grayscale_image(image: torch.Tensor, num_output_channels: int = 1) -> torch.Tensor: def rgb_to_grayscale_image(image: torch.Tensor, num_output_channels: int = 1) -> torch.Tensor:
if num_output_channels not in (1, 3): if num_output_channels not in (1, 3):
raise ValueError(f"num_output_channels must be 1 or 3, got {num_output_channels}.") raise ValueError(f"num_output_channels must be 1 or 3, got {num_output_channels}.")
...@@ -82,7 +82,7 @@ def adjust_brightness(inpt: torch.Tensor, brightness_factor: float) -> torch.Ten ...@@ -82,7 +82,7 @@ def adjust_brightness(inpt: torch.Tensor, brightness_factor: float) -> torch.Ten
@_register_kernel_internal(adjust_brightness, torch.Tensor) @_register_kernel_internal(adjust_brightness, torch.Tensor)
@_register_kernel_internal(adjust_brightness, datapoints.Image) @_register_kernel_internal(adjust_brightness, tv_tensors.Image)
def adjust_brightness_image(image: torch.Tensor, brightness_factor: float) -> torch.Tensor: def adjust_brightness_image(image: torch.Tensor, brightness_factor: float) -> torch.Tensor:
if brightness_factor < 0: if brightness_factor < 0:
raise ValueError(f"brightness_factor ({brightness_factor}) is not non-negative.") raise ValueError(f"brightness_factor ({brightness_factor}) is not non-negative.")
...@@ -102,7 +102,7 @@ def _adjust_brightness_image_pil(image: PIL.Image.Image, brightness_factor: floa ...@@ -102,7 +102,7 @@ def _adjust_brightness_image_pil(image: PIL.Image.Image, brightness_factor: floa
return _FP.adjust_brightness(image, brightness_factor=brightness_factor) return _FP.adjust_brightness(image, brightness_factor=brightness_factor)
@_register_kernel_internal(adjust_brightness, datapoints.Video) @_register_kernel_internal(adjust_brightness, tv_tensors.Video)
def adjust_brightness_video(video: torch.Tensor, brightness_factor: float) -> torch.Tensor: def adjust_brightness_video(video: torch.Tensor, brightness_factor: float) -> torch.Tensor:
return adjust_brightness_image(video, brightness_factor=brightness_factor) return adjust_brightness_image(video, brightness_factor=brightness_factor)
...@@ -119,7 +119,7 @@ def adjust_saturation(inpt: torch.Tensor, saturation_factor: float) -> torch.Ten ...@@ -119,7 +119,7 @@ def adjust_saturation(inpt: torch.Tensor, saturation_factor: float) -> torch.Ten
@_register_kernel_internal(adjust_saturation, torch.Tensor) @_register_kernel_internal(adjust_saturation, torch.Tensor)
@_register_kernel_internal(adjust_saturation, datapoints.Image) @_register_kernel_internal(adjust_saturation, tv_tensors.Image)
def adjust_saturation_image(image: torch.Tensor, saturation_factor: float) -> torch.Tensor: def adjust_saturation_image(image: torch.Tensor, saturation_factor: float) -> torch.Tensor:
if saturation_factor < 0: if saturation_factor < 0:
raise ValueError(f"saturation_factor ({saturation_factor}) is not non-negative.") raise ValueError(f"saturation_factor ({saturation_factor}) is not non-negative.")
...@@ -141,7 +141,7 @@ def adjust_saturation_image(image: torch.Tensor, saturation_factor: float) -> to ...@@ -141,7 +141,7 @@ def adjust_saturation_image(image: torch.Tensor, saturation_factor: float) -> to
_adjust_saturation_image_pil = _register_kernel_internal(adjust_saturation, PIL.Image.Image)(_FP.adjust_saturation) _adjust_saturation_image_pil = _register_kernel_internal(adjust_saturation, PIL.Image.Image)(_FP.adjust_saturation)
@_register_kernel_internal(adjust_saturation, datapoints.Video) @_register_kernel_internal(adjust_saturation, tv_tensors.Video)
def adjust_saturation_video(video: torch.Tensor, saturation_factor: float) -> torch.Tensor: def adjust_saturation_video(video: torch.Tensor, saturation_factor: float) -> torch.Tensor:
return adjust_saturation_image(video, saturation_factor=saturation_factor) return adjust_saturation_image(video, saturation_factor=saturation_factor)
...@@ -158,7 +158,7 @@ def adjust_contrast(inpt: torch.Tensor, contrast_factor: float) -> torch.Tensor: ...@@ -158,7 +158,7 @@ def adjust_contrast(inpt: torch.Tensor, contrast_factor: float) -> torch.Tensor:
@_register_kernel_internal(adjust_contrast, torch.Tensor) @_register_kernel_internal(adjust_contrast, torch.Tensor)
@_register_kernel_internal(adjust_contrast, datapoints.Image) @_register_kernel_internal(adjust_contrast, tv_tensors.Image)
def adjust_contrast_image(image: torch.Tensor, contrast_factor: float) -> torch.Tensor: def adjust_contrast_image(image: torch.Tensor, contrast_factor: float) -> torch.Tensor:
if contrast_factor < 0: if contrast_factor < 0:
raise ValueError(f"contrast_factor ({contrast_factor}) is not non-negative.") raise ValueError(f"contrast_factor ({contrast_factor}) is not non-negative.")
...@@ -180,7 +180,7 @@ def adjust_contrast_image(image: torch.Tensor, contrast_factor: float) -> torch. ...@@ -180,7 +180,7 @@ def adjust_contrast_image(image: torch.Tensor, contrast_factor: float) -> torch.
_adjust_contrast_image_pil = _register_kernel_internal(adjust_contrast, PIL.Image.Image)(_FP.adjust_contrast) _adjust_contrast_image_pil = _register_kernel_internal(adjust_contrast, PIL.Image.Image)(_FP.adjust_contrast)
@_register_kernel_internal(adjust_contrast, datapoints.Video) @_register_kernel_internal(adjust_contrast, tv_tensors.Video)
def adjust_contrast_video(video: torch.Tensor, contrast_factor: float) -> torch.Tensor: def adjust_contrast_video(video: torch.Tensor, contrast_factor: float) -> torch.Tensor:
return adjust_contrast_image(video, contrast_factor=contrast_factor) return adjust_contrast_image(video, contrast_factor=contrast_factor)
...@@ -197,7 +197,7 @@ def adjust_sharpness(inpt: torch.Tensor, sharpness_factor: float) -> torch.Tenso ...@@ -197,7 +197,7 @@ def adjust_sharpness(inpt: torch.Tensor, sharpness_factor: float) -> torch.Tenso
@_register_kernel_internal(adjust_sharpness, torch.Tensor) @_register_kernel_internal(adjust_sharpness, torch.Tensor)
@_register_kernel_internal(adjust_sharpness, datapoints.Image) @_register_kernel_internal(adjust_sharpness, tv_tensors.Image)
def adjust_sharpness_image(image: torch.Tensor, sharpness_factor: float) -> torch.Tensor: def adjust_sharpness_image(image: torch.Tensor, sharpness_factor: float) -> torch.Tensor:
num_channels, height, width = image.shape[-3:] num_channels, height, width = image.shape[-3:]
if num_channels not in (1, 3): if num_channels not in (1, 3):
...@@ -253,7 +253,7 @@ def adjust_sharpness_image(image: torch.Tensor, sharpness_factor: float) -> torc ...@@ -253,7 +253,7 @@ def adjust_sharpness_image(image: torch.Tensor, sharpness_factor: float) -> torc
_adjust_sharpness_image_pil = _register_kernel_internal(adjust_sharpness, PIL.Image.Image)(_FP.adjust_sharpness) _adjust_sharpness_image_pil = _register_kernel_internal(adjust_sharpness, PIL.Image.Image)(_FP.adjust_sharpness)
@_register_kernel_internal(adjust_sharpness, datapoints.Video) @_register_kernel_internal(adjust_sharpness, tv_tensors.Video)
def adjust_sharpness_video(video: torch.Tensor, sharpness_factor: float) -> torch.Tensor: def adjust_sharpness_video(video: torch.Tensor, sharpness_factor: float) -> torch.Tensor:
return adjust_sharpness_image(video, sharpness_factor=sharpness_factor) return adjust_sharpness_image(video, sharpness_factor=sharpness_factor)
...@@ -340,7 +340,7 @@ def _hsv_to_rgb(img: torch.Tensor) -> torch.Tensor: ...@@ -340,7 +340,7 @@ def _hsv_to_rgb(img: torch.Tensor) -> torch.Tensor:
@_register_kernel_internal(adjust_hue, torch.Tensor) @_register_kernel_internal(adjust_hue, torch.Tensor)
@_register_kernel_internal(adjust_hue, datapoints.Image) @_register_kernel_internal(adjust_hue, tv_tensors.Image)
def adjust_hue_image(image: torch.Tensor, hue_factor: float) -> torch.Tensor: def adjust_hue_image(image: torch.Tensor, hue_factor: float) -> torch.Tensor:
if not (-0.5 <= hue_factor <= 0.5): if not (-0.5 <= hue_factor <= 0.5):
raise ValueError(f"hue_factor ({hue_factor}) is not in [-0.5, 0.5].") raise ValueError(f"hue_factor ({hue_factor}) is not in [-0.5, 0.5].")
...@@ -371,7 +371,7 @@ def adjust_hue_image(image: torch.Tensor, hue_factor: float) -> torch.Tensor: ...@@ -371,7 +371,7 @@ def adjust_hue_image(image: torch.Tensor, hue_factor: float) -> torch.Tensor:
_adjust_hue_image_pil = _register_kernel_internal(adjust_hue, PIL.Image.Image)(_FP.adjust_hue) _adjust_hue_image_pil = _register_kernel_internal(adjust_hue, PIL.Image.Image)(_FP.adjust_hue)
@_register_kernel_internal(adjust_hue, datapoints.Video) @_register_kernel_internal(adjust_hue, tv_tensors.Video)
def adjust_hue_video(video: torch.Tensor, hue_factor: float) -> torch.Tensor: def adjust_hue_video(video: torch.Tensor, hue_factor: float) -> torch.Tensor:
return adjust_hue_image(video, hue_factor=hue_factor) return adjust_hue_image(video, hue_factor=hue_factor)
...@@ -388,7 +388,7 @@ def adjust_gamma(inpt: torch.Tensor, gamma: float, gain: float = 1) -> torch.Ten ...@@ -388,7 +388,7 @@ def adjust_gamma(inpt: torch.Tensor, gamma: float, gain: float = 1) -> torch.Ten
@_register_kernel_internal(adjust_gamma, torch.Tensor) @_register_kernel_internal(adjust_gamma, torch.Tensor)
@_register_kernel_internal(adjust_gamma, datapoints.Image) @_register_kernel_internal(adjust_gamma, tv_tensors.Image)
def adjust_gamma_image(image: torch.Tensor, gamma: float, gain: float = 1.0) -> torch.Tensor: def adjust_gamma_image(image: torch.Tensor, gamma: float, gain: float = 1.0) -> torch.Tensor:
if gamma < 0: if gamma < 0:
raise ValueError("Gamma should be a non-negative real number") raise ValueError("Gamma should be a non-negative real number")
...@@ -411,7 +411,7 @@ def adjust_gamma_image(image: torch.Tensor, gamma: float, gain: float = 1.0) -> ...@@ -411,7 +411,7 @@ def adjust_gamma_image(image: torch.Tensor, gamma: float, gain: float = 1.0) ->
_adjust_gamma_image_pil = _register_kernel_internal(adjust_gamma, PIL.Image.Image)(_FP.adjust_gamma) _adjust_gamma_image_pil = _register_kernel_internal(adjust_gamma, PIL.Image.Image)(_FP.adjust_gamma)
@_register_kernel_internal(adjust_gamma, datapoints.Video) @_register_kernel_internal(adjust_gamma, tv_tensors.Video)
def adjust_gamma_video(video: torch.Tensor, gamma: float, gain: float = 1) -> torch.Tensor: def adjust_gamma_video(video: torch.Tensor, gamma: float, gain: float = 1) -> torch.Tensor:
return adjust_gamma_image(video, gamma=gamma, gain=gain) return adjust_gamma_image(video, gamma=gamma, gain=gain)
...@@ -428,7 +428,7 @@ def posterize(inpt: torch.Tensor, bits: int) -> torch.Tensor: ...@@ -428,7 +428,7 @@ def posterize(inpt: torch.Tensor, bits: int) -> torch.Tensor:
@_register_kernel_internal(posterize, torch.Tensor) @_register_kernel_internal(posterize, torch.Tensor)
@_register_kernel_internal(posterize, datapoints.Image) @_register_kernel_internal(posterize, tv_tensors.Image)
def posterize_image(image: torch.Tensor, bits: int) -> torch.Tensor: def posterize_image(image: torch.Tensor, bits: int) -> torch.Tensor:
if image.is_floating_point(): if image.is_floating_point():
levels = 1 << bits levels = 1 << bits
...@@ -445,7 +445,7 @@ def posterize_image(image: torch.Tensor, bits: int) -> torch.Tensor: ...@@ -445,7 +445,7 @@ def posterize_image(image: torch.Tensor, bits: int) -> torch.Tensor:
_posterize_image_pil = _register_kernel_internal(posterize, PIL.Image.Image)(_FP.posterize) _posterize_image_pil = _register_kernel_internal(posterize, PIL.Image.Image)(_FP.posterize)
@_register_kernel_internal(posterize, datapoints.Video) @_register_kernel_internal(posterize, tv_tensors.Video)
def posterize_video(video: torch.Tensor, bits: int) -> torch.Tensor: def posterize_video(video: torch.Tensor, bits: int) -> torch.Tensor:
return posterize_image(video, bits=bits) return posterize_image(video, bits=bits)
...@@ -462,7 +462,7 @@ def solarize(inpt: torch.Tensor, threshold: float) -> torch.Tensor: ...@@ -462,7 +462,7 @@ def solarize(inpt: torch.Tensor, threshold: float) -> torch.Tensor:
@_register_kernel_internal(solarize, torch.Tensor) @_register_kernel_internal(solarize, torch.Tensor)
@_register_kernel_internal(solarize, datapoints.Image) @_register_kernel_internal(solarize, tv_tensors.Image)
def solarize_image(image: torch.Tensor, threshold: float) -> torch.Tensor: def solarize_image(image: torch.Tensor, threshold: float) -> torch.Tensor:
if threshold > _max_value(image.dtype): if threshold > _max_value(image.dtype):
raise TypeError(f"Threshold should be less or equal the maximum value of the dtype, but got {threshold}") raise TypeError(f"Threshold should be less or equal the maximum value of the dtype, but got {threshold}")
...@@ -473,7 +473,7 @@ def solarize_image(image: torch.Tensor, threshold: float) -> torch.Tensor: ...@@ -473,7 +473,7 @@ def solarize_image(image: torch.Tensor, threshold: float) -> torch.Tensor:
_solarize_image_pil = _register_kernel_internal(solarize, PIL.Image.Image)(_FP.solarize) _solarize_image_pil = _register_kernel_internal(solarize, PIL.Image.Image)(_FP.solarize)
@_register_kernel_internal(solarize, datapoints.Video) @_register_kernel_internal(solarize, tv_tensors.Video)
def solarize_video(video: torch.Tensor, threshold: float) -> torch.Tensor: def solarize_video(video: torch.Tensor, threshold: float) -> torch.Tensor:
return solarize_image(video, threshold=threshold) return solarize_image(video, threshold=threshold)
...@@ -490,7 +490,7 @@ def autocontrast(inpt: torch.Tensor) -> torch.Tensor: ...@@ -490,7 +490,7 @@ def autocontrast(inpt: torch.Tensor) -> torch.Tensor:
@_register_kernel_internal(autocontrast, torch.Tensor) @_register_kernel_internal(autocontrast, torch.Tensor)
@_register_kernel_internal(autocontrast, datapoints.Image) @_register_kernel_internal(autocontrast, tv_tensors.Image)
def autocontrast_image(image: torch.Tensor) -> torch.Tensor: def autocontrast_image(image: torch.Tensor) -> torch.Tensor:
c = image.shape[-3] c = image.shape[-3]
if c not in [1, 3]: if c not in [1, 3]:
...@@ -523,7 +523,7 @@ def autocontrast_image(image: torch.Tensor) -> torch.Tensor: ...@@ -523,7 +523,7 @@ def autocontrast_image(image: torch.Tensor) -> torch.Tensor:
_autocontrast_image_pil = _register_kernel_internal(autocontrast, PIL.Image.Image)(_FP.autocontrast) _autocontrast_image_pil = _register_kernel_internal(autocontrast, PIL.Image.Image)(_FP.autocontrast)
@_register_kernel_internal(autocontrast, datapoints.Video) @_register_kernel_internal(autocontrast, tv_tensors.Video)
def autocontrast_video(video: torch.Tensor) -> torch.Tensor: def autocontrast_video(video: torch.Tensor) -> torch.Tensor:
return autocontrast_image(video) return autocontrast_image(video)
...@@ -540,7 +540,7 @@ def equalize(inpt: torch.Tensor) -> torch.Tensor: ...@@ -540,7 +540,7 @@ def equalize(inpt: torch.Tensor) -> torch.Tensor:
@_register_kernel_internal(equalize, torch.Tensor) @_register_kernel_internal(equalize, torch.Tensor)
@_register_kernel_internal(equalize, datapoints.Image) @_register_kernel_internal(equalize, tv_tensors.Image)
def equalize_image(image: torch.Tensor) -> torch.Tensor: def equalize_image(image: torch.Tensor) -> torch.Tensor:
if image.numel() == 0: if image.numel() == 0:
return image return image
...@@ -613,7 +613,7 @@ def equalize_image(image: torch.Tensor) -> torch.Tensor: ...@@ -613,7 +613,7 @@ def equalize_image(image: torch.Tensor) -> torch.Tensor:
_equalize_image_pil = _register_kernel_internal(equalize, PIL.Image.Image)(_FP.equalize) _equalize_image_pil = _register_kernel_internal(equalize, PIL.Image.Image)(_FP.equalize)
@_register_kernel_internal(equalize, datapoints.Video) @_register_kernel_internal(equalize, tv_tensors.Video)
def equalize_video(video: torch.Tensor) -> torch.Tensor: def equalize_video(video: torch.Tensor) -> torch.Tensor:
return equalize_image(video) return equalize_image(video)
...@@ -630,7 +630,7 @@ def invert(inpt: torch.Tensor) -> torch.Tensor: ...@@ -630,7 +630,7 @@ def invert(inpt: torch.Tensor) -> torch.Tensor:
@_register_kernel_internal(invert, torch.Tensor) @_register_kernel_internal(invert, torch.Tensor)
@_register_kernel_internal(invert, datapoints.Image) @_register_kernel_internal(invert, tv_tensors.Image)
def invert_image(image: torch.Tensor) -> torch.Tensor: def invert_image(image: torch.Tensor) -> torch.Tensor:
if image.is_floating_point(): if image.is_floating_point():
return 1.0 - image return 1.0 - image
...@@ -644,7 +644,7 @@ def invert_image(image: torch.Tensor) -> torch.Tensor: ...@@ -644,7 +644,7 @@ def invert_image(image: torch.Tensor) -> torch.Tensor:
_invert_image_pil = _register_kernel_internal(invert, PIL.Image.Image)(_FP.invert) _invert_image_pil = _register_kernel_internal(invert, PIL.Image.Image)(_FP.invert)
@_register_kernel_internal(invert, datapoints.Video) @_register_kernel_internal(invert, tv_tensors.Video)
def invert_video(video: torch.Tensor) -> torch.Tensor: def invert_video(video: torch.Tensor) -> torch.Tensor:
return invert_image(video) return invert_image(video)
...@@ -653,7 +653,7 @@ def permute_channels(inpt: torch.Tensor, permutation: List[int]) -> torch.Tensor ...@@ -653,7 +653,7 @@ def permute_channels(inpt: torch.Tensor, permutation: List[int]) -> torch.Tensor
"""Permute the channels of the input according to the given permutation. """Permute the channels of the input according to the given permutation.
This function supports plain :class:`~torch.Tensor`'s, :class:`PIL.Image.Image`'s, and This function supports plain :class:`~torch.Tensor`'s, :class:`PIL.Image.Image`'s, and
:class:`torchvision.datapoints.Image` and :class:`torchvision.datapoints.Video`. :class:`torchvision.tv_tensors.Image` and :class:`torchvision.tv_tensors.Video`.
Example: Example:
>>> rgb_image = torch.rand(3, 256, 256) >>> rgb_image = torch.rand(3, 256, 256)
...@@ -681,7 +681,7 @@ def permute_channels(inpt: torch.Tensor, permutation: List[int]) -> torch.Tensor ...@@ -681,7 +681,7 @@ def permute_channels(inpt: torch.Tensor, permutation: List[int]) -> torch.Tensor
@_register_kernel_internal(permute_channels, torch.Tensor) @_register_kernel_internal(permute_channels, torch.Tensor)
@_register_kernel_internal(permute_channels, datapoints.Image) @_register_kernel_internal(permute_channels, tv_tensors.Image)
def permute_channels_image(image: torch.Tensor, permutation: List[int]) -> torch.Tensor: def permute_channels_image(image: torch.Tensor, permutation: List[int]) -> torch.Tensor:
shape = image.shape shape = image.shape
num_channels, height, width = shape[-3:] num_channels, height, width = shape[-3:]
...@@ -704,6 +704,6 @@ def _permute_channels_image_pil(image: PIL.Image.Image, permutation: List[int]) ...@@ -704,6 +704,6 @@ def _permute_channels_image_pil(image: PIL.Image.Image, permutation: List[int])
return to_pil_image(permute_channels_image(pil_to_tensor(image), permutation=permutation)) return to_pil_image(permute_channels_image(pil_to_tensor(image), permutation=permutation))
@_register_kernel_internal(permute_channels, datapoints.Video) @_register_kernel_internal(permute_channels, tv_tensors.Video)
def permute_channels_video(video: torch.Tensor, permutation: List[int]) -> torch.Tensor: def permute_channels_video(video: torch.Tensor, permutation: List[int]) -> torch.Tensor:
return permute_channels_image(video, permutation=permutation) return permute_channels_image(video, permutation=permutation)
...@@ -7,7 +7,7 @@ import PIL.Image ...@@ -7,7 +7,7 @@ import PIL.Image
import torch import torch
from torch.nn.functional import grid_sample, interpolate, pad as torch_pad from torch.nn.functional import grid_sample, interpolate, pad as torch_pad
from torchvision import datapoints from torchvision import tv_tensors
from torchvision.transforms import _functional_pil as _FP from torchvision.transforms import _functional_pil as _FP
from torchvision.transforms._functional_tensor import _pad_symmetric from torchvision.transforms._functional_tensor import _pad_symmetric
from torchvision.transforms.functional import ( from torchvision.transforms.functional import (
...@@ -51,7 +51,7 @@ def horizontal_flip(inpt: torch.Tensor) -> torch.Tensor: ...@@ -51,7 +51,7 @@ def horizontal_flip(inpt: torch.Tensor) -> torch.Tensor:
@_register_kernel_internal(horizontal_flip, torch.Tensor) @_register_kernel_internal(horizontal_flip, torch.Tensor)
@_register_kernel_internal(horizontal_flip, datapoints.Image) @_register_kernel_internal(horizontal_flip, tv_tensors.Image)
def horizontal_flip_image(image: torch.Tensor) -> torch.Tensor: def horizontal_flip_image(image: torch.Tensor) -> torch.Tensor:
return image.flip(-1) return image.flip(-1)
...@@ -61,37 +61,37 @@ def _horizontal_flip_image_pil(image: PIL.Image.Image) -> PIL.Image.Image: ...@@ -61,37 +61,37 @@ def _horizontal_flip_image_pil(image: PIL.Image.Image) -> PIL.Image.Image:
return _FP.hflip(image) return _FP.hflip(image)
@_register_kernel_internal(horizontal_flip, datapoints.Mask) @_register_kernel_internal(horizontal_flip, tv_tensors.Mask)
def horizontal_flip_mask(mask: torch.Tensor) -> torch.Tensor: def horizontal_flip_mask(mask: torch.Tensor) -> torch.Tensor:
return horizontal_flip_image(mask) return horizontal_flip_image(mask)
def horizontal_flip_bounding_boxes( def horizontal_flip_bounding_boxes(
bounding_boxes: torch.Tensor, format: datapoints.BoundingBoxFormat, canvas_size: Tuple[int, int] bounding_boxes: torch.Tensor, format: tv_tensors.BoundingBoxFormat, canvas_size: Tuple[int, int]
) -> torch.Tensor: ) -> torch.Tensor:
shape = bounding_boxes.shape shape = bounding_boxes.shape
bounding_boxes = bounding_boxes.clone().reshape(-1, 4) bounding_boxes = bounding_boxes.clone().reshape(-1, 4)
if format == datapoints.BoundingBoxFormat.XYXY: if format == tv_tensors.BoundingBoxFormat.XYXY:
bounding_boxes[:, [2, 0]] = bounding_boxes[:, [0, 2]].sub_(canvas_size[1]).neg_() bounding_boxes[:, [2, 0]] = bounding_boxes[:, [0, 2]].sub_(canvas_size[1]).neg_()
elif format == datapoints.BoundingBoxFormat.XYWH: elif format == tv_tensors.BoundingBoxFormat.XYWH:
bounding_boxes[:, 0].add_(bounding_boxes[:, 2]).sub_(canvas_size[1]).neg_() bounding_boxes[:, 0].add_(bounding_boxes[:, 2]).sub_(canvas_size[1]).neg_()
else: # format == datapoints.BoundingBoxFormat.CXCYWH: else: # format == tv_tensors.BoundingBoxFormat.CXCYWH:
bounding_boxes[:, 0].sub_(canvas_size[1]).neg_() bounding_boxes[:, 0].sub_(canvas_size[1]).neg_()
return bounding_boxes.reshape(shape) return bounding_boxes.reshape(shape)
@_register_kernel_internal(horizontal_flip, datapoints.BoundingBoxes, datapoint_wrapper=False) @_register_kernel_internal(horizontal_flip, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False)
def _horizontal_flip_bounding_boxes_dispatch(inpt: datapoints.BoundingBoxes) -> datapoints.BoundingBoxes: def _horizontal_flip_bounding_boxes_dispatch(inpt: tv_tensors.BoundingBoxes) -> tv_tensors.BoundingBoxes:
output = horizontal_flip_bounding_boxes( output = horizontal_flip_bounding_boxes(
inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size
) )
return datapoints.wrap(output, like=inpt) return tv_tensors.wrap(output, like=inpt)
@_register_kernel_internal(horizontal_flip, datapoints.Video) @_register_kernel_internal(horizontal_flip, tv_tensors.Video)
def horizontal_flip_video(video: torch.Tensor) -> torch.Tensor: def horizontal_flip_video(video: torch.Tensor) -> torch.Tensor:
return horizontal_flip_image(video) return horizontal_flip_image(video)
...@@ -108,7 +108,7 @@ def vertical_flip(inpt: torch.Tensor) -> torch.Tensor: ...@@ -108,7 +108,7 @@ def vertical_flip(inpt: torch.Tensor) -> torch.Tensor:
@_register_kernel_internal(vertical_flip, torch.Tensor) @_register_kernel_internal(vertical_flip, torch.Tensor)
@_register_kernel_internal(vertical_flip, datapoints.Image) @_register_kernel_internal(vertical_flip, tv_tensors.Image)
def vertical_flip_image(image: torch.Tensor) -> torch.Tensor: def vertical_flip_image(image: torch.Tensor) -> torch.Tensor:
return image.flip(-2) return image.flip(-2)
...@@ -118,37 +118,37 @@ def _vertical_flip_image_pil(image: PIL.Image) -> PIL.Image: ...@@ -118,37 +118,37 @@ def _vertical_flip_image_pil(image: PIL.Image) -> PIL.Image:
return _FP.vflip(image) return _FP.vflip(image)
@_register_kernel_internal(vertical_flip, datapoints.Mask) @_register_kernel_internal(vertical_flip, tv_tensors.Mask)
def vertical_flip_mask(mask: torch.Tensor) -> torch.Tensor: def vertical_flip_mask(mask: torch.Tensor) -> torch.Tensor:
return vertical_flip_image(mask) return vertical_flip_image(mask)
def vertical_flip_bounding_boxes( def vertical_flip_bounding_boxes(
bounding_boxes: torch.Tensor, format: datapoints.BoundingBoxFormat, canvas_size: Tuple[int, int] bounding_boxes: torch.Tensor, format: tv_tensors.BoundingBoxFormat, canvas_size: Tuple[int, int]
) -> torch.Tensor: ) -> torch.Tensor:
shape = bounding_boxes.shape shape = bounding_boxes.shape
bounding_boxes = bounding_boxes.clone().reshape(-1, 4) bounding_boxes = bounding_boxes.clone().reshape(-1, 4)
if format == datapoints.BoundingBoxFormat.XYXY: if format == tv_tensors.BoundingBoxFormat.XYXY:
bounding_boxes[:, [1, 3]] = bounding_boxes[:, [3, 1]].sub_(canvas_size[0]).neg_() bounding_boxes[:, [1, 3]] = bounding_boxes[:, [3, 1]].sub_(canvas_size[0]).neg_()
elif format == datapoints.BoundingBoxFormat.XYWH: elif format == tv_tensors.BoundingBoxFormat.XYWH:
bounding_boxes[:, 1].add_(bounding_boxes[:, 3]).sub_(canvas_size[0]).neg_() bounding_boxes[:, 1].add_(bounding_boxes[:, 3]).sub_(canvas_size[0]).neg_()
else: # format == datapoints.BoundingBoxFormat.CXCYWH: else: # format == tv_tensors.BoundingBoxFormat.CXCYWH:
bounding_boxes[:, 1].sub_(canvas_size[0]).neg_() bounding_boxes[:, 1].sub_(canvas_size[0]).neg_()
return bounding_boxes.reshape(shape) return bounding_boxes.reshape(shape)
@_register_kernel_internal(vertical_flip, datapoints.BoundingBoxes, datapoint_wrapper=False) @_register_kernel_internal(vertical_flip, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False)
def _vertical_flip_bounding_boxes_dispatch(inpt: datapoints.BoundingBoxes) -> datapoints.BoundingBoxes: def _vertical_flip_bounding_boxes_dispatch(inpt: tv_tensors.BoundingBoxes) -> tv_tensors.BoundingBoxes:
output = vertical_flip_bounding_boxes( output = vertical_flip_bounding_boxes(
inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size
) )
return datapoints.wrap(output, like=inpt) return tv_tensors.wrap(output, like=inpt)
@_register_kernel_internal(vertical_flip, datapoints.Video) @_register_kernel_internal(vertical_flip, tv_tensors.Video)
def vertical_flip_video(video: torch.Tensor) -> torch.Tensor: def vertical_flip_video(video: torch.Tensor) -> torch.Tensor:
return vertical_flip_image(video) return vertical_flip_image(video)
...@@ -190,7 +190,7 @@ def resize( ...@@ -190,7 +190,7 @@ def resize(
@_register_kernel_internal(resize, torch.Tensor) @_register_kernel_internal(resize, torch.Tensor)
@_register_kernel_internal(resize, datapoints.Image) @_register_kernel_internal(resize, tv_tensors.Image)
def resize_image( def resize_image(
image: torch.Tensor, image: torch.Tensor,
size: List[int], size: List[int],
...@@ -319,12 +319,12 @@ def resize_mask(mask: torch.Tensor, size: List[int], max_size: Optional[int] = N ...@@ -319,12 +319,12 @@ def resize_mask(mask: torch.Tensor, size: List[int], max_size: Optional[int] = N
return output return output
@_register_kernel_internal(resize, datapoints.Mask, datapoint_wrapper=False) @_register_kernel_internal(resize, tv_tensors.Mask, tv_tensor_wrapper=False)
def _resize_mask_dispatch( def _resize_mask_dispatch(
inpt: datapoints.Mask, size: List[int], max_size: Optional[int] = None, **kwargs: Any inpt: tv_tensors.Mask, size: List[int], max_size: Optional[int] = None, **kwargs: Any
) -> datapoints.Mask: ) -> tv_tensors.Mask:
output = resize_mask(inpt.as_subclass(torch.Tensor), size, max_size=max_size) output = resize_mask(inpt.as_subclass(torch.Tensor), size, max_size=max_size)
return datapoints.wrap(output, like=inpt) return tv_tensors.wrap(output, like=inpt)
def resize_bounding_boxes( def resize_bounding_boxes(
...@@ -345,17 +345,17 @@ def resize_bounding_boxes( ...@@ -345,17 +345,17 @@ def resize_bounding_boxes(
) )
@_register_kernel_internal(resize, datapoints.BoundingBoxes, datapoint_wrapper=False) @_register_kernel_internal(resize, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False)
def _resize_bounding_boxes_dispatch( def _resize_bounding_boxes_dispatch(
inpt: datapoints.BoundingBoxes, size: List[int], max_size: Optional[int] = None, **kwargs: Any inpt: tv_tensors.BoundingBoxes, size: List[int], max_size: Optional[int] = None, **kwargs: Any
) -> datapoints.BoundingBoxes: ) -> tv_tensors.BoundingBoxes:
output, canvas_size = resize_bounding_boxes( output, canvas_size = resize_bounding_boxes(
inpt.as_subclass(torch.Tensor), inpt.canvas_size, size, max_size=max_size inpt.as_subclass(torch.Tensor), inpt.canvas_size, size, max_size=max_size
) )
return datapoints.wrap(output, like=inpt, canvas_size=canvas_size) return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size)
@_register_kernel_internal(resize, datapoints.Video) @_register_kernel_internal(resize, tv_tensors.Video)
def resize_video( def resize_video(
video: torch.Tensor, video: torch.Tensor,
size: List[int], size: List[int],
...@@ -651,7 +651,7 @@ def _affine_grid( ...@@ -651,7 +651,7 @@ def _affine_grid(
@_register_kernel_internal(affine, torch.Tensor) @_register_kernel_internal(affine, torch.Tensor)
@_register_kernel_internal(affine, datapoints.Image) @_register_kernel_internal(affine, tv_tensors.Image)
def affine_image( def affine_image(
image: torch.Tensor, image: torch.Tensor,
angle: Union[int, float], angle: Union[int, float],
...@@ -730,7 +730,7 @@ def _affine_image_pil( ...@@ -730,7 +730,7 @@ def _affine_image_pil(
def _affine_bounding_boxes_with_expand( def _affine_bounding_boxes_with_expand(
bounding_boxes: torch.Tensor, bounding_boxes: torch.Tensor,
format: datapoints.BoundingBoxFormat, format: tv_tensors.BoundingBoxFormat,
canvas_size: Tuple[int, int], canvas_size: Tuple[int, int],
angle: Union[int, float], angle: Union[int, float],
translate: List[float], translate: List[float],
...@@ -749,7 +749,7 @@ def _affine_bounding_boxes_with_expand( ...@@ -749,7 +749,7 @@ def _affine_bounding_boxes_with_expand(
device = bounding_boxes.device device = bounding_boxes.device
bounding_boxes = ( bounding_boxes = (
convert_bounding_box_format( convert_bounding_box_format(
bounding_boxes, old_format=format, new_format=datapoints.BoundingBoxFormat.XYXY, inplace=True bounding_boxes, old_format=format, new_format=tv_tensors.BoundingBoxFormat.XYXY, inplace=True
) )
).reshape(-1, 4) ).reshape(-1, 4)
...@@ -808,9 +808,9 @@ def _affine_bounding_boxes_with_expand( ...@@ -808,9 +808,9 @@ def _affine_bounding_boxes_with_expand(
new_width, new_height = _compute_affine_output_size(affine_vector, width, height) new_width, new_height = _compute_affine_output_size(affine_vector, width, height)
canvas_size = (new_height, new_width) canvas_size = (new_height, new_width)
out_bboxes = clamp_bounding_boxes(out_bboxes, format=datapoints.BoundingBoxFormat.XYXY, canvas_size=canvas_size) out_bboxes = clamp_bounding_boxes(out_bboxes, format=tv_tensors.BoundingBoxFormat.XYXY, canvas_size=canvas_size)
out_bboxes = convert_bounding_box_format( out_bboxes = convert_bounding_box_format(
out_bboxes, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format, inplace=True out_bboxes, old_format=tv_tensors.BoundingBoxFormat.XYXY, new_format=format, inplace=True
).reshape(original_shape) ).reshape(original_shape)
out_bboxes = out_bboxes.to(original_dtype) out_bboxes = out_bboxes.to(original_dtype)
...@@ -819,7 +819,7 @@ def _affine_bounding_boxes_with_expand( ...@@ -819,7 +819,7 @@ def _affine_bounding_boxes_with_expand(
def affine_bounding_boxes( def affine_bounding_boxes(
bounding_boxes: torch.Tensor, bounding_boxes: torch.Tensor,
format: datapoints.BoundingBoxFormat, format: tv_tensors.BoundingBoxFormat,
canvas_size: Tuple[int, int], canvas_size: Tuple[int, int],
angle: Union[int, float], angle: Union[int, float],
translate: List[float], translate: List[float],
...@@ -841,16 +841,16 @@ def affine_bounding_boxes( ...@@ -841,16 +841,16 @@ def affine_bounding_boxes(
return out_box return out_box
@_register_kernel_internal(affine, datapoints.BoundingBoxes, datapoint_wrapper=False) @_register_kernel_internal(affine, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False)
def _affine_bounding_boxes_dispatch( def _affine_bounding_boxes_dispatch(
inpt: datapoints.BoundingBoxes, inpt: tv_tensors.BoundingBoxes,
angle: Union[int, float], angle: Union[int, float],
translate: List[float], translate: List[float],
scale: float, scale: float,
shear: List[float], shear: List[float],
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
**kwargs, **kwargs,
) -> datapoints.BoundingBoxes: ) -> tv_tensors.BoundingBoxes:
output = affine_bounding_boxes( output = affine_bounding_boxes(
inpt.as_subclass(torch.Tensor), inpt.as_subclass(torch.Tensor),
format=inpt.format, format=inpt.format,
...@@ -861,7 +861,7 @@ def _affine_bounding_boxes_dispatch( ...@@ -861,7 +861,7 @@ def _affine_bounding_boxes_dispatch(
shear=shear, shear=shear,
center=center, center=center,
) )
return datapoints.wrap(output, like=inpt) return tv_tensors.wrap(output, like=inpt)
def affine_mask( def affine_mask(
...@@ -896,9 +896,9 @@ def affine_mask( ...@@ -896,9 +896,9 @@ def affine_mask(
return output return output
@_register_kernel_internal(affine, datapoints.Mask, datapoint_wrapper=False) @_register_kernel_internal(affine, tv_tensors.Mask, tv_tensor_wrapper=False)
def _affine_mask_dispatch( def _affine_mask_dispatch(
inpt: datapoints.Mask, inpt: tv_tensors.Mask,
angle: Union[int, float], angle: Union[int, float],
translate: List[float], translate: List[float],
scale: float, scale: float,
...@@ -906,7 +906,7 @@ def _affine_mask_dispatch( ...@@ -906,7 +906,7 @@ def _affine_mask_dispatch(
fill: _FillTypeJIT = None, fill: _FillTypeJIT = None,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
**kwargs, **kwargs,
) -> datapoints.Mask: ) -> tv_tensors.Mask:
output = affine_mask( output = affine_mask(
inpt.as_subclass(torch.Tensor), inpt.as_subclass(torch.Tensor),
angle=angle, angle=angle,
...@@ -916,10 +916,10 @@ def _affine_mask_dispatch( ...@@ -916,10 +916,10 @@ def _affine_mask_dispatch(
fill=fill, fill=fill,
center=center, center=center,
) )
return datapoints.wrap(output, like=inpt) return tv_tensors.wrap(output, like=inpt)
@_register_kernel_internal(affine, datapoints.Video) @_register_kernel_internal(affine, tv_tensors.Video)
def affine_video( def affine_video(
video: torch.Tensor, video: torch.Tensor,
angle: Union[int, float], angle: Union[int, float],
...@@ -961,7 +961,7 @@ def rotate( ...@@ -961,7 +961,7 @@ def rotate(
@_register_kernel_internal(rotate, torch.Tensor) @_register_kernel_internal(rotate, torch.Tensor)
@_register_kernel_internal(rotate, datapoints.Image) @_register_kernel_internal(rotate, tv_tensors.Image)
def rotate_image( def rotate_image(
image: torch.Tensor, image: torch.Tensor,
angle: float, angle: float,
...@@ -1027,7 +1027,7 @@ def _rotate_image_pil( ...@@ -1027,7 +1027,7 @@ def _rotate_image_pil(
def rotate_bounding_boxes( def rotate_bounding_boxes(
bounding_boxes: torch.Tensor, bounding_boxes: torch.Tensor,
format: datapoints.BoundingBoxFormat, format: tv_tensors.BoundingBoxFormat,
canvas_size: Tuple[int, int], canvas_size: Tuple[int, int],
angle: float, angle: float,
expand: bool = False, expand: bool = False,
...@@ -1049,10 +1049,10 @@ def rotate_bounding_boxes( ...@@ -1049,10 +1049,10 @@ def rotate_bounding_boxes(
) )
@_register_kernel_internal(rotate, datapoints.BoundingBoxes, datapoint_wrapper=False) @_register_kernel_internal(rotate, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False)
def _rotate_bounding_boxes_dispatch( def _rotate_bounding_boxes_dispatch(
inpt: datapoints.BoundingBoxes, angle: float, expand: bool = False, center: Optional[List[float]] = None, **kwargs inpt: tv_tensors.BoundingBoxes, angle: float, expand: bool = False, center: Optional[List[float]] = None, **kwargs
) -> datapoints.BoundingBoxes: ) -> tv_tensors.BoundingBoxes:
output, canvas_size = rotate_bounding_boxes( output, canvas_size = rotate_bounding_boxes(
inpt.as_subclass(torch.Tensor), inpt.as_subclass(torch.Tensor),
format=inpt.format, format=inpt.format,
...@@ -1061,7 +1061,7 @@ def _rotate_bounding_boxes_dispatch( ...@@ -1061,7 +1061,7 @@ def _rotate_bounding_boxes_dispatch(
expand=expand, expand=expand,
center=center, center=center,
) )
return datapoints.wrap(output, like=inpt, canvas_size=canvas_size) return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size)
def rotate_mask( def rotate_mask(
...@@ -1092,20 +1092,20 @@ def rotate_mask( ...@@ -1092,20 +1092,20 @@ def rotate_mask(
return output return output
@_register_kernel_internal(rotate, datapoints.Mask, datapoint_wrapper=False) @_register_kernel_internal(rotate, tv_tensors.Mask, tv_tensor_wrapper=False)
def _rotate_mask_dispatch( def _rotate_mask_dispatch(
inpt: datapoints.Mask, inpt: tv_tensors.Mask,
angle: float, angle: float,
expand: bool = False, expand: bool = False,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
fill: _FillTypeJIT = None, fill: _FillTypeJIT = None,
**kwargs, **kwargs,
) -> datapoints.Mask: ) -> tv_tensors.Mask:
output = rotate_mask(inpt.as_subclass(torch.Tensor), angle=angle, expand=expand, fill=fill, center=center) output = rotate_mask(inpt.as_subclass(torch.Tensor), angle=angle, expand=expand, fill=fill, center=center)
return datapoints.wrap(output, like=inpt) return tv_tensors.wrap(output, like=inpt)
@_register_kernel_internal(rotate, datapoints.Video) @_register_kernel_internal(rotate, tv_tensors.Video)
def rotate_video( def rotate_video(
video: torch.Tensor, video: torch.Tensor,
angle: float, angle: float,
...@@ -1158,7 +1158,7 @@ def _parse_pad_padding(padding: Union[int, List[int]]) -> List[int]: ...@@ -1158,7 +1158,7 @@ def _parse_pad_padding(padding: Union[int, List[int]]) -> List[int]:
@_register_kernel_internal(pad, torch.Tensor) @_register_kernel_internal(pad, torch.Tensor)
@_register_kernel_internal(pad, datapoints.Image) @_register_kernel_internal(pad, tv_tensors.Image)
def pad_image( def pad_image(
image: torch.Tensor, image: torch.Tensor,
padding: List[int], padding: List[int],
...@@ -1260,7 +1260,7 @@ def _pad_with_vector_fill( ...@@ -1260,7 +1260,7 @@ def _pad_with_vector_fill(
_pad_image_pil = _register_kernel_internal(pad, PIL.Image.Image)(_FP.pad) _pad_image_pil = _register_kernel_internal(pad, PIL.Image.Image)(_FP.pad)
@_register_kernel_internal(pad, datapoints.Mask) @_register_kernel_internal(pad, tv_tensors.Mask)
def pad_mask( def pad_mask(
mask: torch.Tensor, mask: torch.Tensor,
padding: List[int], padding: List[int],
...@@ -1289,7 +1289,7 @@ def pad_mask( ...@@ -1289,7 +1289,7 @@ def pad_mask(
def pad_bounding_boxes( def pad_bounding_boxes(
bounding_boxes: torch.Tensor, bounding_boxes: torch.Tensor,
format: datapoints.BoundingBoxFormat, format: tv_tensors.BoundingBoxFormat,
canvas_size: Tuple[int, int], canvas_size: Tuple[int, int],
padding: List[int], padding: List[int],
padding_mode: str = "constant", padding_mode: str = "constant",
...@@ -1300,7 +1300,7 @@ def pad_bounding_boxes( ...@@ -1300,7 +1300,7 @@ def pad_bounding_boxes(
left, right, top, bottom = _parse_pad_padding(padding) left, right, top, bottom = _parse_pad_padding(padding)
if format == datapoints.BoundingBoxFormat.XYXY: if format == tv_tensors.BoundingBoxFormat.XYXY:
pad = [left, top, left, top] pad = [left, top, left, top]
else: else:
pad = [left, top, 0, 0] pad = [left, top, 0, 0]
...@@ -1314,10 +1314,10 @@ def pad_bounding_boxes( ...@@ -1314,10 +1314,10 @@ def pad_bounding_boxes(
return clamp_bounding_boxes(bounding_boxes, format=format, canvas_size=canvas_size), canvas_size return clamp_bounding_boxes(bounding_boxes, format=format, canvas_size=canvas_size), canvas_size
@_register_kernel_internal(pad, datapoints.BoundingBoxes, datapoint_wrapper=False) @_register_kernel_internal(pad, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False)
def _pad_bounding_boxes_dispatch( def _pad_bounding_boxes_dispatch(
inpt: datapoints.BoundingBoxes, padding: List[int], padding_mode: str = "constant", **kwargs inpt: tv_tensors.BoundingBoxes, padding: List[int], padding_mode: str = "constant", **kwargs
) -> datapoints.BoundingBoxes: ) -> tv_tensors.BoundingBoxes:
output, canvas_size = pad_bounding_boxes( output, canvas_size = pad_bounding_boxes(
inpt.as_subclass(torch.Tensor), inpt.as_subclass(torch.Tensor),
format=inpt.format, format=inpt.format,
...@@ -1325,10 +1325,10 @@ def _pad_bounding_boxes_dispatch( ...@@ -1325,10 +1325,10 @@ def _pad_bounding_boxes_dispatch(
padding=padding, padding=padding,
padding_mode=padding_mode, padding_mode=padding_mode,
) )
return datapoints.wrap(output, like=inpt, canvas_size=canvas_size) return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size)
@_register_kernel_internal(pad, datapoints.Video) @_register_kernel_internal(pad, tv_tensors.Video)
def pad_video( def pad_video(
video: torch.Tensor, video: torch.Tensor,
padding: List[int], padding: List[int],
...@@ -1350,7 +1350,7 @@ def crop(inpt: torch.Tensor, top: int, left: int, height: int, width: int) -> to ...@@ -1350,7 +1350,7 @@ def crop(inpt: torch.Tensor, top: int, left: int, height: int, width: int) -> to
@_register_kernel_internal(crop, torch.Tensor) @_register_kernel_internal(crop, torch.Tensor)
@_register_kernel_internal(crop, datapoints.Image) @_register_kernel_internal(crop, tv_tensors.Image)
def crop_image(image: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor: def crop_image(image: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor:
h, w = image.shape[-2:] h, w = image.shape[-2:]
...@@ -1375,7 +1375,7 @@ _register_kernel_internal(crop, PIL.Image.Image)(_crop_image_pil) ...@@ -1375,7 +1375,7 @@ _register_kernel_internal(crop, PIL.Image.Image)(_crop_image_pil)
def crop_bounding_boxes( def crop_bounding_boxes(
bounding_boxes: torch.Tensor, bounding_boxes: torch.Tensor,
format: datapoints.BoundingBoxFormat, format: tv_tensors.BoundingBoxFormat,
top: int, top: int,
left: int, left: int,
height: int, height: int,
...@@ -1383,7 +1383,7 @@ def crop_bounding_boxes( ...@@ -1383,7 +1383,7 @@ def crop_bounding_boxes(
) -> Tuple[torch.Tensor, Tuple[int, int]]: ) -> Tuple[torch.Tensor, Tuple[int, int]]:
# Crop or implicit pad if left and/or top have negative values: # Crop or implicit pad if left and/or top have negative values:
if format == datapoints.BoundingBoxFormat.XYXY: if format == tv_tensors.BoundingBoxFormat.XYXY:
sub = [left, top, left, top] sub = [left, top, left, top]
else: else:
sub = [left, top, 0, 0] sub = [left, top, 0, 0]
...@@ -1394,17 +1394,17 @@ def crop_bounding_boxes( ...@@ -1394,17 +1394,17 @@ def crop_bounding_boxes(
return clamp_bounding_boxes(bounding_boxes, format=format, canvas_size=canvas_size), canvas_size return clamp_bounding_boxes(bounding_boxes, format=format, canvas_size=canvas_size), canvas_size
@_register_kernel_internal(crop, datapoints.BoundingBoxes, datapoint_wrapper=False) @_register_kernel_internal(crop, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False)
def _crop_bounding_boxes_dispatch( def _crop_bounding_boxes_dispatch(
inpt: datapoints.BoundingBoxes, top: int, left: int, height: int, width: int inpt: tv_tensors.BoundingBoxes, top: int, left: int, height: int, width: int
) -> datapoints.BoundingBoxes: ) -> tv_tensors.BoundingBoxes:
output, canvas_size = crop_bounding_boxes( output, canvas_size = crop_bounding_boxes(
inpt.as_subclass(torch.Tensor), format=inpt.format, top=top, left=left, height=height, width=width inpt.as_subclass(torch.Tensor), format=inpt.format, top=top, left=left, height=height, width=width
) )
return datapoints.wrap(output, like=inpt, canvas_size=canvas_size) return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size)
@_register_kernel_internal(crop, datapoints.Mask) @_register_kernel_internal(crop, tv_tensors.Mask)
def crop_mask(mask: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor: def crop_mask(mask: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor:
if mask.ndim < 3: if mask.ndim < 3:
mask = mask.unsqueeze(0) mask = mask.unsqueeze(0)
...@@ -1420,7 +1420,7 @@ def crop_mask(mask: torch.Tensor, top: int, left: int, height: int, width: int) ...@@ -1420,7 +1420,7 @@ def crop_mask(mask: torch.Tensor, top: int, left: int, height: int, width: int)
return output return output
@_register_kernel_internal(crop, datapoints.Video) @_register_kernel_internal(crop, tv_tensors.Video)
def crop_video(video: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor: def crop_video(video: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor:
return crop_image(video, top, left, height, width) return crop_image(video, top, left, height, width)
...@@ -1505,7 +1505,7 @@ def _perspective_coefficients( ...@@ -1505,7 +1505,7 @@ def _perspective_coefficients(
@_register_kernel_internal(perspective, torch.Tensor) @_register_kernel_internal(perspective, torch.Tensor)
@_register_kernel_internal(perspective, datapoints.Image) @_register_kernel_internal(perspective, tv_tensors.Image)
def perspective_image( def perspective_image(
image: torch.Tensor, image: torch.Tensor,
startpoints: Optional[List[List[int]]], startpoints: Optional[List[List[int]]],
...@@ -1568,7 +1568,7 @@ def _perspective_image_pil( ...@@ -1568,7 +1568,7 @@ def _perspective_image_pil(
def perspective_bounding_boxes( def perspective_bounding_boxes(
bounding_boxes: torch.Tensor, bounding_boxes: torch.Tensor,
format: datapoints.BoundingBoxFormat, format: tv_tensors.BoundingBoxFormat,
canvas_size: Tuple[int, int], canvas_size: Tuple[int, int],
startpoints: Optional[List[List[int]]], startpoints: Optional[List[List[int]]],
endpoints: Optional[List[List[int]]], endpoints: Optional[List[List[int]]],
...@@ -1582,7 +1582,7 @@ def perspective_bounding_boxes( ...@@ -1582,7 +1582,7 @@ def perspective_bounding_boxes(
original_shape = bounding_boxes.shape original_shape = bounding_boxes.shape
# TODO: first cast to float if bbox is int64 before convert_bounding_box_format # TODO: first cast to float if bbox is int64 before convert_bounding_box_format
bounding_boxes = ( bounding_boxes = (
convert_bounding_box_format(bounding_boxes, old_format=format, new_format=datapoints.BoundingBoxFormat.XYXY) convert_bounding_box_format(bounding_boxes, old_format=format, new_format=tv_tensors.BoundingBoxFormat.XYXY)
).reshape(-1, 4) ).reshape(-1, 4)
dtype = bounding_boxes.dtype if torch.is_floating_point(bounding_boxes) else torch.float32 dtype = bounding_boxes.dtype if torch.is_floating_point(bounding_boxes) else torch.float32
...@@ -1649,25 +1649,25 @@ def perspective_bounding_boxes( ...@@ -1649,25 +1649,25 @@ def perspective_bounding_boxes(
out_bboxes = clamp_bounding_boxes( out_bboxes = clamp_bounding_boxes(
torch.cat([out_bbox_mins, out_bbox_maxs], dim=1).to(bounding_boxes.dtype), torch.cat([out_bbox_mins, out_bbox_maxs], dim=1).to(bounding_boxes.dtype),
format=datapoints.BoundingBoxFormat.XYXY, format=tv_tensors.BoundingBoxFormat.XYXY,
canvas_size=canvas_size, canvas_size=canvas_size,
) )
# out_bboxes should be of shape [N boxes, 4] # out_bboxes should be of shape [N boxes, 4]
return convert_bounding_box_format( return convert_bounding_box_format(
out_bboxes, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format, inplace=True out_bboxes, old_format=tv_tensors.BoundingBoxFormat.XYXY, new_format=format, inplace=True
).reshape(original_shape) ).reshape(original_shape)
@_register_kernel_internal(perspective, datapoints.BoundingBoxes, datapoint_wrapper=False) @_register_kernel_internal(perspective, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False)
def _perspective_bounding_boxes_dispatch( def _perspective_bounding_boxes_dispatch(
inpt: datapoints.BoundingBoxes, inpt: tv_tensors.BoundingBoxes,
startpoints: Optional[List[List[int]]], startpoints: Optional[List[List[int]]],
endpoints: Optional[List[List[int]]], endpoints: Optional[List[List[int]]],
coefficients: Optional[List[float]] = None, coefficients: Optional[List[float]] = None,
**kwargs, **kwargs,
) -> datapoints.BoundingBoxes: ) -> tv_tensors.BoundingBoxes:
output = perspective_bounding_boxes( output = perspective_bounding_boxes(
inpt.as_subclass(torch.Tensor), inpt.as_subclass(torch.Tensor),
format=inpt.format, format=inpt.format,
...@@ -1676,7 +1676,7 @@ def _perspective_bounding_boxes_dispatch( ...@@ -1676,7 +1676,7 @@ def _perspective_bounding_boxes_dispatch(
endpoints=endpoints, endpoints=endpoints,
coefficients=coefficients, coefficients=coefficients,
) )
return datapoints.wrap(output, like=inpt) return tv_tensors.wrap(output, like=inpt)
def perspective_mask( def perspective_mask(
...@@ -1702,15 +1702,15 @@ def perspective_mask( ...@@ -1702,15 +1702,15 @@ def perspective_mask(
return output return output
@_register_kernel_internal(perspective, datapoints.Mask, datapoint_wrapper=False) @_register_kernel_internal(perspective, tv_tensors.Mask, tv_tensor_wrapper=False)
def _perspective_mask_dispatch( def _perspective_mask_dispatch(
inpt: datapoints.Mask, inpt: tv_tensors.Mask,
startpoints: Optional[List[List[int]]], startpoints: Optional[List[List[int]]],
endpoints: Optional[List[List[int]]], endpoints: Optional[List[List[int]]],
fill: _FillTypeJIT = None, fill: _FillTypeJIT = None,
coefficients: Optional[List[float]] = None, coefficients: Optional[List[float]] = None,
**kwargs, **kwargs,
) -> datapoints.Mask: ) -> tv_tensors.Mask:
output = perspective_mask( output = perspective_mask(
inpt.as_subclass(torch.Tensor), inpt.as_subclass(torch.Tensor),
startpoints=startpoints, startpoints=startpoints,
...@@ -1718,10 +1718,10 @@ def _perspective_mask_dispatch( ...@@ -1718,10 +1718,10 @@ def _perspective_mask_dispatch(
fill=fill, fill=fill,
coefficients=coefficients, coefficients=coefficients,
) )
return datapoints.wrap(output, like=inpt) return tv_tensors.wrap(output, like=inpt)
@_register_kernel_internal(perspective, datapoints.Video) @_register_kernel_internal(perspective, tv_tensors.Video)
def perspective_video( def perspective_video(
video: torch.Tensor, video: torch.Tensor,
startpoints: Optional[List[List[int]]], startpoints: Optional[List[List[int]]],
...@@ -1755,7 +1755,7 @@ elastic_transform = elastic ...@@ -1755,7 +1755,7 @@ elastic_transform = elastic
@_register_kernel_internal(elastic, torch.Tensor) @_register_kernel_internal(elastic, torch.Tensor)
@_register_kernel_internal(elastic, datapoints.Image) @_register_kernel_internal(elastic, tv_tensors.Image)
def elastic_image( def elastic_image(
image: torch.Tensor, image: torch.Tensor,
displacement: torch.Tensor, displacement: torch.Tensor,
...@@ -1841,7 +1841,7 @@ def _create_identity_grid(size: Tuple[int, int], device: torch.device, dtype: to ...@@ -1841,7 +1841,7 @@ def _create_identity_grid(size: Tuple[int, int], device: torch.device, dtype: to
def elastic_bounding_boxes( def elastic_bounding_boxes(
bounding_boxes: torch.Tensor, bounding_boxes: torch.Tensor,
format: datapoints.BoundingBoxFormat, format: tv_tensors.BoundingBoxFormat,
canvas_size: Tuple[int, int], canvas_size: Tuple[int, int],
displacement: torch.Tensor, displacement: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -1864,7 +1864,7 @@ def elastic_bounding_boxes( ...@@ -1864,7 +1864,7 @@ def elastic_bounding_boxes(
original_shape = bounding_boxes.shape original_shape = bounding_boxes.shape
# TODO: first cast to float if bbox is int64 before convert_bounding_box_format # TODO: first cast to float if bbox is int64 before convert_bounding_box_format
bounding_boxes = ( bounding_boxes = (
convert_bounding_box_format(bounding_boxes, old_format=format, new_format=datapoints.BoundingBoxFormat.XYXY) convert_bounding_box_format(bounding_boxes, old_format=format, new_format=tv_tensors.BoundingBoxFormat.XYXY)
).reshape(-1, 4) ).reshape(-1, 4)
id_grid = _create_identity_grid(canvas_size, device=device, dtype=dtype) id_grid = _create_identity_grid(canvas_size, device=device, dtype=dtype)
...@@ -1887,23 +1887,23 @@ def elastic_bounding_boxes( ...@@ -1887,23 +1887,23 @@ def elastic_bounding_boxes(
out_bbox_mins, out_bbox_maxs = torch.aminmax(transformed_points, dim=1) out_bbox_mins, out_bbox_maxs = torch.aminmax(transformed_points, dim=1)
out_bboxes = clamp_bounding_boxes( out_bboxes = clamp_bounding_boxes(
torch.cat([out_bbox_mins, out_bbox_maxs], dim=1).to(bounding_boxes.dtype), torch.cat([out_bbox_mins, out_bbox_maxs], dim=1).to(bounding_boxes.dtype),
format=datapoints.BoundingBoxFormat.XYXY, format=tv_tensors.BoundingBoxFormat.XYXY,
canvas_size=canvas_size, canvas_size=canvas_size,
) )
return convert_bounding_box_format( return convert_bounding_box_format(
out_bboxes, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format, inplace=True out_bboxes, old_format=tv_tensors.BoundingBoxFormat.XYXY, new_format=format, inplace=True
).reshape(original_shape) ).reshape(original_shape)
@_register_kernel_internal(elastic, datapoints.BoundingBoxes, datapoint_wrapper=False) @_register_kernel_internal(elastic, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False)
def _elastic_bounding_boxes_dispatch( def _elastic_bounding_boxes_dispatch(
inpt: datapoints.BoundingBoxes, displacement: torch.Tensor, **kwargs inpt: tv_tensors.BoundingBoxes, displacement: torch.Tensor, **kwargs
) -> datapoints.BoundingBoxes: ) -> tv_tensors.BoundingBoxes:
output = elastic_bounding_boxes( output = elastic_bounding_boxes(
inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size, displacement=displacement inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size, displacement=displacement
) )
return datapoints.wrap(output, like=inpt) return tv_tensors.wrap(output, like=inpt)
def elastic_mask( def elastic_mask(
...@@ -1925,15 +1925,15 @@ def elastic_mask( ...@@ -1925,15 +1925,15 @@ def elastic_mask(
return output return output
@_register_kernel_internal(elastic, datapoints.Mask, datapoint_wrapper=False) @_register_kernel_internal(elastic, tv_tensors.Mask, tv_tensor_wrapper=False)
def _elastic_mask_dispatch( def _elastic_mask_dispatch(
inpt: datapoints.Mask, displacement: torch.Tensor, fill: _FillTypeJIT = None, **kwargs inpt: tv_tensors.Mask, displacement: torch.Tensor, fill: _FillTypeJIT = None, **kwargs
) -> datapoints.Mask: ) -> tv_tensors.Mask:
output = elastic_mask(inpt.as_subclass(torch.Tensor), displacement=displacement, fill=fill) output = elastic_mask(inpt.as_subclass(torch.Tensor), displacement=displacement, fill=fill)
return datapoints.wrap(output, like=inpt) return tv_tensors.wrap(output, like=inpt)
@_register_kernel_internal(elastic, datapoints.Video) @_register_kernel_internal(elastic, tv_tensors.Video)
def elastic_video( def elastic_video(
video: torch.Tensor, video: torch.Tensor,
displacement: torch.Tensor, displacement: torch.Tensor,
...@@ -1982,7 +1982,7 @@ def _center_crop_compute_crop_anchor( ...@@ -1982,7 +1982,7 @@ def _center_crop_compute_crop_anchor(
@_register_kernel_internal(center_crop, torch.Tensor) @_register_kernel_internal(center_crop, torch.Tensor)
@_register_kernel_internal(center_crop, datapoints.Image) @_register_kernel_internal(center_crop, tv_tensors.Image)
def center_crop_image(image: torch.Tensor, output_size: List[int]) -> torch.Tensor: def center_crop_image(image: torch.Tensor, output_size: List[int]) -> torch.Tensor:
crop_height, crop_width = _center_crop_parse_output_size(output_size) crop_height, crop_width = _center_crop_parse_output_size(output_size)
shape = image.shape shape = image.shape
...@@ -2021,7 +2021,7 @@ def _center_crop_image_pil(image: PIL.Image.Image, output_size: List[int]) -> PI ...@@ -2021,7 +2021,7 @@ def _center_crop_image_pil(image: PIL.Image.Image, output_size: List[int]) -> PI
def center_crop_bounding_boxes( def center_crop_bounding_boxes(
bounding_boxes: torch.Tensor, bounding_boxes: torch.Tensor,
format: datapoints.BoundingBoxFormat, format: tv_tensors.BoundingBoxFormat,
canvas_size: Tuple[int, int], canvas_size: Tuple[int, int],
output_size: List[int], output_size: List[int],
) -> Tuple[torch.Tensor, Tuple[int, int]]: ) -> Tuple[torch.Tensor, Tuple[int, int]]:
...@@ -2032,17 +2032,17 @@ def center_crop_bounding_boxes( ...@@ -2032,17 +2032,17 @@ def center_crop_bounding_boxes(
) )
@_register_kernel_internal(center_crop, datapoints.BoundingBoxes, datapoint_wrapper=False) @_register_kernel_internal(center_crop, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False)
def _center_crop_bounding_boxes_dispatch( def _center_crop_bounding_boxes_dispatch(
inpt: datapoints.BoundingBoxes, output_size: List[int] inpt: tv_tensors.BoundingBoxes, output_size: List[int]
) -> datapoints.BoundingBoxes: ) -> tv_tensors.BoundingBoxes:
output, canvas_size = center_crop_bounding_boxes( output, canvas_size = center_crop_bounding_boxes(
inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size, output_size=output_size inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size, output_size=output_size
) )
return datapoints.wrap(output, like=inpt, canvas_size=canvas_size) return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size)
@_register_kernel_internal(center_crop, datapoints.Mask) @_register_kernel_internal(center_crop, tv_tensors.Mask)
def center_crop_mask(mask: torch.Tensor, output_size: List[int]) -> torch.Tensor: def center_crop_mask(mask: torch.Tensor, output_size: List[int]) -> torch.Tensor:
if mask.ndim < 3: if mask.ndim < 3:
mask = mask.unsqueeze(0) mask = mask.unsqueeze(0)
...@@ -2058,7 +2058,7 @@ def center_crop_mask(mask: torch.Tensor, output_size: List[int]) -> torch.Tensor ...@@ -2058,7 +2058,7 @@ def center_crop_mask(mask: torch.Tensor, output_size: List[int]) -> torch.Tensor
return output return output
@_register_kernel_internal(center_crop, datapoints.Video) @_register_kernel_internal(center_crop, tv_tensors.Video)
def center_crop_video(video: torch.Tensor, output_size: List[int]) -> torch.Tensor: def center_crop_video(video: torch.Tensor, output_size: List[int]) -> torch.Tensor:
return center_crop_image(video, output_size) return center_crop_image(video, output_size)
...@@ -2102,7 +2102,7 @@ def resized_crop( ...@@ -2102,7 +2102,7 @@ def resized_crop(
@_register_kernel_internal(resized_crop, torch.Tensor) @_register_kernel_internal(resized_crop, torch.Tensor)
@_register_kernel_internal(resized_crop, datapoints.Image) @_register_kernel_internal(resized_crop, tv_tensors.Image)
def resized_crop_image( def resized_crop_image(
image: torch.Tensor, image: torch.Tensor,
top: int, top: int,
...@@ -2156,7 +2156,7 @@ def _resized_crop_image_pil_dispatch( ...@@ -2156,7 +2156,7 @@ def _resized_crop_image_pil_dispatch(
def resized_crop_bounding_boxes( def resized_crop_bounding_boxes(
bounding_boxes: torch.Tensor, bounding_boxes: torch.Tensor,
format: datapoints.BoundingBoxFormat, format: tv_tensors.BoundingBoxFormat,
top: int, top: int,
left: int, left: int,
height: int, height: int,
...@@ -2167,14 +2167,14 @@ def resized_crop_bounding_boxes( ...@@ -2167,14 +2167,14 @@ def resized_crop_bounding_boxes(
return resize_bounding_boxes(bounding_boxes, canvas_size=canvas_size, size=size) return resize_bounding_boxes(bounding_boxes, canvas_size=canvas_size, size=size)
@_register_kernel_internal(resized_crop, datapoints.BoundingBoxes, datapoint_wrapper=False) @_register_kernel_internal(resized_crop, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False)
def _resized_crop_bounding_boxes_dispatch( def _resized_crop_bounding_boxes_dispatch(
inpt: datapoints.BoundingBoxes, top: int, left: int, height: int, width: int, size: List[int], **kwargs inpt: tv_tensors.BoundingBoxes, top: int, left: int, height: int, width: int, size: List[int], **kwargs
) -> datapoints.BoundingBoxes: ) -> tv_tensors.BoundingBoxes:
output, canvas_size = resized_crop_bounding_boxes( output, canvas_size = resized_crop_bounding_boxes(
inpt.as_subclass(torch.Tensor), format=inpt.format, top=top, left=left, height=height, width=width, size=size inpt.as_subclass(torch.Tensor), format=inpt.format, top=top, left=left, height=height, width=width, size=size
) )
return datapoints.wrap(output, like=inpt, canvas_size=canvas_size) return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size)
def resized_crop_mask( def resized_crop_mask(
...@@ -2189,17 +2189,17 @@ def resized_crop_mask( ...@@ -2189,17 +2189,17 @@ def resized_crop_mask(
return resize_mask(mask, size) return resize_mask(mask, size)
@_register_kernel_internal(resized_crop, datapoints.Mask, datapoint_wrapper=False) @_register_kernel_internal(resized_crop, tv_tensors.Mask, tv_tensor_wrapper=False)
def _resized_crop_mask_dispatch( def _resized_crop_mask_dispatch(
inpt: datapoints.Mask, top: int, left: int, height: int, width: int, size: List[int], **kwargs inpt: tv_tensors.Mask, top: int, left: int, height: int, width: int, size: List[int], **kwargs
) -> datapoints.Mask: ) -> tv_tensors.Mask:
output = resized_crop_mask( output = resized_crop_mask(
inpt.as_subclass(torch.Tensor), top=top, left=left, height=height, width=width, size=size inpt.as_subclass(torch.Tensor), top=top, left=left, height=height, width=width, size=size
) )
return datapoints.wrap(output, like=inpt) return tv_tensors.wrap(output, like=inpt)
@_register_kernel_internal(resized_crop, datapoints.Video) @_register_kernel_internal(resized_crop, tv_tensors.Video)
def resized_crop_video( def resized_crop_video(
video: torch.Tensor, video: torch.Tensor,
top: int, top: int,
...@@ -2243,7 +2243,7 @@ def _parse_five_crop_size(size: List[int]) -> List[int]: ...@@ -2243,7 +2243,7 @@ def _parse_five_crop_size(size: List[int]) -> List[int]:
@_register_five_ten_crop_kernel_internal(five_crop, torch.Tensor) @_register_five_ten_crop_kernel_internal(five_crop, torch.Tensor)
@_register_five_ten_crop_kernel_internal(five_crop, datapoints.Image) @_register_five_ten_crop_kernel_internal(five_crop, tv_tensors.Image)
def five_crop_image( def five_crop_image(
image: torch.Tensor, size: List[int] image: torch.Tensor, size: List[int]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
...@@ -2281,7 +2281,7 @@ def _five_crop_image_pil( ...@@ -2281,7 +2281,7 @@ def _five_crop_image_pil(
return tl, tr, bl, br, center return tl, tr, bl, br, center
@_register_five_ten_crop_kernel_internal(five_crop, datapoints.Video) @_register_five_ten_crop_kernel_internal(five_crop, tv_tensors.Video)
def five_crop_video( def five_crop_video(
video: torch.Tensor, size: List[int] video: torch.Tensor, size: List[int]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
...@@ -2313,7 +2313,7 @@ def ten_crop( ...@@ -2313,7 +2313,7 @@ def ten_crop(
@_register_five_ten_crop_kernel_internal(ten_crop, torch.Tensor) @_register_five_ten_crop_kernel_internal(ten_crop, torch.Tensor)
@_register_five_ten_crop_kernel_internal(ten_crop, datapoints.Image) @_register_five_ten_crop_kernel_internal(ten_crop, tv_tensors.Image)
def ten_crop_image( def ten_crop_image(
image: torch.Tensor, size: List[int], vertical_flip: bool = False image: torch.Tensor, size: List[int], vertical_flip: bool = False
) -> Tuple[ ) -> Tuple[
...@@ -2367,7 +2367,7 @@ def _ten_crop_image_pil( ...@@ -2367,7 +2367,7 @@ def _ten_crop_image_pil(
return non_flipped + flipped return non_flipped + flipped
@_register_five_ten_crop_kernel_internal(ten_crop, datapoints.Video) @_register_five_ten_crop_kernel_internal(ten_crop, tv_tensors.Video)
def ten_crop_video( def ten_crop_video(
video: torch.Tensor, size: List[int], vertical_flip: bool = False video: torch.Tensor, size: List[int], vertical_flip: bool = False
) -> Tuple[ ) -> Tuple[
......
...@@ -2,9 +2,9 @@ from typing import List, Optional, Tuple ...@@ -2,9 +2,9 @@ from typing import List, Optional, Tuple
import PIL.Image import PIL.Image
import torch import torch
from torchvision import datapoints from torchvision import tv_tensors
from torchvision.datapoints import BoundingBoxFormat
from torchvision.transforms import _functional_pil as _FP from torchvision.transforms import _functional_pil as _FP
from torchvision.tv_tensors import BoundingBoxFormat
from torchvision.utils import _log_api_usage_once from torchvision.utils import _log_api_usage_once
...@@ -22,7 +22,7 @@ def get_dimensions(inpt: torch.Tensor) -> List[int]: ...@@ -22,7 +22,7 @@ def get_dimensions(inpt: torch.Tensor) -> List[int]:
@_register_kernel_internal(get_dimensions, torch.Tensor) @_register_kernel_internal(get_dimensions, torch.Tensor)
@_register_kernel_internal(get_dimensions, datapoints.Image, datapoint_wrapper=False) @_register_kernel_internal(get_dimensions, tv_tensors.Image, tv_tensor_wrapper=False)
def get_dimensions_image(image: torch.Tensor) -> List[int]: def get_dimensions_image(image: torch.Tensor) -> List[int]:
chw = list(image.shape[-3:]) chw = list(image.shape[-3:])
ndims = len(chw) ndims = len(chw)
...@@ -38,7 +38,7 @@ def get_dimensions_image(image: torch.Tensor) -> List[int]: ...@@ -38,7 +38,7 @@ def get_dimensions_image(image: torch.Tensor) -> List[int]:
_get_dimensions_image_pil = _register_kernel_internal(get_dimensions, PIL.Image.Image)(_FP.get_dimensions) _get_dimensions_image_pil = _register_kernel_internal(get_dimensions, PIL.Image.Image)(_FP.get_dimensions)
@_register_kernel_internal(get_dimensions, datapoints.Video, datapoint_wrapper=False) @_register_kernel_internal(get_dimensions, tv_tensors.Video, tv_tensor_wrapper=False)
def get_dimensions_video(video: torch.Tensor) -> List[int]: def get_dimensions_video(video: torch.Tensor) -> List[int]:
return get_dimensions_image(video) return get_dimensions_image(video)
...@@ -54,7 +54,7 @@ def get_num_channels(inpt: torch.Tensor) -> int: ...@@ -54,7 +54,7 @@ def get_num_channels(inpt: torch.Tensor) -> int:
@_register_kernel_internal(get_num_channels, torch.Tensor) @_register_kernel_internal(get_num_channels, torch.Tensor)
@_register_kernel_internal(get_num_channels, datapoints.Image, datapoint_wrapper=False) @_register_kernel_internal(get_num_channels, tv_tensors.Image, tv_tensor_wrapper=False)
def get_num_channels_image(image: torch.Tensor) -> int: def get_num_channels_image(image: torch.Tensor) -> int:
chw = image.shape[-3:] chw = image.shape[-3:]
ndims = len(chw) ndims = len(chw)
...@@ -69,7 +69,7 @@ def get_num_channels_image(image: torch.Tensor) -> int: ...@@ -69,7 +69,7 @@ def get_num_channels_image(image: torch.Tensor) -> int:
_get_num_channels_image_pil = _register_kernel_internal(get_num_channels, PIL.Image.Image)(_FP.get_image_num_channels) _get_num_channels_image_pil = _register_kernel_internal(get_num_channels, PIL.Image.Image)(_FP.get_image_num_channels)
@_register_kernel_internal(get_num_channels, datapoints.Video, datapoint_wrapper=False) @_register_kernel_internal(get_num_channels, tv_tensors.Video, tv_tensor_wrapper=False)
def get_num_channels_video(video: torch.Tensor) -> int: def get_num_channels_video(video: torch.Tensor) -> int:
return get_num_channels_image(video) return get_num_channels_image(video)
...@@ -90,7 +90,7 @@ def get_size(inpt: torch.Tensor) -> List[int]: ...@@ -90,7 +90,7 @@ def get_size(inpt: torch.Tensor) -> List[int]:
@_register_kernel_internal(get_size, torch.Tensor) @_register_kernel_internal(get_size, torch.Tensor)
@_register_kernel_internal(get_size, datapoints.Image, datapoint_wrapper=False) @_register_kernel_internal(get_size, tv_tensors.Image, tv_tensor_wrapper=False)
def get_size_image(image: torch.Tensor) -> List[int]: def get_size_image(image: torch.Tensor) -> List[int]:
hw = list(image.shape[-2:]) hw = list(image.shape[-2:])
ndims = len(hw) ndims = len(hw)
...@@ -106,18 +106,18 @@ def _get_size_image_pil(image: PIL.Image.Image) -> List[int]: ...@@ -106,18 +106,18 @@ def _get_size_image_pil(image: PIL.Image.Image) -> List[int]:
return [height, width] return [height, width]
@_register_kernel_internal(get_size, datapoints.Video, datapoint_wrapper=False) @_register_kernel_internal(get_size, tv_tensors.Video, tv_tensor_wrapper=False)
def get_size_video(video: torch.Tensor) -> List[int]: def get_size_video(video: torch.Tensor) -> List[int]:
return get_size_image(video) return get_size_image(video)
@_register_kernel_internal(get_size, datapoints.Mask, datapoint_wrapper=False) @_register_kernel_internal(get_size, tv_tensors.Mask, tv_tensor_wrapper=False)
def get_size_mask(mask: torch.Tensor) -> List[int]: def get_size_mask(mask: torch.Tensor) -> List[int]:
return get_size_image(mask) return get_size_image(mask)
@_register_kernel_internal(get_size, datapoints.BoundingBoxes, datapoint_wrapper=False) @_register_kernel_internal(get_size, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False)
def get_size_bounding_boxes(bounding_box: datapoints.BoundingBoxes) -> List[int]: def get_size_bounding_boxes(bounding_box: tv_tensors.BoundingBoxes) -> List[int]:
return list(bounding_box.canvas_size) return list(bounding_box.canvas_size)
...@@ -132,7 +132,7 @@ def get_num_frames(inpt: torch.Tensor) -> int: ...@@ -132,7 +132,7 @@ def get_num_frames(inpt: torch.Tensor) -> int:
@_register_kernel_internal(get_num_frames, torch.Tensor) @_register_kernel_internal(get_num_frames, torch.Tensor)
@_register_kernel_internal(get_num_frames, datapoints.Video, datapoint_wrapper=False) @_register_kernel_internal(get_num_frames, tv_tensors.Video, tv_tensor_wrapper=False)
def get_num_frames_video(video: torch.Tensor) -> int: def get_num_frames_video(video: torch.Tensor) -> int:
return video.shape[-4] return video.shape[-4]
...@@ -205,7 +205,7 @@ def convert_bounding_box_format( ...@@ -205,7 +205,7 @@ def convert_bounding_box_format(
) -> torch.Tensor: ) -> torch.Tensor:
"""[BETA] See :func:`~torchvision.transforms.v2.ConvertBoundingBoxFormat` for details.""" """[BETA] See :func:`~torchvision.transforms.v2.ConvertBoundingBoxFormat` for details."""
# This being a kernel / functional hybrid, we need an option to pass `old_format` explicitly for pure tensor # This being a kernel / functional hybrid, we need an option to pass `old_format` explicitly for pure tensor
# inputs as well as extract it from `datapoints.BoundingBoxes` inputs. However, putting a default value on # inputs as well as extract it from `tv_tensors.BoundingBoxes` inputs. However, putting a default value on
# `old_format` means we also need to put one on `new_format` to have syntactically correct Python. Here we mimic the # `old_format` means we also need to put one on `new_format` to have syntactically correct Python. Here we mimic the
# default error that would be thrown if `new_format` had no default value. # default error that would be thrown if `new_format` had no default value.
if new_format is None: if new_format is None:
...@@ -218,16 +218,16 @@ def convert_bounding_box_format( ...@@ -218,16 +218,16 @@ def convert_bounding_box_format(
if old_format is None: if old_format is None:
raise ValueError("For pure tensor inputs, `old_format` has to be passed.") raise ValueError("For pure tensor inputs, `old_format` has to be passed.")
return _convert_bounding_box_format(inpt, old_format=old_format, new_format=new_format, inplace=inplace) return _convert_bounding_box_format(inpt, old_format=old_format, new_format=new_format, inplace=inplace)
elif isinstance(inpt, datapoints.BoundingBoxes): elif isinstance(inpt, tv_tensors.BoundingBoxes):
if old_format is not None: if old_format is not None:
raise ValueError("For bounding box datapoint inputs, `old_format` must not be passed.") raise ValueError("For bounding box tv_tensor inputs, `old_format` must not be passed.")
output = _convert_bounding_box_format( output = _convert_bounding_box_format(
inpt.as_subclass(torch.Tensor), old_format=inpt.format, new_format=new_format, inplace=inplace inpt.as_subclass(torch.Tensor), old_format=inpt.format, new_format=new_format, inplace=inplace
) )
return datapoints.wrap(output, like=inpt, format=new_format) return tv_tensors.wrap(output, like=inpt, format=new_format)
else: else:
raise TypeError( raise TypeError(
f"Input can either be a plain tensor or a bounding box datapoint, but got {type(inpt)} instead." f"Input can either be a plain tensor or a bounding box tv_tensor, but got {type(inpt)} instead."
) )
...@@ -239,7 +239,7 @@ def _clamp_bounding_boxes( ...@@ -239,7 +239,7 @@ def _clamp_bounding_boxes(
in_dtype = bounding_boxes.dtype in_dtype = bounding_boxes.dtype
bounding_boxes = bounding_boxes.clone() if bounding_boxes.is_floating_point() else bounding_boxes.float() bounding_boxes = bounding_boxes.clone() if bounding_boxes.is_floating_point() else bounding_boxes.float()
xyxy_boxes = convert_bounding_box_format( xyxy_boxes = convert_bounding_box_format(
bounding_boxes, old_format=format, new_format=datapoints.BoundingBoxFormat.XYXY, inplace=True bounding_boxes, old_format=format, new_format=tv_tensors.BoundingBoxFormat.XYXY, inplace=True
) )
xyxy_boxes[..., 0::2].clamp_(min=0, max=canvas_size[1]) xyxy_boxes[..., 0::2].clamp_(min=0, max=canvas_size[1])
xyxy_boxes[..., 1::2].clamp_(min=0, max=canvas_size[0]) xyxy_boxes[..., 1::2].clamp_(min=0, max=canvas_size[0])
...@@ -263,12 +263,12 @@ def clamp_bounding_boxes( ...@@ -263,12 +263,12 @@ def clamp_bounding_boxes(
if format is None or canvas_size is None: if format is None or canvas_size is None:
raise ValueError("For pure tensor inputs, `format` and `canvas_size` has to be passed.") raise ValueError("For pure tensor inputs, `format` and `canvas_size` has to be passed.")
return _clamp_bounding_boxes(inpt, format=format, canvas_size=canvas_size) return _clamp_bounding_boxes(inpt, format=format, canvas_size=canvas_size)
elif isinstance(inpt, datapoints.BoundingBoxes): elif isinstance(inpt, tv_tensors.BoundingBoxes):
if format is not None or canvas_size is not None: if format is not None or canvas_size is not None:
raise ValueError("For bounding box datapoint inputs, `format` and `canvas_size` must not be passed.") raise ValueError("For bounding box tv_tensor inputs, `format` and `canvas_size` must not be passed.")
output = _clamp_bounding_boxes(inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size) output = _clamp_bounding_boxes(inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size)
return datapoints.wrap(output, like=inpt) return tv_tensors.wrap(output, like=inpt)
else: else:
raise TypeError( raise TypeError(
f"Input can either be a plain tensor or a bounding box datapoint, but got {type(inpt)} instead." f"Input can either be a plain tensor or a bounding box tv_tensor, but got {type(inpt)} instead."
) )
...@@ -5,7 +5,7 @@ import PIL.Image ...@@ -5,7 +5,7 @@ import PIL.Image
import torch import torch
from torch.nn.functional import conv2d, pad as torch_pad from torch.nn.functional import conv2d, pad as torch_pad
from torchvision import datapoints from torchvision import tv_tensors
from torchvision.transforms._functional_tensor import _max_value from torchvision.transforms._functional_tensor import _max_value
from torchvision.transforms.functional import pil_to_tensor, to_pil_image from torchvision.transforms.functional import pil_to_tensor, to_pil_image
...@@ -31,7 +31,7 @@ def normalize( ...@@ -31,7 +31,7 @@ def normalize(
@_register_kernel_internal(normalize, torch.Tensor) @_register_kernel_internal(normalize, torch.Tensor)
@_register_kernel_internal(normalize, datapoints.Image) @_register_kernel_internal(normalize, tv_tensors.Image)
def normalize_image(image: torch.Tensor, mean: List[float], std: List[float], inplace: bool = False) -> torch.Tensor: def normalize_image(image: torch.Tensor, mean: List[float], std: List[float], inplace: bool = False) -> torch.Tensor:
if not image.is_floating_point(): if not image.is_floating_point():
raise TypeError(f"Input tensor should be a float tensor. Got {image.dtype}.") raise TypeError(f"Input tensor should be a float tensor. Got {image.dtype}.")
...@@ -65,7 +65,7 @@ def normalize_image(image: torch.Tensor, mean: List[float], std: List[float], in ...@@ -65,7 +65,7 @@ def normalize_image(image: torch.Tensor, mean: List[float], std: List[float], in
return image.div_(std) return image.div_(std)
@_register_kernel_internal(normalize, datapoints.Video) @_register_kernel_internal(normalize, tv_tensors.Video)
def normalize_video(video: torch.Tensor, mean: List[float], std: List[float], inplace: bool = False) -> torch.Tensor: def normalize_video(video: torch.Tensor, mean: List[float], std: List[float], inplace: bool = False) -> torch.Tensor:
return normalize_image(video, mean, std, inplace=inplace) return normalize_image(video, mean, std, inplace=inplace)
...@@ -98,7 +98,7 @@ def _get_gaussian_kernel2d( ...@@ -98,7 +98,7 @@ def _get_gaussian_kernel2d(
@_register_kernel_internal(gaussian_blur, torch.Tensor) @_register_kernel_internal(gaussian_blur, torch.Tensor)
@_register_kernel_internal(gaussian_blur, datapoints.Image) @_register_kernel_internal(gaussian_blur, tv_tensors.Image)
def gaussian_blur_image( def gaussian_blur_image(
image: torch.Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None image: torch.Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -172,7 +172,7 @@ def _gaussian_blur_image_pil( ...@@ -172,7 +172,7 @@ def _gaussian_blur_image_pil(
return to_pil_image(output, mode=image.mode) return to_pil_image(output, mode=image.mode)
@_register_kernel_internal(gaussian_blur, datapoints.Video) @_register_kernel_internal(gaussian_blur, tv_tensors.Video)
def gaussian_blur_video( def gaussian_blur_video(
video: torch.Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None video: torch.Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -206,7 +206,7 @@ def _num_value_bits(dtype: torch.dtype) -> int: ...@@ -206,7 +206,7 @@ def _num_value_bits(dtype: torch.dtype) -> int:
@_register_kernel_internal(to_dtype, torch.Tensor) @_register_kernel_internal(to_dtype, torch.Tensor)
@_register_kernel_internal(to_dtype, datapoints.Image) @_register_kernel_internal(to_dtype, tv_tensors.Image)
def to_dtype_image(image: torch.Tensor, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor: def to_dtype_image(image: torch.Tensor, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor:
if image.dtype == dtype: if image.dtype == dtype:
...@@ -265,13 +265,13 @@ def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float32) ...@@ -265,13 +265,13 @@ def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float32)
return to_dtype_image(image, dtype=dtype, scale=True) return to_dtype_image(image, dtype=dtype, scale=True)
@_register_kernel_internal(to_dtype, datapoints.Video) @_register_kernel_internal(to_dtype, tv_tensors.Video)
def to_dtype_video(video: torch.Tensor, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor: def to_dtype_video(video: torch.Tensor, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor:
return to_dtype_image(video, dtype, scale=scale) return to_dtype_image(video, dtype, scale=scale)
@_register_kernel_internal(to_dtype, datapoints.BoundingBoxes, datapoint_wrapper=False) @_register_kernel_internal(to_dtype, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False)
@_register_kernel_internal(to_dtype, datapoints.Mask, datapoint_wrapper=False) @_register_kernel_internal(to_dtype, tv_tensors.Mask, tv_tensor_wrapper=False)
def _to_dtype_tensor_dispatch(inpt: torch.Tensor, dtype: torch.dtype, scale: bool = False) -> torch.Tensor: def _to_dtype_tensor_dispatch(inpt: torch.Tensor, dtype: torch.dtype, scale: bool = False) -> torch.Tensor:
# We don't need to unwrap and rewrap here, since Datapoint.to() preserves the type # We don't need to unwrap and rewrap here, since TVTensor.to() preserves the type
return inpt.to(dtype) return inpt.to(dtype)
import torch import torch
from torchvision import datapoints from torchvision import tv_tensors
from torchvision.utils import _log_api_usage_once from torchvision.utils import _log_api_usage_once
...@@ -19,7 +19,7 @@ def uniform_temporal_subsample(inpt: torch.Tensor, num_samples: int) -> torch.Te ...@@ -19,7 +19,7 @@ def uniform_temporal_subsample(inpt: torch.Tensor, num_samples: int) -> torch.Te
@_register_kernel_internal(uniform_temporal_subsample, torch.Tensor) @_register_kernel_internal(uniform_temporal_subsample, torch.Tensor)
@_register_kernel_internal(uniform_temporal_subsample, datapoints.Video) @_register_kernel_internal(uniform_temporal_subsample, tv_tensors.Video)
def uniform_temporal_subsample_video(video: torch.Tensor, num_samples: int) -> torch.Tensor: def uniform_temporal_subsample_video(video: torch.Tensor, num_samples: int) -> torch.Tensor:
# Reference: https://github.com/facebookresearch/pytorchvideo/blob/a0a131e/pytorchvideo/transforms/functional.py#L19 # Reference: https://github.com/facebookresearch/pytorchvideo/blob/a0a131e/pytorchvideo/transforms/functional.py#L19
t_max = video.shape[-4] - 1 t_max = video.shape[-4] - 1
......
...@@ -3,12 +3,12 @@ from typing import Union ...@@ -3,12 +3,12 @@ from typing import Union
import numpy as np import numpy as np
import PIL.Image import PIL.Image
import torch import torch
from torchvision import datapoints from torchvision import tv_tensors
from torchvision.transforms import functional as _F from torchvision.transforms import functional as _F
@torch.jit.unused @torch.jit.unused
def to_image(inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray]) -> datapoints.Image: def to_image(inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray]) -> tv_tensors.Image:
"""[BETA] See :class:`~torchvision.transforms.v2.ToImage` for details.""" """[BETA] See :class:`~torchvision.transforms.v2.ToImage` for details."""
if isinstance(inpt, np.ndarray): if isinstance(inpt, np.ndarray):
output = torch.from_numpy(inpt).permute((2, 0, 1)).contiguous() output = torch.from_numpy(inpt).permute((2, 0, 1)).contiguous()
...@@ -18,7 +18,7 @@ def to_image(inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray]) -> datapoin ...@@ -18,7 +18,7 @@ def to_image(inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray]) -> datapoin
output = inpt output = inpt
else: else:
raise TypeError(f"Input can either be a numpy array or a PIL image, but got {type(inpt)} instead.") raise TypeError(f"Input can either be a numpy array or a PIL image, but got {type(inpt)} instead.")
return datapoints.Image(output) return tv_tensors.Image(output)
to_pil_image = _F.to_pil_image to_pil_image = _F.to_pil_image
......
...@@ -2,21 +2,21 @@ import functools ...@@ -2,21 +2,21 @@ import functools
from typing import Any, Callable, Dict, List, Optional, Sequence, Type, Union from typing import Any, Callable, Dict, List, Optional, Sequence, Type, Union
import torch import torch
from torchvision import datapoints from torchvision import tv_tensors
_FillType = Union[int, float, Sequence[int], Sequence[float], None] _FillType = Union[int, float, Sequence[int], Sequence[float], None]
_FillTypeJIT = Optional[List[float]] _FillTypeJIT = Optional[List[float]]
def is_pure_tensor(inpt: Any) -> bool: def is_pure_tensor(inpt: Any) -> bool:
return isinstance(inpt, torch.Tensor) and not isinstance(inpt, datapoints.Datapoint) return isinstance(inpt, torch.Tensor) and not isinstance(inpt, tv_tensors.TVTensor)
# {functional: {input_type: type_specific_kernel}} # {functional: {input_type: type_specific_kernel}}
_KERNEL_REGISTRY: Dict[Callable, Dict[Type, Callable]] = {} _KERNEL_REGISTRY: Dict[Callable, Dict[Type, Callable]] = {}
def _kernel_datapoint_wrapper(kernel): def _kernel_tv_tensor_wrapper(kernel):
@functools.wraps(kernel) @functools.wraps(kernel)
def wrapper(inpt, *args, **kwargs): def wrapper(inpt, *args, **kwargs):
# If you're wondering whether we could / should get rid of this wrapper, # If you're wondering whether we could / should get rid of this wrapper,
...@@ -25,24 +25,24 @@ def _kernel_datapoint_wrapper(kernel): ...@@ -25,24 +25,24 @@ def _kernel_datapoint_wrapper(kernel):
# regardless of whether we override __torch_function__ in our base class # regardless of whether we override __torch_function__ in our base class
# or not. # or not.
# Also, even if we didn't call `as_subclass` here, we would still need # Also, even if we didn't call `as_subclass` here, we would still need
# this wrapper to call wrap(), because the Datapoint type would be # this wrapper to call wrap(), because the TVTensor type would be
# lost after the first operation due to our own __torch_function__ # lost after the first operation due to our own __torch_function__
# logic. # logic.
output = kernel(inpt.as_subclass(torch.Tensor), *args, **kwargs) output = kernel(inpt.as_subclass(torch.Tensor), *args, **kwargs)
return datapoints.wrap(output, like=inpt) return tv_tensors.wrap(output, like=inpt)
return wrapper return wrapper
def _register_kernel_internal(functional, input_type, *, datapoint_wrapper=True): def _register_kernel_internal(functional, input_type, *, tv_tensor_wrapper=True):
registry = _KERNEL_REGISTRY.setdefault(functional, {}) registry = _KERNEL_REGISTRY.setdefault(functional, {})
if input_type in registry: if input_type in registry:
raise ValueError(f"Functional {functional} already has a kernel registered for type {input_type}.") raise ValueError(f"Functional {functional} already has a kernel registered for type {input_type}.")
def decorator(kernel): def decorator(kernel):
registry[input_type] = ( registry[input_type] = (
_kernel_datapoint_wrapper(kernel) _kernel_tv_tensor_wrapper(kernel)
if issubclass(input_type, datapoints.Datapoint) and datapoint_wrapper if issubclass(input_type, tv_tensors.TVTensor) and tv_tensor_wrapper
else kernel else kernel
) )
return kernel return kernel
...@@ -62,14 +62,14 @@ def _name_to_functional(name): ...@@ -62,14 +62,14 @@ def _name_to_functional(name):
_BUILTIN_DATAPOINT_TYPES = { _BUILTIN_DATAPOINT_TYPES = {
obj for obj in datapoints.__dict__.values() if isinstance(obj, type) and issubclass(obj, datapoints.Datapoint) obj for obj in tv_tensors.__dict__.values() if isinstance(obj, type) and issubclass(obj, tv_tensors.TVTensor)
} }
def register_kernel(functional, datapoint_cls): def register_kernel(functional, tv_tensor_cls):
"""[BETA] Decorate a kernel to register it for a functional and a (custom) datapoint type. """[BETA] Decorate a kernel to register it for a functional and a (custom) tv_tensor type.
See :ref:`sphx_glr_auto_examples_transforms_plot_custom_datapoints.py` for usage See :ref:`sphx_glr_auto_examples_transforms_plot_custom_tv_tensors.py` for usage
details. details.
""" """
if isinstance(functional, str): if isinstance(functional, str):
...@@ -83,16 +83,16 @@ def register_kernel(functional, datapoint_cls): ...@@ -83,16 +83,16 @@ def register_kernel(functional, datapoint_cls):
f"but got {functional}." f"but got {functional}."
) )
if not (isinstance(datapoint_cls, type) and issubclass(datapoint_cls, datapoints.Datapoint)): if not (isinstance(tv_tensor_cls, type) and issubclass(tv_tensor_cls, tv_tensors.TVTensor)):
raise ValueError( raise ValueError(
f"Kernels can only be registered for subclasses of torchvision.datapoints.Datapoint, " f"Kernels can only be registered for subclasses of torchvision.tv_tensors.TVTensor, "
f"but got {datapoint_cls}." f"but got {tv_tensor_cls}."
) )
if datapoint_cls in _BUILTIN_DATAPOINT_TYPES: if tv_tensor_cls in _BUILTIN_DATAPOINT_TYPES:
raise ValueError(f"Kernels cannot be registered for the builtin datapoint classes, but got {datapoint_cls}") raise ValueError(f"Kernels cannot be registered for the builtin tv_tensor classes, but got {tv_tensor_cls}")
return _register_kernel_internal(functional, datapoint_cls, datapoint_wrapper=False) return _register_kernel_internal(functional, tv_tensor_cls, tv_tensor_wrapper=False)
def _get_kernel(functional, input_type, *, allow_passthrough=False): def _get_kernel(functional, input_type, *, allow_passthrough=False):
...@@ -103,10 +103,10 @@ def _get_kernel(functional, input_type, *, allow_passthrough=False): ...@@ -103,10 +103,10 @@ def _get_kernel(functional, input_type, *, allow_passthrough=False):
for cls in input_type.__mro__: for cls in input_type.__mro__:
if cls in registry: if cls in registry:
return registry[cls] return registry[cls]
elif cls is datapoints.Datapoint: elif cls is tv_tensors.TVTensor:
# We don't want user-defined datapoints to dispatch to the pure Tensor kernels, so we explicit stop the # We don't want user-defined tv_tensors to dispatch to the pure Tensor kernels, so we explicit stop the
# MRO traversal before hitting torch.Tensor. We can even stop at datapoints.Datapoint, since we don't # MRO traversal before hitting torch.Tensor. We can even stop at tv_tensors.TVTensor, since we don't
# allow kernels to be registered for datapoints.Datapoint anyway. # allow kernels to be registered for tv_tensors.TVTensor anyway.
break break
if allow_passthrough: if allow_passthrough:
...@@ -130,12 +130,12 @@ def _register_five_ten_crop_kernel_internal(functional, input_type): ...@@ -130,12 +130,12 @@ def _register_five_ten_crop_kernel_internal(functional, input_type):
def wrapper(inpt, *args, **kwargs): def wrapper(inpt, *args, **kwargs):
output = kernel(inpt, *args, **kwargs) output = kernel(inpt, *args, **kwargs)
container_type = type(output) container_type = type(output)
return container_type(datapoints.wrap(o, like=inpt) for o in output) return container_type(tv_tensors.wrap(o, like=inpt) for o in output)
return wrapper return wrapper
def decorator(kernel): def decorator(kernel):
registry[input_type] = wrap(kernel) if issubclass(input_type, datapoints.Datapoint) else kernel registry[input_type] = wrap(kernel) if issubclass(input_type, tv_tensors.TVTensor) else kernel
return kernel return kernel
return decorator return decorator
import torch import torch
from ._bounding_box import BoundingBoxes, BoundingBoxFormat from ._bounding_box import BoundingBoxes, BoundingBoxFormat
from ._datapoint import Datapoint
from ._image import Image from ._image import Image
from ._mask import Mask from ._mask import Mask
from ._torch_function_helpers import set_return_type from ._torch_function_helpers import set_return_type
from ._tv_tensor import TVTensor
from ._video import Video from ._video import Video
def wrap(wrappee, *, like, **kwargs): def wrap(wrappee, *, like, **kwargs):
"""[BETA] Convert a :class:`torch.Tensor` (``wrappee``) into the same :class:`~torchvision.datapoints.Datapoint` subclass as ``like``. """[BETA] Convert a :class:`torch.Tensor` (``wrappee``) into the same :class:`~torchvision.tv_tensors.TVTensor` subclass as ``like``.
If ``like`` is a :class:`~torchvision.datapoints.BoundingBoxes`, the ``format`` and ``canvas_size`` of If ``like`` is a :class:`~torchvision.tv_tensors.BoundingBoxes`, the ``format`` and ``canvas_size`` of
``like`` are assigned to ``wrappee``, unless they are passed as ``kwargs``. ``like`` are assigned to ``wrappee``, unless they are passed as ``kwargs``.
Args: Args:
wrappee (Tensor): The tensor to convert. wrappee (Tensor): The tensor to convert.
like (:class:`~torchvision.datapoints.Datapoint`): The reference. like (:class:`~torchvision.tv_tensors.TVTensor`): The reference.
``wrappee`` will be converted into the same subclass as ``like``. ``wrappee`` will be converted into the same subclass as ``like``.
kwargs: Can contain "format" and "canvas_size" if ``like`` is a :class:`~torchvision.datapoint.BoundingBoxes`. kwargs: Can contain "format" and "canvas_size" if ``like`` is a :class:`~torchvision.tv_tensor.BoundingBoxes`.
Ignored otherwise. Ignored otherwise.
""" """
if isinstance(like, BoundingBoxes): if isinstance(like, BoundingBoxes):
......
...@@ -6,7 +6,7 @@ from typing import Any, Mapping, Optional, Sequence, Tuple, Union ...@@ -6,7 +6,7 @@ from typing import Any, Mapping, Optional, Sequence, Tuple, Union
import torch import torch
from torch.utils._pytree import tree_flatten from torch.utils._pytree import tree_flatten
from ._datapoint import Datapoint from ._tv_tensor import TVTensor
class BoundingBoxFormat(Enum): class BoundingBoxFormat(Enum):
...@@ -24,13 +24,13 @@ class BoundingBoxFormat(Enum): ...@@ -24,13 +24,13 @@ class BoundingBoxFormat(Enum):
CXCYWH = "CXCYWH" CXCYWH = "CXCYWH"
class BoundingBoxes(Datapoint): class BoundingBoxes(TVTensor):
"""[BETA] :class:`torch.Tensor` subclass for bounding boxes. """[BETA] :class:`torch.Tensor` subclass for bounding boxes.
.. note:: .. note::
There should be only one :class:`~torchvision.datapoints.BoundingBoxes` There should be only one :class:`~torchvision.tv_tensors.BoundingBoxes`
instance per sample e.g. ``{"img": img, "bbox": BoundingBoxes(...)}``, instance per sample e.g. ``{"img": img, "bbox": BoundingBoxes(...)}``,
although one :class:`~torchvision.datapoints.BoundingBoxes` object can although one :class:`~torchvision.tv_tensors.BoundingBoxes` object can
contain multiple bounding boxes. contain multiple bounding boxes.
Args: Args:
......
...@@ -9,7 +9,7 @@ from collections import defaultdict ...@@ -9,7 +9,7 @@ from collections import defaultdict
import torch import torch
from torchvision import datapoints, datasets from torchvision import datasets, tv_tensors
from torchvision.transforms.v2 import functional as F from torchvision.transforms.v2 import functional as F
__all__ = ["wrap_dataset_for_transforms_v2"] __all__ = ["wrap_dataset_for_transforms_v2"]
...@@ -36,26 +36,26 @@ def wrap_dataset_for_transforms_v2(dataset, target_keys=None): ...@@ -36,26 +36,26 @@ def wrap_dataset_for_transforms_v2(dataset, target_keys=None):
* :class:`~torchvision.datasets.CocoDetection`: Instead of returning the target as list of dicts, the wrapper * :class:`~torchvision.datasets.CocoDetection`: Instead of returning the target as list of dicts, the wrapper
returns a dict of lists. In addition, the key-value-pairs ``"boxes"`` (in ``XYXY`` coordinate format), returns a dict of lists. In addition, the key-value-pairs ``"boxes"`` (in ``XYXY`` coordinate format),
``"masks"`` and ``"labels"`` are added and wrap the data in the corresponding ``torchvision.datapoints``. ``"masks"`` and ``"labels"`` are added and wrap the data in the corresponding ``torchvision.tv_tensors``.
The original keys are preserved. If ``target_keys`` is omitted, returns only the values for the The original keys are preserved. If ``target_keys`` is omitted, returns only the values for the
``"image_id"``, ``"boxes"``, and ``"labels"``. ``"image_id"``, ``"boxes"``, and ``"labels"``.
* :class:`~torchvision.datasets.VOCDetection`: The key-value-pairs ``"boxes"`` and ``"labels"`` are added to * :class:`~torchvision.datasets.VOCDetection`: The key-value-pairs ``"boxes"`` and ``"labels"`` are added to
the target and wrap the data in the corresponding ``torchvision.datapoints``. The original keys are the target and wrap the data in the corresponding ``torchvision.tv_tensors``. The original keys are
preserved. If ``target_keys`` is omitted, returns only the values for the ``"boxes"`` and ``"labels"``. preserved. If ``target_keys`` is omitted, returns only the values for the ``"boxes"`` and ``"labels"``.
* :class:`~torchvision.datasets.CelebA`: The target for ``target_type="bbox"`` is converted to the ``XYXY`` * :class:`~torchvision.datasets.CelebA`: The target for ``target_type="bbox"`` is converted to the ``XYXY``
coordinate format and wrapped into a :class:`~torchvision.datapoints.BoundingBoxes` datapoint. coordinate format and wrapped into a :class:`~torchvision.tv_tensors.BoundingBoxes` tv_tensor.
* :class:`~torchvision.datasets.Kitti`: Instead returning the target as list of dicts, the wrapper returns a * :class:`~torchvision.datasets.Kitti`: Instead returning the target as list of dicts, the wrapper returns a
dict of lists. In addition, the key-value-pairs ``"boxes"`` and ``"labels"`` are added and wrap the data dict of lists. In addition, the key-value-pairs ``"boxes"`` and ``"labels"`` are added and wrap the data
in the corresponding ``torchvision.datapoints``. The original keys are preserved. If ``target_keys`` is in the corresponding ``torchvision.tv_tensors``. The original keys are preserved. If ``target_keys`` is
omitted, returns only the values for the ``"boxes"`` and ``"labels"``. omitted, returns only the values for the ``"boxes"`` and ``"labels"``.
* :class:`~torchvision.datasets.OxfordIIITPet`: The target for ``target_type="segmentation"`` is wrapped into a * :class:`~torchvision.datasets.OxfordIIITPet`: The target for ``target_type="segmentation"`` is wrapped into a
:class:`~torchvision.datapoints.Mask` datapoint. :class:`~torchvision.tv_tensors.Mask` tv_tensor.
* :class:`~torchvision.datasets.Cityscapes`: The target for ``target_type="semantic"`` is wrapped into a * :class:`~torchvision.datasets.Cityscapes`: The target for ``target_type="semantic"`` is wrapped into a
:class:`~torchvision.datapoints.Mask` datapoint. The target for ``target_type="instance"`` is *replaced* by :class:`~torchvision.tv_tensors.Mask` tv_tensor. The target for ``target_type="instance"`` is *replaced* by
a dictionary with the key-value-pairs ``"masks"`` (as :class:`~torchvision.datapoints.Mask` datapoint) and a dictionary with the key-value-pairs ``"masks"`` (as :class:`~torchvision.tv_tensors.Mask` tv_tensor) and
``"labels"``. ``"labels"``.
* :class:`~torchvision.datasets.WIDERFace`: The value for key ``"bbox"`` in the target is converted to ``XYXY`` * :class:`~torchvision.datasets.WIDERFace`: The value for key ``"bbox"`` in the target is converted to ``XYXY``
coordinate format and wrapped into a :class:`~torchvision.datapoints.BoundingBoxes` datapoint. coordinate format and wrapped into a :class:`~torchvision.tv_tensors.BoundingBoxes` tv_tensor.
Image classification datasets Image classification datasets
...@@ -66,13 +66,13 @@ def wrap_dataset_for_transforms_v2(dataset, target_keys=None): ...@@ -66,13 +66,13 @@ def wrap_dataset_for_transforms_v2(dataset, target_keys=None):
Segmentation datasets, e.g. :class:`~torchvision.datasets.VOCSegmentation`, return a two-tuple of Segmentation datasets, e.g. :class:`~torchvision.datasets.VOCSegmentation`, return a two-tuple of
:class:`PIL.Image.Image`'s. This wrapper leaves the image as is (first item), while wrapping the :class:`PIL.Image.Image`'s. This wrapper leaves the image as is (first item), while wrapping the
segmentation mask into a :class:`~torchvision.datapoints.Mask` (second item). segmentation mask into a :class:`~torchvision.tv_tensors.Mask` (second item).
Video classification datasets Video classification datasets
Video classification datasets, e.g. :class:`~torchvision.datasets.Kinetics`, return a three-tuple containing a Video classification datasets, e.g. :class:`~torchvision.datasets.Kinetics`, return a three-tuple containing a
:class:`torch.Tensor` for the video and audio and a :class:`int` as label. This wrapper wraps the video into a :class:`torch.Tensor` for the video and audio and a :class:`int` as label. This wrapper wraps the video into a
:class:`~torchvision.datapoints.Video` while leaving the other items as is. :class:`~torchvision.tv_tensors.Video` while leaving the other items as is.
.. note:: .. note::
...@@ -98,12 +98,12 @@ def wrap_dataset_for_transforms_v2(dataset, target_keys=None): ...@@ -98,12 +98,12 @@ def wrap_dataset_for_transforms_v2(dataset, target_keys=None):
) )
# Imagine we have isinstance(dataset, datasets.ImageNet). This will create a new class with the name # Imagine we have isinstance(dataset, datasets.ImageNet). This will create a new class with the name
# "WrappedImageNet" at runtime that doubly inherits from VisionDatasetDatapointWrapper (see below) as well as the # "WrappedImageNet" at runtime that doubly inherits from VisionDatasetTVTensorWrapper (see below) as well as the
# original ImageNet class. This allows the user to do regular isinstance(wrapped_dataset, datasets.ImageNet) checks, # original ImageNet class. This allows the user to do regular isinstance(wrapped_dataset, datasets.ImageNet) checks,
# while we can still inject everything that we need. # while we can still inject everything that we need.
wrapped_dataset_cls = type(f"Wrapped{type(dataset).__name__}", (VisionDatasetDatapointWrapper, type(dataset)), {}) wrapped_dataset_cls = type(f"Wrapped{type(dataset).__name__}", (VisionDatasetTVTensorWrapper, type(dataset)), {})
# Since VisionDatasetDatapointWrapper comes before ImageNet in the MRO, calling the class hits # Since VisionDatasetTVTensorWrapper comes before ImageNet in the MRO, calling the class hits
# VisionDatasetDatapointWrapper.__init__ first. Since we are never doing super().__init__(...), the constructor of # VisionDatasetTVTensorWrapper.__init__ first. Since we are never doing super().__init__(...), the constructor of
# ImageNet is never hit. That is by design, since we don't want to create the dataset instance again, but rather # ImageNet is never hit. That is by design, since we don't want to create the dataset instance again, but rather
# have the existing instance as attribute on the new object. # have the existing instance as attribute on the new object.
return wrapped_dataset_cls(dataset, target_keys) return wrapped_dataset_cls(dataset, target_keys)
...@@ -125,7 +125,7 @@ class WrapperFactories(dict): ...@@ -125,7 +125,7 @@ class WrapperFactories(dict):
WRAPPER_FACTORIES = WrapperFactories() WRAPPER_FACTORIES = WrapperFactories()
class VisionDatasetDatapointWrapper: class VisionDatasetTVTensorWrapper:
def __init__(self, dataset, target_keys): def __init__(self, dataset, target_keys):
dataset_cls = type(dataset) dataset_cls = type(dataset)
...@@ -134,7 +134,7 @@ class VisionDatasetDatapointWrapper: ...@@ -134,7 +134,7 @@ class VisionDatasetDatapointWrapper:
f"This wrapper is meant for subclasses of `torchvision.datasets.VisionDataset`, " f"This wrapper is meant for subclasses of `torchvision.datasets.VisionDataset`, "
f"but got a '{dataset_cls.__name__}' instead.\n" f"but got a '{dataset_cls.__name__}' instead.\n"
f"For an example of how to perform the wrapping for custom datasets, see\n\n" f"For an example of how to perform the wrapping for custom datasets, see\n\n"
"https://pytorch.org/vision/main/auto_examples/plot_datapoints.html#do-i-have-to-wrap-the-output-of-the-datasets-myself" "https://pytorch.org/vision/main/auto_examples/plot_tv_tensors.html#do-i-have-to-wrap-the-output-of-the-datasets-myself"
) )
for cls in dataset_cls.mro(): for cls in dataset_cls.mro():
...@@ -221,7 +221,7 @@ def identity_wrapper_factory(dataset, target_keys): ...@@ -221,7 +221,7 @@ def identity_wrapper_factory(dataset, target_keys):
def pil_image_to_mask(pil_image): def pil_image_to_mask(pil_image):
return datapoints.Mask(pil_image) return tv_tensors.Mask(pil_image)
def parse_target_keys(target_keys, *, available, default): def parse_target_keys(target_keys, *, available, default):
...@@ -302,7 +302,7 @@ def video_classification_wrapper_factory(dataset, target_keys): ...@@ -302,7 +302,7 @@ def video_classification_wrapper_factory(dataset, target_keys):
def wrapper(idx, sample): def wrapper(idx, sample):
video, audio, label = sample video, audio, label = sample
video = datapoints.Video(video) video = tv_tensors.Video(video)
return video, audio, label return video, audio, label
...@@ -373,16 +373,16 @@ def coco_dectection_wrapper_factory(dataset, target_keys): ...@@ -373,16 +373,16 @@ def coco_dectection_wrapper_factory(dataset, target_keys):
if "boxes" in target_keys: if "boxes" in target_keys:
target["boxes"] = F.convert_bounding_box_format( target["boxes"] = F.convert_bounding_box_format(
datapoints.BoundingBoxes( tv_tensors.BoundingBoxes(
batched_target["bbox"], batched_target["bbox"],
format=datapoints.BoundingBoxFormat.XYWH, format=tv_tensors.BoundingBoxFormat.XYWH,
canvas_size=canvas_size, canvas_size=canvas_size,
), ),
new_format=datapoints.BoundingBoxFormat.XYXY, new_format=tv_tensors.BoundingBoxFormat.XYXY,
) )
if "masks" in target_keys: if "masks" in target_keys:
target["masks"] = datapoints.Mask( target["masks"] = tv_tensors.Mask(
torch.stack( torch.stack(
[ [
segmentation_to_mask(segmentation, canvas_size=canvas_size) segmentation_to_mask(segmentation, canvas_size=canvas_size)
...@@ -454,12 +454,12 @@ def voc_detection_wrapper_factory(dataset, target_keys): ...@@ -454,12 +454,12 @@ def voc_detection_wrapper_factory(dataset, target_keys):
target = {} target = {}
if "boxes" in target_keys: if "boxes" in target_keys:
target["boxes"] = datapoints.BoundingBoxes( target["boxes"] = tv_tensors.BoundingBoxes(
[ [
[int(bndbox[part]) for part in ("xmin", "ymin", "xmax", "ymax")] [int(bndbox[part]) for part in ("xmin", "ymin", "xmax", "ymax")]
for bndbox in batched_instances["bndbox"] for bndbox in batched_instances["bndbox"]
], ],
format=datapoints.BoundingBoxFormat.XYXY, format=tv_tensors.BoundingBoxFormat.XYXY,
canvas_size=(image.height, image.width), canvas_size=(image.height, image.width),
) )
...@@ -494,12 +494,12 @@ def celeba_wrapper_factory(dataset, target_keys): ...@@ -494,12 +494,12 @@ def celeba_wrapper_factory(dataset, target_keys):
target_types=dataset.target_type, target_types=dataset.target_type,
type_wrappers={ type_wrappers={
"bbox": lambda item: F.convert_bounding_box_format( "bbox": lambda item: F.convert_bounding_box_format(
datapoints.BoundingBoxes( tv_tensors.BoundingBoxes(
item, item,
format=datapoints.BoundingBoxFormat.XYWH, format=tv_tensors.BoundingBoxFormat.XYWH,
canvas_size=(image.height, image.width), canvas_size=(image.height, image.width),
), ),
new_format=datapoints.BoundingBoxFormat.XYXY, new_format=tv_tensors.BoundingBoxFormat.XYXY,
), ),
}, },
) )
...@@ -544,9 +544,9 @@ def kitti_wrapper_factory(dataset, target_keys): ...@@ -544,9 +544,9 @@ def kitti_wrapper_factory(dataset, target_keys):
target = {} target = {}
if "boxes" in target_keys: if "boxes" in target_keys:
target["boxes"] = datapoints.BoundingBoxes( target["boxes"] = tv_tensors.BoundingBoxes(
batched_target["bbox"], batched_target["bbox"],
format=datapoints.BoundingBoxFormat.XYXY, format=tv_tensors.BoundingBoxFormat.XYXY,
canvas_size=(image.height, image.width), canvas_size=(image.height, image.width),
) )
...@@ -596,7 +596,7 @@ def cityscapes_wrapper_factory(dataset, target_keys): ...@@ -596,7 +596,7 @@ def cityscapes_wrapper_factory(dataset, target_keys):
if label >= 1_000: if label >= 1_000:
label //= 1_000 label //= 1_000
labels.append(label) labels.append(label)
return dict(masks=datapoints.Mask(torch.stack(masks)), labels=torch.stack(labels)) return dict(masks=tv_tensors.Mask(torch.stack(masks)), labels=torch.stack(labels))
def wrapper(idx, sample): def wrapper(idx, sample):
image, target = sample image, target = sample
...@@ -641,10 +641,10 @@ def widerface_wrapper(dataset, target_keys): ...@@ -641,10 +641,10 @@ def widerface_wrapper(dataset, target_keys):
if "bbox" in target_keys: if "bbox" in target_keys:
target["bbox"] = F.convert_bounding_box_format( target["bbox"] = F.convert_bounding_box_format(
datapoints.BoundingBoxes( tv_tensors.BoundingBoxes(
target["bbox"], format=datapoints.BoundingBoxFormat.XYWH, canvas_size=(image.height, image.width) target["bbox"], format=tv_tensors.BoundingBoxFormat.XYWH, canvas_size=(image.height, image.width)
), ),
new_format=datapoints.BoundingBoxFormat.XYXY, new_format=tv_tensors.BoundingBoxFormat.XYXY,
) )
return image, target return image, target
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment