Unverified Commit d5f4cc38 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Datapoint -> TVTensor; datapoint[s] -> tv_tensor[s] (#7894)

parent b9447fdd
......@@ -5,13 +5,13 @@ from typing import Any, Optional, Sequence, Type, TypeVar, Union
import torch
from torch.utils._pytree import tree_map
from torchvision.datapoints._datapoint import Datapoint
from torchvision.tv_tensors._tv_tensor import TVTensor
L = TypeVar("L", bound="_LabelBase")
class _LabelBase(Datapoint):
class _LabelBase(TVTensor):
categories: Optional[Sequence[str]]
@classmethod
......
......@@ -7,7 +7,7 @@ import PIL.Image
import torch
from torch.nn.functional import one_hot
from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision import datapoints, transforms as _transforms
from torchvision import transforms as _transforms, tv_tensors
from torchvision.transforms.v2 import functional as F
from ._transform import _RandomApplyTransform, Transform
......@@ -91,10 +91,10 @@ class RandomErasing(_RandomApplyTransform):
self._log_ratio = torch.log(torch.tensor(self.ratio))
def _call_kernel(self, functional: Callable, inpt: Any, *args: Any, **kwargs: Any) -> Any:
if isinstance(inpt, (datapoints.BoundingBoxes, datapoints.Mask)):
if isinstance(inpt, (tv_tensors.BoundingBoxes, tv_tensors.Mask)):
warnings.warn(
f"{type(self).__name__}() is currently passing through inputs of type "
f"datapoints.{type(inpt).__name__}. This will likely change in the future."
f"tv_tensors.{type(inpt).__name__}. This will likely change in the future."
)
return super()._call_kernel(functional, inpt, *args, **kwargs)
......@@ -158,7 +158,7 @@ class _BaseMixUpCutMix(Transform):
flat_inputs, spec = tree_flatten(inputs)
needs_transform_list = self._needs_transform_list(flat_inputs)
if has_any(flat_inputs, PIL.Image.Image, datapoints.BoundingBoxes, datapoints.Mask):
if has_any(flat_inputs, PIL.Image.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask):
raise ValueError(f"{type(self).__name__}() does not support PIL images, bounding boxes and masks.")
labels = self._labels_getter(inputs)
......@@ -188,7 +188,7 @@ class _BaseMixUpCutMix(Transform):
return tree_unflatten(flat_outputs, spec)
def _check_image_or_video(self, inpt: torch.Tensor, *, batch_size: int):
expected_num_dims = 5 if isinstance(inpt, datapoints.Video) else 4
expected_num_dims = 5 if isinstance(inpt, tv_tensors.Video) else 4
if inpt.ndim != expected_num_dims:
raise ValueError(
f"Expected a batched input with {expected_num_dims} dims, but got {inpt.ndim} dimensions instead."
......@@ -242,13 +242,13 @@ class MixUp(_BaseMixUpCutMix):
if inpt is params["labels"]:
return self._mixup_label(inpt, lam=lam)
elif isinstance(inpt, (datapoints.Image, datapoints.Video)) or is_pure_tensor(inpt):
elif isinstance(inpt, (tv_tensors.Image, tv_tensors.Video)) or is_pure_tensor(inpt):
self._check_image_or_video(inpt, batch_size=params["batch_size"])
output = inpt.roll(1, 0).mul_(1.0 - lam).add_(inpt.mul(lam))
if isinstance(inpt, (datapoints.Image, datapoints.Video)):
output = datapoints.wrap(output, like=inpt)
if isinstance(inpt, (tv_tensors.Image, tv_tensors.Video)):
output = tv_tensors.wrap(output, like=inpt)
return output
else:
......@@ -309,7 +309,7 @@ class CutMix(_BaseMixUpCutMix):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if inpt is params["labels"]:
return self._mixup_label(inpt, lam=params["lam_adjusted"])
elif isinstance(inpt, (datapoints.Image, datapoints.Video)) or is_pure_tensor(inpt):
elif isinstance(inpt, (tv_tensors.Image, tv_tensors.Video)) or is_pure_tensor(inpt):
self._check_image_or_video(inpt, batch_size=params["batch_size"])
x1, y1, x2, y2 = params["box"]
......@@ -317,8 +317,8 @@ class CutMix(_BaseMixUpCutMix):
output = inpt.clone()
output[..., y1:y2, x1:x2] = rolled[..., y1:y2, x1:x2]
if isinstance(inpt, (datapoints.Image, datapoints.Video)):
output = datapoints.wrap(output, like=inpt)
if isinstance(inpt, (tv_tensors.Image, tv_tensors.Video)):
output = tv_tensors.wrap(output, like=inpt)
return output
else:
......
......@@ -5,7 +5,7 @@ import PIL.Image
import torch
from torch.utils._pytree import tree_flatten, tree_unflatten, TreeSpec
from torchvision import datapoints, transforms as _transforms
from torchvision import transforms as _transforms, tv_tensors
from torchvision.transforms import _functional_tensor as _FT
from torchvision.transforms.v2 import AutoAugmentPolicy, functional as F, InterpolationMode, Transform
from torchvision.transforms.v2.functional._geometry import _check_interpolation
......@@ -15,7 +15,7 @@ from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT
from ._utils import _get_fill, _setup_fill_arg, check_type, is_pure_tensor
ImageOrVideo = Union[torch.Tensor, PIL.Image.Image, datapoints.Image, datapoints.Video]
ImageOrVideo = Union[torch.Tensor, PIL.Image.Image, tv_tensors.Image, tv_tensors.Video]
class _AutoAugmentBase(Transform):
......@@ -46,7 +46,7 @@ class _AutoAugmentBase(Transform):
def _flatten_and_extract_image_or_video(
self,
inputs: Any,
unsupported_types: Tuple[Type, ...] = (datapoints.BoundingBoxes, datapoints.Mask),
unsupported_types: Tuple[Type, ...] = (tv_tensors.BoundingBoxes, tv_tensors.Mask),
) -> Tuple[Tuple[List[Any], TreeSpec, int], ImageOrVideo]:
flat_inputs, spec = tree_flatten(inputs if len(inputs) > 1 else inputs[0])
needs_transform_list = self._needs_transform_list(flat_inputs)
......@@ -56,10 +56,10 @@ class _AutoAugmentBase(Transform):
if needs_transform and check_type(
inpt,
(
datapoints.Image,
tv_tensors.Image,
PIL.Image.Image,
is_pure_tensor,
datapoints.Video,
tv_tensors.Video,
),
):
image_or_videos.append((idx, inpt))
......@@ -590,7 +590,7 @@ class AugMix(_AutoAugmentBase):
augmentation_space = self._AUGMENTATION_SPACE if self.all_ops else self._PARTIAL_AUGMENTATION_SPACE
orig_dims = list(image_or_video.shape)
expected_ndim = 5 if isinstance(orig_image_or_video, datapoints.Video) else 4
expected_ndim = 5 if isinstance(orig_image_or_video, tv_tensors.Video) else 4
batch = image_or_video.reshape([1] * max(expected_ndim - image_or_video.ndim, 0) + orig_dims)
batch_dims = [batch.size(0)] + [1] * (batch.ndim - 1)
......@@ -627,8 +627,8 @@ class AugMix(_AutoAugmentBase):
mix.add_(combined_weights[:, i].reshape(batch_dims) * aug)
mix = mix.reshape(orig_dims).to(dtype=image_or_video.dtype)
if isinstance(orig_image_or_video, (datapoints.Image, datapoints.Video)):
mix = datapoints.wrap(mix, like=orig_image_or_video)
if isinstance(orig_image_or_video, (tv_tensors.Image, tv_tensors.Video)):
mix = tv_tensors.wrap(mix, like=orig_image_or_video)
elif isinstance(orig_image_or_video, PIL.Image.Image):
mix = F.to_pil_image(mix)
......
......@@ -6,7 +6,7 @@ from typing import Any, Callable, cast, Dict, List, Literal, Optional, Sequence,
import PIL.Image
import torch
from torchvision import datapoints, transforms as _transforms
from torchvision import transforms as _transforms, tv_tensors
from torchvision.ops.boxes import box_iou
from torchvision.transforms.functional import _get_perspective_coeffs
from torchvision.transforms.v2 import functional as F, InterpolationMode, Transform
......@@ -36,8 +36,8 @@ class RandomHorizontalFlip(_RandomApplyTransform):
.. v2betastatus:: RandomHorizontalFlip transform
If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`,
:class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBoxes` etc.)
If the input is a :class:`torch.Tensor` or a ``TVTensor`` (e.g. :class:`~torchvision.tv_tensors.Image`,
:class:`~torchvision.tv_tensors.Video`, :class:`~torchvision.tv_tensors.BoundingBoxes` etc.)
it can have arbitrary number of leading batch dimensions. For example,
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
......@@ -56,8 +56,8 @@ class RandomVerticalFlip(_RandomApplyTransform):
.. v2betastatus:: RandomVerticalFlip transform
If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`,
:class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBoxes` etc.)
If the input is a :class:`torch.Tensor` or a ``TVTensor`` (e.g. :class:`~torchvision.tv_tensors.Image`,
:class:`~torchvision.tv_tensors.Video`, :class:`~torchvision.tv_tensors.BoundingBoxes` etc.)
it can have arbitrary number of leading batch dimensions. For example,
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
......@@ -76,8 +76,8 @@ class Resize(Transform):
.. v2betastatus:: Resize transform
If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`,
:class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBoxes` etc.)
If the input is a :class:`torch.Tensor` or a ``TVTensor`` (e.g. :class:`~torchvision.tv_tensors.Image`,
:class:`~torchvision.tv_tensors.Video`, :class:`~torchvision.tv_tensors.BoundingBoxes` etc.)
it can have arbitrary number of leading batch dimensions. For example,
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
......@@ -171,8 +171,8 @@ class CenterCrop(Transform):
.. v2betastatus:: CenterCrop transform
If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`,
:class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBoxes` etc.)
If the input is a :class:`torch.Tensor` or a ``TVTensor`` (e.g. :class:`~torchvision.tv_tensors.Image`,
:class:`~torchvision.tv_tensors.Video`, :class:`~torchvision.tv_tensors.BoundingBoxes` etc.)
it can have arbitrary number of leading batch dimensions. For example,
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
......@@ -199,8 +199,8 @@ class RandomResizedCrop(Transform):
.. v2betastatus:: RandomResizedCrop transform
If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`,
:class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBoxes` etc.)
If the input is a :class:`torch.Tensor` or a ``TVTensor`` (e.g. :class:`~torchvision.tv_tensors.Image`,
:class:`~torchvision.tv_tensors.Video`, :class:`~torchvision.tv_tensors.BoundingBoxes` etc.)
it can have arbitrary number of leading batch dimensions. For example,
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
......@@ -322,8 +322,8 @@ class FiveCrop(Transform):
.. v2betastatus:: FiveCrop transform
If the input is a :class:`torch.Tensor` or a :class:`~torchvision.datapoints.Image` or a
:class:`~torchvision.datapoints.Video` it can have arbitrary number of leading batch dimensions.
If the input is a :class:`torch.Tensor` or a :class:`~torchvision.tv_tensors.Image` or a
:class:`~torchvision.tv_tensors.Video` it can have arbitrary number of leading batch dimensions.
For example, the image can have ``[..., C, H, W]`` shape.
.. Note::
......@@ -338,15 +338,15 @@ class FiveCrop(Transform):
Example:
>>> class BatchMultiCrop(transforms.Transform):
... def forward(self, sample: Tuple[Tuple[Union[datapoints.Image, datapoints.Video], ...], int]):
... def forward(self, sample: Tuple[Tuple[Union[tv_tensors.Image, tv_tensors.Video], ...], int]):
... images_or_videos, labels = sample
... batch_size = len(images_or_videos)
... image_or_video = images_or_videos[0]
... images_or_videos = datapoints.wrap(torch.stack(images_or_videos), like=image_or_video)
... images_or_videos = tv_tensors.wrap(torch.stack(images_or_videos), like=image_or_video)
... labels = torch.full((batch_size,), label, device=images_or_videos.device)
... return images_or_videos, labels
...
>>> image = datapoints.Image(torch.rand(3, 256, 256))
>>> image = tv_tensors.Image(torch.rand(3, 256, 256))
>>> label = 3
>>> transform = transforms.Compose([transforms.FiveCrop(224), BatchMultiCrop()])
>>> images, labels = transform(image, label)
......@@ -363,10 +363,10 @@ class FiveCrop(Transform):
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
def _call_kernel(self, functional: Callable, inpt: Any, *args: Any, **kwargs: Any) -> Any:
if isinstance(inpt, (datapoints.BoundingBoxes, datapoints.Mask)):
if isinstance(inpt, (tv_tensors.BoundingBoxes, tv_tensors.Mask)):
warnings.warn(
f"{type(self).__name__}() is currently passing through inputs of type "
f"datapoints.{type(inpt).__name__}. This will likely change in the future."
f"tv_tensors.{type(inpt).__name__}. This will likely change in the future."
)
return super()._call_kernel(functional, inpt, *args, **kwargs)
......@@ -374,7 +374,7 @@ class FiveCrop(Transform):
return self._call_kernel(F.five_crop, inpt, self.size)
def _check_inputs(self, flat_inputs: List[Any]) -> None:
if has_any(flat_inputs, datapoints.BoundingBoxes, datapoints.Mask):
if has_any(flat_inputs, tv_tensors.BoundingBoxes, tv_tensors.Mask):
raise TypeError(f"BoundingBoxes'es and Mask's are not supported by {type(self).__name__}()")
......@@ -384,8 +384,8 @@ class TenCrop(Transform):
.. v2betastatus:: TenCrop transform
If the input is a :class:`torch.Tensor` or a :class:`~torchvision.datapoints.Image` or a
:class:`~torchvision.datapoints.Video` it can have arbitrary number of leading batch dimensions.
If the input is a :class:`torch.Tensor` or a :class:`~torchvision.tv_tensors.Image` or a
:class:`~torchvision.tv_tensors.Video` it can have arbitrary number of leading batch dimensions.
For example, the image can have ``[..., C, H, W]`` shape.
See :class:`~torchvision.transforms.v2.FiveCrop` for an example.
......@@ -410,15 +410,15 @@ class TenCrop(Transform):
self.vertical_flip = vertical_flip
def _call_kernel(self, functional: Callable, inpt: Any, *args: Any, **kwargs: Any) -> Any:
if isinstance(inpt, (datapoints.BoundingBoxes, datapoints.Mask)):
if isinstance(inpt, (tv_tensors.BoundingBoxes, tv_tensors.Mask)):
warnings.warn(
f"{type(self).__name__}() is currently passing through inputs of type "
f"datapoints.{type(inpt).__name__}. This will likely change in the future."
f"tv_tensors.{type(inpt).__name__}. This will likely change in the future."
)
return super()._call_kernel(functional, inpt, *args, **kwargs)
def _check_inputs(self, flat_inputs: List[Any]) -> None:
if has_any(flat_inputs, datapoints.BoundingBoxes, datapoints.Mask):
if has_any(flat_inputs, tv_tensors.BoundingBoxes, tv_tensors.Mask):
raise TypeError(f"BoundingBoxes'es and Mask's are not supported by {type(self).__name__}()")
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
......@@ -430,8 +430,8 @@ class Pad(Transform):
.. v2betastatus:: Pad transform
If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`,
:class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBoxes` etc.)
If the input is a :class:`torch.Tensor` or a ``TVTensor`` (e.g. :class:`~torchvision.tv_tensors.Image`,
:class:`~torchvision.tv_tensors.Video`, :class:`~torchvision.tv_tensors.BoundingBoxes` etc.)
it can have arbitrary number of leading batch dimensions. For example,
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
......@@ -447,7 +447,7 @@ class Pad(Transform):
fill (number or tuple or dict, optional): Pixel fill value used when the ``padding_mode`` is constant.
Default is 0. If a tuple of length 3, it is used to fill R, G, B channels respectively.
Fill value can be also a dictionary mapping data type to the fill value, e.g.
``fill={datapoints.Image: 127, datapoints.Mask: 0}`` where ``Image`` will be filled with 127 and
``fill={tv_tensors.Image: 127, tv_tensors.Mask: 0}`` where ``Image`` will be filled with 127 and
``Mask`` will be filled with 0.
padding_mode (str, optional): Type of padding. Should be: constant, edge, reflect or symmetric.
Default is "constant".
......@@ -515,8 +515,8 @@ class RandomZoomOut(_RandomApplyTransform):
output_width = input_width * r
output_height = input_height * r
If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`,
:class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBoxes` etc.)
If the input is a :class:`torch.Tensor` or a ``TVTensor`` (e.g. :class:`~torchvision.tv_tensors.Image`,
:class:`~torchvision.tv_tensors.Video`, :class:`~torchvision.tv_tensors.BoundingBoxes` etc.)
it can have arbitrary number of leading batch dimensions. For example,
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
......@@ -524,7 +524,7 @@ class RandomZoomOut(_RandomApplyTransform):
fill (number or tuple or dict, optional): Pixel fill value used when the ``padding_mode`` is constant.
Default is 0. If a tuple of length 3, it is used to fill R, G, B channels respectively.
Fill value can be also a dictionary mapping data type to the fill value, e.g.
``fill={datapoints.Image: 127, datapoints.Mask: 0}`` where ``Image`` will be filled with 127 and
``fill={tv_tensors.Image: 127, tv_tensors.Mask: 0}`` where ``Image`` will be filled with 127 and
``Mask`` will be filled with 0.
side_range (sequence of floats, optional): tuple of two floats defines minimum and maximum factors to
scale the input size.
......@@ -574,8 +574,8 @@ class RandomRotation(Transform):
.. v2betastatus:: RandomRotation transform
If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`,
:class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBoxes` etc.)
If the input is a :class:`torch.Tensor` or a ``TVTensor`` (e.g. :class:`~torchvision.tv_tensors.Image`,
:class:`~torchvision.tv_tensors.Video`, :class:`~torchvision.tv_tensors.BoundingBoxes` etc.)
it can have arbitrary number of leading batch dimensions. For example,
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
......@@ -596,7 +596,7 @@ class RandomRotation(Transform):
fill (number or tuple or dict, optional): Pixel fill value used when the ``padding_mode`` is constant.
Default is 0. If a tuple of length 3, it is used to fill R, G, B channels respectively.
Fill value can be also a dictionary mapping data type to the fill value, e.g.
``fill={datapoints.Image: 127, datapoints.Mask: 0}`` where ``Image`` will be filled with 127 and
``fill={tv_tensors.Image: 127, tv_tensors.Mask: 0}`` where ``Image`` will be filled with 127 and
``Mask`` will be filled with 0.
.. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters
......@@ -648,8 +648,8 @@ class RandomAffine(Transform):
.. v2betastatus:: RandomAffine transform
If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`,
:class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBoxes` etc.)
If the input is a :class:`torch.Tensor` or a ``TVTensor`` (e.g. :class:`~torchvision.tv_tensors.Image`,
:class:`~torchvision.tv_tensors.Video`, :class:`~torchvision.tv_tensors.BoundingBoxes` etc.)
it can have arbitrary number of leading batch dimensions. For example,
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
......@@ -676,7 +676,7 @@ class RandomAffine(Transform):
fill (number or tuple or dict, optional): Pixel fill value used when the ``padding_mode`` is constant.
Default is 0. If a tuple of length 3, it is used to fill R, G, B channels respectively.
Fill value can be also a dictionary mapping data type to the fill value, e.g.
``fill={datapoints.Image: 127, datapoints.Mask: 0}`` where ``Image`` will be filled with 127 and
``fill={tv_tensors.Image: 127, tv_tensors.Mask: 0}`` where ``Image`` will be filled with 127 and
``Mask`` will be filled with 0.
center (sequence, optional): Optional center of rotation, (x, y). Origin is the upper left corner.
Default is the center of the image.
......@@ -770,8 +770,8 @@ class RandomCrop(Transform):
.. v2betastatus:: RandomCrop transform
If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`,
:class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBoxes` etc.)
If the input is a :class:`torch.Tensor` or a ``TVTensor`` (e.g. :class:`~torchvision.tv_tensors.Image`,
:class:`~torchvision.tv_tensors.Video`, :class:`~torchvision.tv_tensors.BoundingBoxes` etc.)
it can have arbitrary number of leading batch dimensions. For example,
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
......@@ -794,7 +794,7 @@ class RandomCrop(Transform):
fill (number or tuple or dict, optional): Pixel fill value used when the ``padding_mode`` is constant.
Default is 0. If a tuple of length 3, it is used to fill R, G, B channels respectively.
Fill value can be also a dictionary mapping data type to the fill value, e.g.
``fill={datapoints.Image: 127, datapoints.Mask: 0}`` where ``Image`` will be filled with 127 and
``fill={tv_tensors.Image: 127, tv_tensors.Mask: 0}`` where ``Image`` will be filled with 127 and
``Mask`` will be filled with 0.
padding_mode (str, optional): Type of padding. Should be: constant, edge, reflect or symmetric.
Default is constant.
......@@ -927,8 +927,8 @@ class RandomPerspective(_RandomApplyTransform):
.. v2betastatus:: RandomPerspective transform
If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`,
:class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBoxes` etc.)
If the input is a :class:`torch.Tensor` or a ``TVTensor`` (e.g. :class:`~torchvision.tv_tensors.Image`,
:class:`~torchvision.tv_tensors.Video`, :class:`~torchvision.tv_tensors.BoundingBoxes` etc.)
it can have arbitrary number of leading batch dimensions. For example,
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
......@@ -943,7 +943,7 @@ class RandomPerspective(_RandomApplyTransform):
fill (number or tuple or dict, optional): Pixel fill value used when the ``padding_mode`` is constant.
Default is 0. If a tuple of length 3, it is used to fill R, G, B channels respectively.
Fill value can be also a dictionary mapping data type to the fill value, e.g.
``fill={datapoints.Image: 127, datapoints.Mask: 0}`` where ``Image`` will be filled with 127 and
``fill={tv_tensors.Image: 127, tv_tensors.Mask: 0}`` where ``Image`` will be filled with 127 and
``Mask`` will be filled with 0.
"""
......@@ -1014,8 +1014,8 @@ class ElasticTransform(Transform):
.. v2betastatus:: RandomPerspective transform
If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`,
:class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBoxes` etc.)
If the input is a :class:`torch.Tensor` or a ``TVTensor`` (e.g. :class:`~torchvision.tv_tensors.Image`,
:class:`~torchvision.tv_tensors.Video`, :class:`~torchvision.tv_tensors.BoundingBoxes` etc.)
it can have arbitrary number of leading batch dimensions. For example,
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
......@@ -1046,7 +1046,7 @@ class ElasticTransform(Transform):
fill (number or tuple or dict, optional): Pixel fill value used when the ``padding_mode`` is constant.
Default is 0. If a tuple of length 3, it is used to fill R, G, B channels respectively.
Fill value can be also a dictionary mapping data type to the fill value, e.g.
``fill={datapoints.Image: 127, datapoints.Mask: 0}`` where ``Image`` will be filled with 127 and
``fill={tv_tensors.Image: 127, tv_tensors.Mask: 0}`` where ``Image`` will be filled with 127 and
``Mask`` will be filled with 0.
"""
......@@ -1107,15 +1107,15 @@ class RandomIoUCrop(Transform):
.. v2betastatus:: RandomIoUCrop transform
This transformation requires an image or video data and ``datapoints.BoundingBoxes`` in the input.
This transformation requires an image or video data and ``tv_tensors.BoundingBoxes`` in the input.
.. warning::
In order to properly remove the bounding boxes below the IoU threshold, `RandomIoUCrop`
must be followed by :class:`~torchvision.transforms.v2.SanitizeBoundingBoxes`, either immediately
after or later in the transforms pipeline.
If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`,
:class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBoxes` etc.)
If the input is a :class:`torch.Tensor` or a ``TVTensor`` (e.g. :class:`~torchvision.tv_tensors.Image`,
:class:`~torchvision.tv_tensors.Video`, :class:`~torchvision.tv_tensors.BoundingBoxes` etc.)
it can have arbitrary number of leading batch dimensions. For example,
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
......@@ -1152,8 +1152,8 @@ class RandomIoUCrop(Transform):
def _check_inputs(self, flat_inputs: List[Any]) -> None:
if not (
has_all(flat_inputs, datapoints.BoundingBoxes)
and has_any(flat_inputs, PIL.Image.Image, datapoints.Image, is_pure_tensor)
has_all(flat_inputs, tv_tensors.BoundingBoxes)
and has_any(flat_inputs, PIL.Image.Image, tv_tensors.Image, is_pure_tensor)
):
raise TypeError(
f"{type(self).__name__}() requires input sample to contain tensor or PIL images "
......@@ -1193,7 +1193,7 @@ class RandomIoUCrop(Transform):
xyxy_bboxes = F.convert_bounding_box_format(
bboxes.as_subclass(torch.Tensor),
bboxes.format,
datapoints.BoundingBoxFormat.XYXY,
tv_tensors.BoundingBoxFormat.XYXY,
)
cx = 0.5 * (xyxy_bboxes[..., 0] + xyxy_bboxes[..., 2])
cy = 0.5 * (xyxy_bboxes[..., 1] + xyxy_bboxes[..., 3])
......@@ -1221,7 +1221,7 @@ class RandomIoUCrop(Transform):
F.crop, inpt, top=params["top"], left=params["left"], height=params["height"], width=params["width"]
)
if isinstance(output, datapoints.BoundingBoxes):
if isinstance(output, tv_tensors.BoundingBoxes):
# We "mark" the invalid boxes as degenreate, and they can be
# removed by a later call to SanitizeBoundingBoxes()
output[~params["is_within_crop_area"]] = 0
......@@ -1235,8 +1235,8 @@ class ScaleJitter(Transform):
.. v2betastatus:: ScaleJitter transform
If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`,
:class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBoxes` etc.)
If the input is a :class:`torch.Tensor` or a ``TVTensor`` (e.g. :class:`~torchvision.tv_tensors.Image`,
:class:`~torchvision.tv_tensors.Video`, :class:`~torchvision.tv_tensors.BoundingBoxes` etc.)
it can have arbitrary number of leading batch dimensions. For example,
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
......@@ -1303,8 +1303,8 @@ class RandomShortestSize(Transform):
.. v2betastatus:: RandomShortestSize transform
If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`,
:class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBoxes` etc.)
If the input is a :class:`torch.Tensor` or a ``TVTensor`` (e.g. :class:`~torchvision.tv_tensors.Image`,
:class:`~torchvision.tv_tensors.Video`, :class:`~torchvision.tv_tensors.BoundingBoxes` etc.)
it can have arbitrary number of leading batch dimensions. For example,
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
......@@ -1384,8 +1384,8 @@ class RandomResize(Transform):
output_width = size
output_height = size
If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`,
:class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBoxes` etc.)
If the input is a :class:`torch.Tensor` or a ``TVTensor`` (e.g. :class:`~torchvision.tv_tensors.Image`,
:class:`~torchvision.tv_tensors.Video`, :class:`~torchvision.tv_tensors.BoundingBoxes` etc.)
it can have arbitrary number of leading batch dimensions. For example,
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
......
from typing import Any, Dict, Union
from torchvision import datapoints
from torchvision import tv_tensors
from torchvision.transforms.v2 import functional as F, Transform
......@@ -10,20 +10,20 @@ class ConvertBoundingBoxFormat(Transform):
.. v2betastatus:: ConvertBoundingBoxFormat transform
Args:
format (str or datapoints.BoundingBoxFormat): output bounding box format.
Possible values are defined by :class:`~torchvision.datapoints.BoundingBoxFormat` and
format (str or tv_tensors.BoundingBoxFormat): output bounding box format.
Possible values are defined by :class:`~torchvision.tv_tensors.BoundingBoxFormat` and
string values match the enums, e.g. "XYXY" or "XYWH" etc.
"""
_transformed_types = (datapoints.BoundingBoxes,)
_transformed_types = (tv_tensors.BoundingBoxes,)
def __init__(self, format: Union[str, datapoints.BoundingBoxFormat]) -> None:
def __init__(self, format: Union[str, tv_tensors.BoundingBoxFormat]) -> None:
super().__init__()
if isinstance(format, str):
format = datapoints.BoundingBoxFormat[format]
format = tv_tensors.BoundingBoxFormat[format]
self.format = format
def _transform(self, inpt: datapoints.BoundingBoxes, params: Dict[str, Any]) -> datapoints.BoundingBoxes:
def _transform(self, inpt: tv_tensors.BoundingBoxes, params: Dict[str, Any]) -> tv_tensors.BoundingBoxes:
return F.convert_bounding_box_format(inpt, new_format=self.format) # type: ignore[return-value]
......@@ -36,7 +36,7 @@ class ClampBoundingBoxes(Transform):
"""
_transformed_types = (datapoints.BoundingBoxes,)
_transformed_types = (tv_tensors.BoundingBoxes,)
def _transform(self, inpt: datapoints.BoundingBoxes, params: Dict[str, Any]) -> datapoints.BoundingBoxes:
def _transform(self, inpt: tv_tensors.BoundingBoxes, params: Dict[str, Any]) -> tv_tensors.BoundingBoxes:
return F.clamp_bounding_boxes(inpt) # type: ignore[return-value]
......@@ -6,7 +6,7 @@ import PIL.Image
import torch
from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision import datapoints, transforms as _transforms
from torchvision import transforms as _transforms, tv_tensors
from torchvision.transforms.v2 import functional as F, Transform
from ._utils import _parse_labels_getter, _setup_float_or_seq, _setup_size, get_bounding_boxes, has_any, is_pure_tensor
......@@ -74,7 +74,7 @@ class LinearTransformation(Transform):
_v1_transform_cls = _transforms.LinearTransformation
_transformed_types = (is_pure_tensor, datapoints.Image, datapoints.Video)
_transformed_types = (is_pure_tensor, tv_tensors.Image, tv_tensors.Video)
def __init__(self, transformation_matrix: torch.Tensor, mean_vector: torch.Tensor):
super().__init__()
......@@ -129,8 +129,8 @@ class LinearTransformation(Transform):
output = torch.mm(flat_inpt, transformation_matrix)
output = output.reshape(shape)
if isinstance(inpt, (datapoints.Image, datapoints.Video)):
output = datapoints.wrap(output, like=inpt)
if isinstance(inpt, (tv_tensors.Image, tv_tensors.Video)):
output = tv_tensors.wrap(output, like=inpt)
return output
......@@ -227,12 +227,12 @@ class ToDtype(Transform):
``ToDtype(dtype, scale=True)`` is the recommended replacement for ``ConvertImageDtype(dtype)``.
Args:
dtype (``torch.dtype`` or dict of ``Datapoint`` -> ``torch.dtype``): The dtype to convert to.
dtype (``torch.dtype`` or dict of ``TVTensor`` -> ``torch.dtype``): The dtype to convert to.
If a ``torch.dtype`` is passed, e.g. ``torch.float32``, only images and videos will be converted
to that dtype: this is for compatibility with :class:`~torchvision.transforms.v2.ConvertImageDtype`.
A dict can be passed to specify per-datapoint conversions, e.g.
``dtype={datapoints.Image: torch.float32, datapoints.Mask: torch.int64, "others":None}``. The "others"
key can be used as a catch-all for any other datapoint type, and ``None`` means no conversion.
A dict can be passed to specify per-tv_tensor conversions, e.g.
``dtype={tv_tensors.Image: torch.float32, tv_tensors.Mask: torch.int64, "others":None}``. The "others"
key can be used as a catch-all for any other tv_tensor type, and ``None`` means no conversion.
scale (bool, optional): Whether to scale the values for images or videos. See :ref:`range_and_dtype`.
Default: ``False``.
"""
......@@ -250,12 +250,12 @@ class ToDtype(Transform):
if (
isinstance(dtype, dict)
and torch.Tensor in dtype
and any(cls in dtype for cls in [datapoints.Image, datapoints.Video])
and any(cls in dtype for cls in [tv_tensors.Image, tv_tensors.Video])
):
warnings.warn(
"Got `dtype` values for `torch.Tensor` and either `datapoints.Image` or `datapoints.Video`. "
"Got `dtype` values for `torch.Tensor` and either `tv_tensors.Image` or `tv_tensors.Video`. "
"Note that a plain `torch.Tensor` will *not* be transformed by this (or any other transformation) "
"in case a `datapoints.Image` or `datapoints.Video` is present in the input."
"in case a `tv_tensors.Image` or `tv_tensors.Video` is present in the input."
)
self.dtype = dtype
self.scale = scale
......@@ -264,7 +264,7 @@ class ToDtype(Transform):
if isinstance(self.dtype, torch.dtype):
# For consistency / BC with ConvertImageDtype, we only care about images or videos when dtype
# is a simple torch.dtype
if not is_pure_tensor(inpt) and not isinstance(inpt, (datapoints.Image, datapoints.Video)):
if not is_pure_tensor(inpt) and not isinstance(inpt, (tv_tensors.Image, tv_tensors.Video)):
return inpt
dtype: Optional[torch.dtype] = self.dtype
......@@ -278,10 +278,10 @@ class ToDtype(Transform):
"If you only need to convert the dtype of images or videos, you can just pass e.g. dtype=torch.float32. "
"If you're passing a dict as dtype, "
'you can use "others" as a catch-all key '
'e.g. dtype={datapoints.Mask: torch.int64, "others": None} to pass-through the rest of the inputs.'
'e.g. dtype={tv_tensors.Mask: torch.int64, "others": None} to pass-through the rest of the inputs.'
)
supports_scaling = is_pure_tensor(inpt) or isinstance(inpt, (datapoints.Image, datapoints.Video))
supports_scaling = is_pure_tensor(inpt) or isinstance(inpt, (tv_tensors.Image, tv_tensors.Video))
if dtype is None:
if self.scale and supports_scaling:
warnings.warn(
......@@ -389,10 +389,10 @@ class SanitizeBoundingBoxes(Transform):
)
boxes = cast(
datapoints.BoundingBoxes,
tv_tensors.BoundingBoxes,
F.convert_bounding_box_format(
boxes,
new_format=datapoints.BoundingBoxFormat.XYXY,
new_format=tv_tensors.BoundingBoxFormat.XYXY,
),
)
ws, hs = boxes[:, 2] - boxes[:, 0], boxes[:, 3] - boxes[:, 1]
......@@ -415,7 +415,7 @@ class SanitizeBoundingBoxes(Transform):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
is_label = inpt is not None and inpt is params["labels"]
is_bounding_boxes_or_mask = isinstance(inpt, (datapoints.BoundingBoxes, datapoints.Mask))
is_bounding_boxes_or_mask = isinstance(inpt, (tv_tensors.BoundingBoxes, tv_tensors.Mask))
if not (is_label or is_bounding_boxes_or_mask):
return inpt
......@@ -425,4 +425,4 @@ class SanitizeBoundingBoxes(Transform):
if is_label:
return output
return datapoints.wrap(output, like=inpt)
return tv_tensors.wrap(output, like=inpt)
......@@ -7,7 +7,7 @@ import PIL.Image
import torch
from torch import nn
from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision import datapoints
from torchvision import tv_tensors
from torchvision.transforms.v2._utils import check_type, has_any, is_pure_tensor
from torchvision.utils import _log_api_usage_once
......@@ -56,8 +56,8 @@ class Transform(nn.Module):
def _needs_transform_list(self, flat_inputs: List[Any]) -> List[bool]:
# Below is a heuristic on how to deal with pure tensor inputs:
# 1. Pure tensors, i.e. tensors that are not a datapoint, are passed through if there is an explicit image
# (`datapoints.Image` or `PIL.Image.Image`) or video (`datapoints.Video`) in the sample.
# 1. Pure tensors, i.e. tensors that are not a tv_tensor, are passed through if there is an explicit image
# (`tv_tensors.Image` or `PIL.Image.Image`) or video (`tv_tensors.Video`) in the sample.
# 2. If there is no explicit image or video in the sample, only the first encountered pure tensor is
# transformed as image, while the rest is passed through. The order is defined by the returned `flat_inputs`
# of `tree_flatten`, which recurses depth-first through the input.
......@@ -72,7 +72,7 @@ class Transform(nn.Module):
# However, this case wasn't supported by transforms v1 either, so there is no BC concern.
needs_transform_list = []
transform_pure_tensor = not has_any(flat_inputs, datapoints.Image, datapoints.Video, PIL.Image.Image)
transform_pure_tensor = not has_any(flat_inputs, tv_tensors.Image, tv_tensors.Video, PIL.Image.Image)
for inpt in flat_inputs:
needs_transform = True
......
......@@ -4,7 +4,7 @@ import numpy as np
import PIL.Image
import torch
from torchvision import datapoints
from torchvision import tv_tensors
from torchvision.transforms.v2 import functional as F, Transform
from torchvision.transforms.v2._utils import is_pure_tensor
......@@ -27,7 +27,7 @@ class PILToTensor(Transform):
class ToImage(Transform):
"""[BETA] Convert a tensor, ndarray, or PIL Image to :class:`~torchvision.datapoints.Image`
"""[BETA] Convert a tensor, ndarray, or PIL Image to :class:`~torchvision.tv_tensors.Image`
; this does not scale values.
.. v2betastatus:: ToImage transform
......@@ -39,7 +39,7 @@ class ToImage(Transform):
def _transform(
self, inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray], params: Dict[str, Any]
) -> datapoints.Image:
) -> tv_tensors.Image:
return F.to_image(inpt)
......@@ -66,7 +66,7 @@ class ToPILImage(Transform):
.. _PIL.Image mode: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#concept-modes
"""
_transformed_types = (is_pure_tensor, datapoints.Image, np.ndarray)
_transformed_types = (is_pure_tensor, tv_tensors.Image, np.ndarray)
def __init__(self, mode: Optional[str] = None) -> None:
super().__init__()
......@@ -79,14 +79,14 @@ class ToPILImage(Transform):
class ToPureTensor(Transform):
"""[BETA] Convert all datapoints to pure tensors, removing associated metadata (if any).
"""[BETA] Convert all tv_tensors to pure tensors, removing associated metadata (if any).
.. v2betastatus:: ToPureTensor transform
This doesn't scale or change the values, only the type.
"""
_transformed_types = (datapoints.Datapoint,)
_transformed_types = (tv_tensors.TVTensor,)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> torch.Tensor:
return inpt.as_subclass(torch.Tensor)
......@@ -9,7 +9,7 @@ from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple
import PIL.Image
import torch
from torchvision import datapoints
from torchvision import tv_tensors
from torchvision._utils import sequence_to_str
......@@ -149,10 +149,10 @@ def _parse_labels_getter(
raise ValueError(f"labels_getter should either be 'default', a callable, or None, but got {labels_getter}.")
def get_bounding_boxes(flat_inputs: List[Any]) -> datapoints.BoundingBoxes:
def get_bounding_boxes(flat_inputs: List[Any]) -> tv_tensors.BoundingBoxes:
# This assumes there is only one bbox per sample as per the general convention
try:
return next(inpt for inpt in flat_inputs if isinstance(inpt, datapoints.BoundingBoxes))
return next(inpt for inpt in flat_inputs if isinstance(inpt, tv_tensors.BoundingBoxes))
except StopIteration:
raise ValueError("No bounding boxes were found in the sample")
......@@ -161,7 +161,7 @@ def query_chw(flat_inputs: List[Any]) -> Tuple[int, int, int]:
chws = {
tuple(get_dimensions(inpt))
for inpt in flat_inputs
if check_type(inpt, (is_pure_tensor, datapoints.Image, PIL.Image.Image, datapoints.Video))
if check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video))
}
if not chws:
raise TypeError("No image or video was found in the sample")
......@@ -179,11 +179,11 @@ def query_size(flat_inputs: List[Any]) -> Tuple[int, int]:
inpt,
(
is_pure_tensor,
datapoints.Image,
tv_tensors.Image,
PIL.Image.Image,
datapoints.Video,
datapoints.Mask,
datapoints.BoundingBoxes,
tv_tensors.Video,
tv_tensors.Mask,
tv_tensors.BoundingBoxes,
),
)
}
......
import PIL.Image
import torch
from torchvision import datapoints
from torchvision import tv_tensors
from torchvision.transforms.functional import pil_to_tensor, to_pil_image
from torchvision.utils import _log_api_usage_once
......@@ -28,7 +28,7 @@ def erase(
@_register_kernel_internal(erase, torch.Tensor)
@_register_kernel_internal(erase, datapoints.Image)
@_register_kernel_internal(erase, tv_tensors.Image)
def erase_image(
image: torch.Tensor, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False
) -> torch.Tensor:
......@@ -48,7 +48,7 @@ def _erase_image_pil(
return to_pil_image(output, mode=image.mode)
@_register_kernel_internal(erase, datapoints.Video)
@_register_kernel_internal(erase, tv_tensors.Video)
def erase_video(
video: torch.Tensor, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False
) -> torch.Tensor:
......
......@@ -3,7 +3,7 @@ from typing import List
import PIL.Image
import torch
from torch.nn.functional import conv2d
from torchvision import datapoints
from torchvision import tv_tensors
from torchvision.transforms import _functional_pil as _FP
from torchvision.transforms._functional_tensor import _max_value
......@@ -47,7 +47,7 @@ def _rgb_to_grayscale_image(
@_register_kernel_internal(rgb_to_grayscale, torch.Tensor)
@_register_kernel_internal(rgb_to_grayscale, datapoints.Image)
@_register_kernel_internal(rgb_to_grayscale, tv_tensors.Image)
def rgb_to_grayscale_image(image: torch.Tensor, num_output_channels: int = 1) -> torch.Tensor:
if num_output_channels not in (1, 3):
raise ValueError(f"num_output_channels must be 1 or 3, got {num_output_channels}.")
......@@ -82,7 +82,7 @@ def adjust_brightness(inpt: torch.Tensor, brightness_factor: float) -> torch.Ten
@_register_kernel_internal(adjust_brightness, torch.Tensor)
@_register_kernel_internal(adjust_brightness, datapoints.Image)
@_register_kernel_internal(adjust_brightness, tv_tensors.Image)
def adjust_brightness_image(image: torch.Tensor, brightness_factor: float) -> torch.Tensor:
if brightness_factor < 0:
raise ValueError(f"brightness_factor ({brightness_factor}) is not non-negative.")
......@@ -102,7 +102,7 @@ def _adjust_brightness_image_pil(image: PIL.Image.Image, brightness_factor: floa
return _FP.adjust_brightness(image, brightness_factor=brightness_factor)
@_register_kernel_internal(adjust_brightness, datapoints.Video)
@_register_kernel_internal(adjust_brightness, tv_tensors.Video)
def adjust_brightness_video(video: torch.Tensor, brightness_factor: float) -> torch.Tensor:
return adjust_brightness_image(video, brightness_factor=brightness_factor)
......@@ -119,7 +119,7 @@ def adjust_saturation(inpt: torch.Tensor, saturation_factor: float) -> torch.Ten
@_register_kernel_internal(adjust_saturation, torch.Tensor)
@_register_kernel_internal(adjust_saturation, datapoints.Image)
@_register_kernel_internal(adjust_saturation, tv_tensors.Image)
def adjust_saturation_image(image: torch.Tensor, saturation_factor: float) -> torch.Tensor:
if saturation_factor < 0:
raise ValueError(f"saturation_factor ({saturation_factor}) is not non-negative.")
......@@ -141,7 +141,7 @@ def adjust_saturation_image(image: torch.Tensor, saturation_factor: float) -> to
_adjust_saturation_image_pil = _register_kernel_internal(adjust_saturation, PIL.Image.Image)(_FP.adjust_saturation)
@_register_kernel_internal(adjust_saturation, datapoints.Video)
@_register_kernel_internal(adjust_saturation, tv_tensors.Video)
def adjust_saturation_video(video: torch.Tensor, saturation_factor: float) -> torch.Tensor:
return adjust_saturation_image(video, saturation_factor=saturation_factor)
......@@ -158,7 +158,7 @@ def adjust_contrast(inpt: torch.Tensor, contrast_factor: float) -> torch.Tensor:
@_register_kernel_internal(adjust_contrast, torch.Tensor)
@_register_kernel_internal(adjust_contrast, datapoints.Image)
@_register_kernel_internal(adjust_contrast, tv_tensors.Image)
def adjust_contrast_image(image: torch.Tensor, contrast_factor: float) -> torch.Tensor:
if contrast_factor < 0:
raise ValueError(f"contrast_factor ({contrast_factor}) is not non-negative.")
......@@ -180,7 +180,7 @@ def adjust_contrast_image(image: torch.Tensor, contrast_factor: float) -> torch.
_adjust_contrast_image_pil = _register_kernel_internal(adjust_contrast, PIL.Image.Image)(_FP.adjust_contrast)
@_register_kernel_internal(adjust_contrast, datapoints.Video)
@_register_kernel_internal(adjust_contrast, tv_tensors.Video)
def adjust_contrast_video(video: torch.Tensor, contrast_factor: float) -> torch.Tensor:
return adjust_contrast_image(video, contrast_factor=contrast_factor)
......@@ -197,7 +197,7 @@ def adjust_sharpness(inpt: torch.Tensor, sharpness_factor: float) -> torch.Tenso
@_register_kernel_internal(adjust_sharpness, torch.Tensor)
@_register_kernel_internal(adjust_sharpness, datapoints.Image)
@_register_kernel_internal(adjust_sharpness, tv_tensors.Image)
def adjust_sharpness_image(image: torch.Tensor, sharpness_factor: float) -> torch.Tensor:
num_channels, height, width = image.shape[-3:]
if num_channels not in (1, 3):
......@@ -253,7 +253,7 @@ def adjust_sharpness_image(image: torch.Tensor, sharpness_factor: float) -> torc
_adjust_sharpness_image_pil = _register_kernel_internal(adjust_sharpness, PIL.Image.Image)(_FP.adjust_sharpness)
@_register_kernel_internal(adjust_sharpness, datapoints.Video)
@_register_kernel_internal(adjust_sharpness, tv_tensors.Video)
def adjust_sharpness_video(video: torch.Tensor, sharpness_factor: float) -> torch.Tensor:
return adjust_sharpness_image(video, sharpness_factor=sharpness_factor)
......@@ -340,7 +340,7 @@ def _hsv_to_rgb(img: torch.Tensor) -> torch.Tensor:
@_register_kernel_internal(adjust_hue, torch.Tensor)
@_register_kernel_internal(adjust_hue, datapoints.Image)
@_register_kernel_internal(adjust_hue, tv_tensors.Image)
def adjust_hue_image(image: torch.Tensor, hue_factor: float) -> torch.Tensor:
if not (-0.5 <= hue_factor <= 0.5):
raise ValueError(f"hue_factor ({hue_factor}) is not in [-0.5, 0.5].")
......@@ -371,7 +371,7 @@ def adjust_hue_image(image: torch.Tensor, hue_factor: float) -> torch.Tensor:
_adjust_hue_image_pil = _register_kernel_internal(adjust_hue, PIL.Image.Image)(_FP.adjust_hue)
@_register_kernel_internal(adjust_hue, datapoints.Video)
@_register_kernel_internal(adjust_hue, tv_tensors.Video)
def adjust_hue_video(video: torch.Tensor, hue_factor: float) -> torch.Tensor:
return adjust_hue_image(video, hue_factor=hue_factor)
......@@ -388,7 +388,7 @@ def adjust_gamma(inpt: torch.Tensor, gamma: float, gain: float = 1) -> torch.Ten
@_register_kernel_internal(adjust_gamma, torch.Tensor)
@_register_kernel_internal(adjust_gamma, datapoints.Image)
@_register_kernel_internal(adjust_gamma, tv_tensors.Image)
def adjust_gamma_image(image: torch.Tensor, gamma: float, gain: float = 1.0) -> torch.Tensor:
if gamma < 0:
raise ValueError("Gamma should be a non-negative real number")
......@@ -411,7 +411,7 @@ def adjust_gamma_image(image: torch.Tensor, gamma: float, gain: float = 1.0) ->
_adjust_gamma_image_pil = _register_kernel_internal(adjust_gamma, PIL.Image.Image)(_FP.adjust_gamma)
@_register_kernel_internal(adjust_gamma, datapoints.Video)
@_register_kernel_internal(adjust_gamma, tv_tensors.Video)
def adjust_gamma_video(video: torch.Tensor, gamma: float, gain: float = 1) -> torch.Tensor:
return adjust_gamma_image(video, gamma=gamma, gain=gain)
......@@ -428,7 +428,7 @@ def posterize(inpt: torch.Tensor, bits: int) -> torch.Tensor:
@_register_kernel_internal(posterize, torch.Tensor)
@_register_kernel_internal(posterize, datapoints.Image)
@_register_kernel_internal(posterize, tv_tensors.Image)
def posterize_image(image: torch.Tensor, bits: int) -> torch.Tensor:
if image.is_floating_point():
levels = 1 << bits
......@@ -445,7 +445,7 @@ def posterize_image(image: torch.Tensor, bits: int) -> torch.Tensor:
_posterize_image_pil = _register_kernel_internal(posterize, PIL.Image.Image)(_FP.posterize)
@_register_kernel_internal(posterize, datapoints.Video)
@_register_kernel_internal(posterize, tv_tensors.Video)
def posterize_video(video: torch.Tensor, bits: int) -> torch.Tensor:
return posterize_image(video, bits=bits)
......@@ -462,7 +462,7 @@ def solarize(inpt: torch.Tensor, threshold: float) -> torch.Tensor:
@_register_kernel_internal(solarize, torch.Tensor)
@_register_kernel_internal(solarize, datapoints.Image)
@_register_kernel_internal(solarize, tv_tensors.Image)
def solarize_image(image: torch.Tensor, threshold: float) -> torch.Tensor:
if threshold > _max_value(image.dtype):
raise TypeError(f"Threshold should be less or equal the maximum value of the dtype, but got {threshold}")
......@@ -473,7 +473,7 @@ def solarize_image(image: torch.Tensor, threshold: float) -> torch.Tensor:
_solarize_image_pil = _register_kernel_internal(solarize, PIL.Image.Image)(_FP.solarize)
@_register_kernel_internal(solarize, datapoints.Video)
@_register_kernel_internal(solarize, tv_tensors.Video)
def solarize_video(video: torch.Tensor, threshold: float) -> torch.Tensor:
return solarize_image(video, threshold=threshold)
......@@ -490,7 +490,7 @@ def autocontrast(inpt: torch.Tensor) -> torch.Tensor:
@_register_kernel_internal(autocontrast, torch.Tensor)
@_register_kernel_internal(autocontrast, datapoints.Image)
@_register_kernel_internal(autocontrast, tv_tensors.Image)
def autocontrast_image(image: torch.Tensor) -> torch.Tensor:
c = image.shape[-3]
if c not in [1, 3]:
......@@ -523,7 +523,7 @@ def autocontrast_image(image: torch.Tensor) -> torch.Tensor:
_autocontrast_image_pil = _register_kernel_internal(autocontrast, PIL.Image.Image)(_FP.autocontrast)
@_register_kernel_internal(autocontrast, datapoints.Video)
@_register_kernel_internal(autocontrast, tv_tensors.Video)
def autocontrast_video(video: torch.Tensor) -> torch.Tensor:
return autocontrast_image(video)
......@@ -540,7 +540,7 @@ def equalize(inpt: torch.Tensor) -> torch.Tensor:
@_register_kernel_internal(equalize, torch.Tensor)
@_register_kernel_internal(equalize, datapoints.Image)
@_register_kernel_internal(equalize, tv_tensors.Image)
def equalize_image(image: torch.Tensor) -> torch.Tensor:
if image.numel() == 0:
return image
......@@ -613,7 +613,7 @@ def equalize_image(image: torch.Tensor) -> torch.Tensor:
_equalize_image_pil = _register_kernel_internal(equalize, PIL.Image.Image)(_FP.equalize)
@_register_kernel_internal(equalize, datapoints.Video)
@_register_kernel_internal(equalize, tv_tensors.Video)
def equalize_video(video: torch.Tensor) -> torch.Tensor:
return equalize_image(video)
......@@ -630,7 +630,7 @@ def invert(inpt: torch.Tensor) -> torch.Tensor:
@_register_kernel_internal(invert, torch.Tensor)
@_register_kernel_internal(invert, datapoints.Image)
@_register_kernel_internal(invert, tv_tensors.Image)
def invert_image(image: torch.Tensor) -> torch.Tensor:
if image.is_floating_point():
return 1.0 - image
......@@ -644,7 +644,7 @@ def invert_image(image: torch.Tensor) -> torch.Tensor:
_invert_image_pil = _register_kernel_internal(invert, PIL.Image.Image)(_FP.invert)
@_register_kernel_internal(invert, datapoints.Video)
@_register_kernel_internal(invert, tv_tensors.Video)
def invert_video(video: torch.Tensor) -> torch.Tensor:
return invert_image(video)
......@@ -653,7 +653,7 @@ def permute_channels(inpt: torch.Tensor, permutation: List[int]) -> torch.Tensor
"""Permute the channels of the input according to the given permutation.
This function supports plain :class:`~torch.Tensor`'s, :class:`PIL.Image.Image`'s, and
:class:`torchvision.datapoints.Image` and :class:`torchvision.datapoints.Video`.
:class:`torchvision.tv_tensors.Image` and :class:`torchvision.tv_tensors.Video`.
Example:
>>> rgb_image = torch.rand(3, 256, 256)
......@@ -681,7 +681,7 @@ def permute_channels(inpt: torch.Tensor, permutation: List[int]) -> torch.Tensor
@_register_kernel_internal(permute_channels, torch.Tensor)
@_register_kernel_internal(permute_channels, datapoints.Image)
@_register_kernel_internal(permute_channels, tv_tensors.Image)
def permute_channels_image(image: torch.Tensor, permutation: List[int]) -> torch.Tensor:
shape = image.shape
num_channels, height, width = shape[-3:]
......@@ -704,6 +704,6 @@ def _permute_channels_image_pil(image: PIL.Image.Image, permutation: List[int])
return to_pil_image(permute_channels_image(pil_to_tensor(image), permutation=permutation))
@_register_kernel_internal(permute_channels, datapoints.Video)
@_register_kernel_internal(permute_channels, tv_tensors.Video)
def permute_channels_video(video: torch.Tensor, permutation: List[int]) -> torch.Tensor:
return permute_channels_image(video, permutation=permutation)
......@@ -7,7 +7,7 @@ import PIL.Image
import torch
from torch.nn.functional import grid_sample, interpolate, pad as torch_pad
from torchvision import datapoints
from torchvision import tv_tensors
from torchvision.transforms import _functional_pil as _FP
from torchvision.transforms._functional_tensor import _pad_symmetric
from torchvision.transforms.functional import (
......@@ -51,7 +51,7 @@ def horizontal_flip(inpt: torch.Tensor) -> torch.Tensor:
@_register_kernel_internal(horizontal_flip, torch.Tensor)
@_register_kernel_internal(horizontal_flip, datapoints.Image)
@_register_kernel_internal(horizontal_flip, tv_tensors.Image)
def horizontal_flip_image(image: torch.Tensor) -> torch.Tensor:
return image.flip(-1)
......@@ -61,37 +61,37 @@ def _horizontal_flip_image_pil(image: PIL.Image.Image) -> PIL.Image.Image:
return _FP.hflip(image)
@_register_kernel_internal(horizontal_flip, datapoints.Mask)
@_register_kernel_internal(horizontal_flip, tv_tensors.Mask)
def horizontal_flip_mask(mask: torch.Tensor) -> torch.Tensor:
return horizontal_flip_image(mask)
def horizontal_flip_bounding_boxes(
bounding_boxes: torch.Tensor, format: datapoints.BoundingBoxFormat, canvas_size: Tuple[int, int]
bounding_boxes: torch.Tensor, format: tv_tensors.BoundingBoxFormat, canvas_size: Tuple[int, int]
) -> torch.Tensor:
shape = bounding_boxes.shape
bounding_boxes = bounding_boxes.clone().reshape(-1, 4)
if format == datapoints.BoundingBoxFormat.XYXY:
if format == tv_tensors.BoundingBoxFormat.XYXY:
bounding_boxes[:, [2, 0]] = bounding_boxes[:, [0, 2]].sub_(canvas_size[1]).neg_()
elif format == datapoints.BoundingBoxFormat.XYWH:
elif format == tv_tensors.BoundingBoxFormat.XYWH:
bounding_boxes[:, 0].add_(bounding_boxes[:, 2]).sub_(canvas_size[1]).neg_()
else: # format == datapoints.BoundingBoxFormat.CXCYWH:
else: # format == tv_tensors.BoundingBoxFormat.CXCYWH:
bounding_boxes[:, 0].sub_(canvas_size[1]).neg_()
return bounding_boxes.reshape(shape)
@_register_kernel_internal(horizontal_flip, datapoints.BoundingBoxes, datapoint_wrapper=False)
def _horizontal_flip_bounding_boxes_dispatch(inpt: datapoints.BoundingBoxes) -> datapoints.BoundingBoxes:
@_register_kernel_internal(horizontal_flip, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False)
def _horizontal_flip_bounding_boxes_dispatch(inpt: tv_tensors.BoundingBoxes) -> tv_tensors.BoundingBoxes:
output = horizontal_flip_bounding_boxes(
inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size
)
return datapoints.wrap(output, like=inpt)
return tv_tensors.wrap(output, like=inpt)
@_register_kernel_internal(horizontal_flip, datapoints.Video)
@_register_kernel_internal(horizontal_flip, tv_tensors.Video)
def horizontal_flip_video(video: torch.Tensor) -> torch.Tensor:
return horizontal_flip_image(video)
......@@ -108,7 +108,7 @@ def vertical_flip(inpt: torch.Tensor) -> torch.Tensor:
@_register_kernel_internal(vertical_flip, torch.Tensor)
@_register_kernel_internal(vertical_flip, datapoints.Image)
@_register_kernel_internal(vertical_flip, tv_tensors.Image)
def vertical_flip_image(image: torch.Tensor) -> torch.Tensor:
return image.flip(-2)
......@@ -118,37 +118,37 @@ def _vertical_flip_image_pil(image: PIL.Image) -> PIL.Image:
return _FP.vflip(image)
@_register_kernel_internal(vertical_flip, datapoints.Mask)
@_register_kernel_internal(vertical_flip, tv_tensors.Mask)
def vertical_flip_mask(mask: torch.Tensor) -> torch.Tensor:
return vertical_flip_image(mask)
def vertical_flip_bounding_boxes(
bounding_boxes: torch.Tensor, format: datapoints.BoundingBoxFormat, canvas_size: Tuple[int, int]
bounding_boxes: torch.Tensor, format: tv_tensors.BoundingBoxFormat, canvas_size: Tuple[int, int]
) -> torch.Tensor:
shape = bounding_boxes.shape
bounding_boxes = bounding_boxes.clone().reshape(-1, 4)
if format == datapoints.BoundingBoxFormat.XYXY:
if format == tv_tensors.BoundingBoxFormat.XYXY:
bounding_boxes[:, [1, 3]] = bounding_boxes[:, [3, 1]].sub_(canvas_size[0]).neg_()
elif format == datapoints.BoundingBoxFormat.XYWH:
elif format == tv_tensors.BoundingBoxFormat.XYWH:
bounding_boxes[:, 1].add_(bounding_boxes[:, 3]).sub_(canvas_size[0]).neg_()
else: # format == datapoints.BoundingBoxFormat.CXCYWH:
else: # format == tv_tensors.BoundingBoxFormat.CXCYWH:
bounding_boxes[:, 1].sub_(canvas_size[0]).neg_()
return bounding_boxes.reshape(shape)
@_register_kernel_internal(vertical_flip, datapoints.BoundingBoxes, datapoint_wrapper=False)
def _vertical_flip_bounding_boxes_dispatch(inpt: datapoints.BoundingBoxes) -> datapoints.BoundingBoxes:
@_register_kernel_internal(vertical_flip, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False)
def _vertical_flip_bounding_boxes_dispatch(inpt: tv_tensors.BoundingBoxes) -> tv_tensors.BoundingBoxes:
output = vertical_flip_bounding_boxes(
inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size
)
return datapoints.wrap(output, like=inpt)
return tv_tensors.wrap(output, like=inpt)
@_register_kernel_internal(vertical_flip, datapoints.Video)
@_register_kernel_internal(vertical_flip, tv_tensors.Video)
def vertical_flip_video(video: torch.Tensor) -> torch.Tensor:
return vertical_flip_image(video)
......@@ -190,7 +190,7 @@ def resize(
@_register_kernel_internal(resize, torch.Tensor)
@_register_kernel_internal(resize, datapoints.Image)
@_register_kernel_internal(resize, tv_tensors.Image)
def resize_image(
image: torch.Tensor,
size: List[int],
......@@ -319,12 +319,12 @@ def resize_mask(mask: torch.Tensor, size: List[int], max_size: Optional[int] = N
return output
@_register_kernel_internal(resize, datapoints.Mask, datapoint_wrapper=False)
@_register_kernel_internal(resize, tv_tensors.Mask, tv_tensor_wrapper=False)
def _resize_mask_dispatch(
inpt: datapoints.Mask, size: List[int], max_size: Optional[int] = None, **kwargs: Any
) -> datapoints.Mask:
inpt: tv_tensors.Mask, size: List[int], max_size: Optional[int] = None, **kwargs: Any
) -> tv_tensors.Mask:
output = resize_mask(inpt.as_subclass(torch.Tensor), size, max_size=max_size)
return datapoints.wrap(output, like=inpt)
return tv_tensors.wrap(output, like=inpt)
def resize_bounding_boxes(
......@@ -345,17 +345,17 @@ def resize_bounding_boxes(
)
@_register_kernel_internal(resize, datapoints.BoundingBoxes, datapoint_wrapper=False)
@_register_kernel_internal(resize, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False)
def _resize_bounding_boxes_dispatch(
inpt: datapoints.BoundingBoxes, size: List[int], max_size: Optional[int] = None, **kwargs: Any
) -> datapoints.BoundingBoxes:
inpt: tv_tensors.BoundingBoxes, size: List[int], max_size: Optional[int] = None, **kwargs: Any
) -> tv_tensors.BoundingBoxes:
output, canvas_size = resize_bounding_boxes(
inpt.as_subclass(torch.Tensor), inpt.canvas_size, size, max_size=max_size
)
return datapoints.wrap(output, like=inpt, canvas_size=canvas_size)
return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size)
@_register_kernel_internal(resize, datapoints.Video)
@_register_kernel_internal(resize, tv_tensors.Video)
def resize_video(
video: torch.Tensor,
size: List[int],
......@@ -651,7 +651,7 @@ def _affine_grid(
@_register_kernel_internal(affine, torch.Tensor)
@_register_kernel_internal(affine, datapoints.Image)
@_register_kernel_internal(affine, tv_tensors.Image)
def affine_image(
image: torch.Tensor,
angle: Union[int, float],
......@@ -730,7 +730,7 @@ def _affine_image_pil(
def _affine_bounding_boxes_with_expand(
bounding_boxes: torch.Tensor,
format: datapoints.BoundingBoxFormat,
format: tv_tensors.BoundingBoxFormat,
canvas_size: Tuple[int, int],
angle: Union[int, float],
translate: List[float],
......@@ -749,7 +749,7 @@ def _affine_bounding_boxes_with_expand(
device = bounding_boxes.device
bounding_boxes = (
convert_bounding_box_format(
bounding_boxes, old_format=format, new_format=datapoints.BoundingBoxFormat.XYXY, inplace=True
bounding_boxes, old_format=format, new_format=tv_tensors.BoundingBoxFormat.XYXY, inplace=True
)
).reshape(-1, 4)
......@@ -808,9 +808,9 @@ def _affine_bounding_boxes_with_expand(
new_width, new_height = _compute_affine_output_size(affine_vector, width, height)
canvas_size = (new_height, new_width)
out_bboxes = clamp_bounding_boxes(out_bboxes, format=datapoints.BoundingBoxFormat.XYXY, canvas_size=canvas_size)
out_bboxes = clamp_bounding_boxes(out_bboxes, format=tv_tensors.BoundingBoxFormat.XYXY, canvas_size=canvas_size)
out_bboxes = convert_bounding_box_format(
out_bboxes, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format, inplace=True
out_bboxes, old_format=tv_tensors.BoundingBoxFormat.XYXY, new_format=format, inplace=True
).reshape(original_shape)
out_bboxes = out_bboxes.to(original_dtype)
......@@ -819,7 +819,7 @@ def _affine_bounding_boxes_with_expand(
def affine_bounding_boxes(
bounding_boxes: torch.Tensor,
format: datapoints.BoundingBoxFormat,
format: tv_tensors.BoundingBoxFormat,
canvas_size: Tuple[int, int],
angle: Union[int, float],
translate: List[float],
......@@ -841,16 +841,16 @@ def affine_bounding_boxes(
return out_box
@_register_kernel_internal(affine, datapoints.BoundingBoxes, datapoint_wrapper=False)
@_register_kernel_internal(affine, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False)
def _affine_bounding_boxes_dispatch(
inpt: datapoints.BoundingBoxes,
inpt: tv_tensors.BoundingBoxes,
angle: Union[int, float],
translate: List[float],
scale: float,
shear: List[float],
center: Optional[List[float]] = None,
**kwargs,
) -> datapoints.BoundingBoxes:
) -> tv_tensors.BoundingBoxes:
output = affine_bounding_boxes(
inpt.as_subclass(torch.Tensor),
format=inpt.format,
......@@ -861,7 +861,7 @@ def _affine_bounding_boxes_dispatch(
shear=shear,
center=center,
)
return datapoints.wrap(output, like=inpt)
return tv_tensors.wrap(output, like=inpt)
def affine_mask(
......@@ -896,9 +896,9 @@ def affine_mask(
return output
@_register_kernel_internal(affine, datapoints.Mask, datapoint_wrapper=False)
@_register_kernel_internal(affine, tv_tensors.Mask, tv_tensor_wrapper=False)
def _affine_mask_dispatch(
inpt: datapoints.Mask,
inpt: tv_tensors.Mask,
angle: Union[int, float],
translate: List[float],
scale: float,
......@@ -906,7 +906,7 @@ def _affine_mask_dispatch(
fill: _FillTypeJIT = None,
center: Optional[List[float]] = None,
**kwargs,
) -> datapoints.Mask:
) -> tv_tensors.Mask:
output = affine_mask(
inpt.as_subclass(torch.Tensor),
angle=angle,
......@@ -916,10 +916,10 @@ def _affine_mask_dispatch(
fill=fill,
center=center,
)
return datapoints.wrap(output, like=inpt)
return tv_tensors.wrap(output, like=inpt)
@_register_kernel_internal(affine, datapoints.Video)
@_register_kernel_internal(affine, tv_tensors.Video)
def affine_video(
video: torch.Tensor,
angle: Union[int, float],
......@@ -961,7 +961,7 @@ def rotate(
@_register_kernel_internal(rotate, torch.Tensor)
@_register_kernel_internal(rotate, datapoints.Image)
@_register_kernel_internal(rotate, tv_tensors.Image)
def rotate_image(
image: torch.Tensor,
angle: float,
......@@ -1027,7 +1027,7 @@ def _rotate_image_pil(
def rotate_bounding_boxes(
bounding_boxes: torch.Tensor,
format: datapoints.BoundingBoxFormat,
format: tv_tensors.BoundingBoxFormat,
canvas_size: Tuple[int, int],
angle: float,
expand: bool = False,
......@@ -1049,10 +1049,10 @@ def rotate_bounding_boxes(
)
@_register_kernel_internal(rotate, datapoints.BoundingBoxes, datapoint_wrapper=False)
@_register_kernel_internal(rotate, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False)
def _rotate_bounding_boxes_dispatch(
inpt: datapoints.BoundingBoxes, angle: float, expand: bool = False, center: Optional[List[float]] = None, **kwargs
) -> datapoints.BoundingBoxes:
inpt: tv_tensors.BoundingBoxes, angle: float, expand: bool = False, center: Optional[List[float]] = None, **kwargs
) -> tv_tensors.BoundingBoxes:
output, canvas_size = rotate_bounding_boxes(
inpt.as_subclass(torch.Tensor),
format=inpt.format,
......@@ -1061,7 +1061,7 @@ def _rotate_bounding_boxes_dispatch(
expand=expand,
center=center,
)
return datapoints.wrap(output, like=inpt, canvas_size=canvas_size)
return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size)
def rotate_mask(
......@@ -1092,20 +1092,20 @@ def rotate_mask(
return output
@_register_kernel_internal(rotate, datapoints.Mask, datapoint_wrapper=False)
@_register_kernel_internal(rotate, tv_tensors.Mask, tv_tensor_wrapper=False)
def _rotate_mask_dispatch(
inpt: datapoints.Mask,
inpt: tv_tensors.Mask,
angle: float,
expand: bool = False,
center: Optional[List[float]] = None,
fill: _FillTypeJIT = None,
**kwargs,
) -> datapoints.Mask:
) -> tv_tensors.Mask:
output = rotate_mask(inpt.as_subclass(torch.Tensor), angle=angle, expand=expand, fill=fill, center=center)
return datapoints.wrap(output, like=inpt)
return tv_tensors.wrap(output, like=inpt)
@_register_kernel_internal(rotate, datapoints.Video)
@_register_kernel_internal(rotate, tv_tensors.Video)
def rotate_video(
video: torch.Tensor,
angle: float,
......@@ -1158,7 +1158,7 @@ def _parse_pad_padding(padding: Union[int, List[int]]) -> List[int]:
@_register_kernel_internal(pad, torch.Tensor)
@_register_kernel_internal(pad, datapoints.Image)
@_register_kernel_internal(pad, tv_tensors.Image)
def pad_image(
image: torch.Tensor,
padding: List[int],
......@@ -1260,7 +1260,7 @@ def _pad_with_vector_fill(
_pad_image_pil = _register_kernel_internal(pad, PIL.Image.Image)(_FP.pad)
@_register_kernel_internal(pad, datapoints.Mask)
@_register_kernel_internal(pad, tv_tensors.Mask)
def pad_mask(
mask: torch.Tensor,
padding: List[int],
......@@ -1289,7 +1289,7 @@ def pad_mask(
def pad_bounding_boxes(
bounding_boxes: torch.Tensor,
format: datapoints.BoundingBoxFormat,
format: tv_tensors.BoundingBoxFormat,
canvas_size: Tuple[int, int],
padding: List[int],
padding_mode: str = "constant",
......@@ -1300,7 +1300,7 @@ def pad_bounding_boxes(
left, right, top, bottom = _parse_pad_padding(padding)
if format == datapoints.BoundingBoxFormat.XYXY:
if format == tv_tensors.BoundingBoxFormat.XYXY:
pad = [left, top, left, top]
else:
pad = [left, top, 0, 0]
......@@ -1314,10 +1314,10 @@ def pad_bounding_boxes(
return clamp_bounding_boxes(bounding_boxes, format=format, canvas_size=canvas_size), canvas_size
@_register_kernel_internal(pad, datapoints.BoundingBoxes, datapoint_wrapper=False)
@_register_kernel_internal(pad, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False)
def _pad_bounding_boxes_dispatch(
inpt: datapoints.BoundingBoxes, padding: List[int], padding_mode: str = "constant", **kwargs
) -> datapoints.BoundingBoxes:
inpt: tv_tensors.BoundingBoxes, padding: List[int], padding_mode: str = "constant", **kwargs
) -> tv_tensors.BoundingBoxes:
output, canvas_size = pad_bounding_boxes(
inpt.as_subclass(torch.Tensor),
format=inpt.format,
......@@ -1325,10 +1325,10 @@ def _pad_bounding_boxes_dispatch(
padding=padding,
padding_mode=padding_mode,
)
return datapoints.wrap(output, like=inpt, canvas_size=canvas_size)
return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size)
@_register_kernel_internal(pad, datapoints.Video)
@_register_kernel_internal(pad, tv_tensors.Video)
def pad_video(
video: torch.Tensor,
padding: List[int],
......@@ -1350,7 +1350,7 @@ def crop(inpt: torch.Tensor, top: int, left: int, height: int, width: int) -> to
@_register_kernel_internal(crop, torch.Tensor)
@_register_kernel_internal(crop, datapoints.Image)
@_register_kernel_internal(crop, tv_tensors.Image)
def crop_image(image: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor:
h, w = image.shape[-2:]
......@@ -1375,7 +1375,7 @@ _register_kernel_internal(crop, PIL.Image.Image)(_crop_image_pil)
def crop_bounding_boxes(
bounding_boxes: torch.Tensor,
format: datapoints.BoundingBoxFormat,
format: tv_tensors.BoundingBoxFormat,
top: int,
left: int,
height: int,
......@@ -1383,7 +1383,7 @@ def crop_bounding_boxes(
) -> Tuple[torch.Tensor, Tuple[int, int]]:
# Crop or implicit pad if left and/or top have negative values:
if format == datapoints.BoundingBoxFormat.XYXY:
if format == tv_tensors.BoundingBoxFormat.XYXY:
sub = [left, top, left, top]
else:
sub = [left, top, 0, 0]
......@@ -1394,17 +1394,17 @@ def crop_bounding_boxes(
return clamp_bounding_boxes(bounding_boxes, format=format, canvas_size=canvas_size), canvas_size
@_register_kernel_internal(crop, datapoints.BoundingBoxes, datapoint_wrapper=False)
@_register_kernel_internal(crop, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False)
def _crop_bounding_boxes_dispatch(
inpt: datapoints.BoundingBoxes, top: int, left: int, height: int, width: int
) -> datapoints.BoundingBoxes:
inpt: tv_tensors.BoundingBoxes, top: int, left: int, height: int, width: int
) -> tv_tensors.BoundingBoxes:
output, canvas_size = crop_bounding_boxes(
inpt.as_subclass(torch.Tensor), format=inpt.format, top=top, left=left, height=height, width=width
)
return datapoints.wrap(output, like=inpt, canvas_size=canvas_size)
return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size)
@_register_kernel_internal(crop, datapoints.Mask)
@_register_kernel_internal(crop, tv_tensors.Mask)
def crop_mask(mask: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor:
if mask.ndim < 3:
mask = mask.unsqueeze(0)
......@@ -1420,7 +1420,7 @@ def crop_mask(mask: torch.Tensor, top: int, left: int, height: int, width: int)
return output
@_register_kernel_internal(crop, datapoints.Video)
@_register_kernel_internal(crop, tv_tensors.Video)
def crop_video(video: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor:
return crop_image(video, top, left, height, width)
......@@ -1505,7 +1505,7 @@ def _perspective_coefficients(
@_register_kernel_internal(perspective, torch.Tensor)
@_register_kernel_internal(perspective, datapoints.Image)
@_register_kernel_internal(perspective, tv_tensors.Image)
def perspective_image(
image: torch.Tensor,
startpoints: Optional[List[List[int]]],
......@@ -1568,7 +1568,7 @@ def _perspective_image_pil(
def perspective_bounding_boxes(
bounding_boxes: torch.Tensor,
format: datapoints.BoundingBoxFormat,
format: tv_tensors.BoundingBoxFormat,
canvas_size: Tuple[int, int],
startpoints: Optional[List[List[int]]],
endpoints: Optional[List[List[int]]],
......@@ -1582,7 +1582,7 @@ def perspective_bounding_boxes(
original_shape = bounding_boxes.shape
# TODO: first cast to float if bbox is int64 before convert_bounding_box_format
bounding_boxes = (
convert_bounding_box_format(bounding_boxes, old_format=format, new_format=datapoints.BoundingBoxFormat.XYXY)
convert_bounding_box_format(bounding_boxes, old_format=format, new_format=tv_tensors.BoundingBoxFormat.XYXY)
).reshape(-1, 4)
dtype = bounding_boxes.dtype if torch.is_floating_point(bounding_boxes) else torch.float32
......@@ -1649,25 +1649,25 @@ def perspective_bounding_boxes(
out_bboxes = clamp_bounding_boxes(
torch.cat([out_bbox_mins, out_bbox_maxs], dim=1).to(bounding_boxes.dtype),
format=datapoints.BoundingBoxFormat.XYXY,
format=tv_tensors.BoundingBoxFormat.XYXY,
canvas_size=canvas_size,
)
# out_bboxes should be of shape [N boxes, 4]
return convert_bounding_box_format(
out_bboxes, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format, inplace=True
out_bboxes, old_format=tv_tensors.BoundingBoxFormat.XYXY, new_format=format, inplace=True
).reshape(original_shape)
@_register_kernel_internal(perspective, datapoints.BoundingBoxes, datapoint_wrapper=False)
@_register_kernel_internal(perspective, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False)
def _perspective_bounding_boxes_dispatch(
inpt: datapoints.BoundingBoxes,
inpt: tv_tensors.BoundingBoxes,
startpoints: Optional[List[List[int]]],
endpoints: Optional[List[List[int]]],
coefficients: Optional[List[float]] = None,
**kwargs,
) -> datapoints.BoundingBoxes:
) -> tv_tensors.BoundingBoxes:
output = perspective_bounding_boxes(
inpt.as_subclass(torch.Tensor),
format=inpt.format,
......@@ -1676,7 +1676,7 @@ def _perspective_bounding_boxes_dispatch(
endpoints=endpoints,
coefficients=coefficients,
)
return datapoints.wrap(output, like=inpt)
return tv_tensors.wrap(output, like=inpt)
def perspective_mask(
......@@ -1702,15 +1702,15 @@ def perspective_mask(
return output
@_register_kernel_internal(perspective, datapoints.Mask, datapoint_wrapper=False)
@_register_kernel_internal(perspective, tv_tensors.Mask, tv_tensor_wrapper=False)
def _perspective_mask_dispatch(
inpt: datapoints.Mask,
inpt: tv_tensors.Mask,
startpoints: Optional[List[List[int]]],
endpoints: Optional[List[List[int]]],
fill: _FillTypeJIT = None,
coefficients: Optional[List[float]] = None,
**kwargs,
) -> datapoints.Mask:
) -> tv_tensors.Mask:
output = perspective_mask(
inpt.as_subclass(torch.Tensor),
startpoints=startpoints,
......@@ -1718,10 +1718,10 @@ def _perspective_mask_dispatch(
fill=fill,
coefficients=coefficients,
)
return datapoints.wrap(output, like=inpt)
return tv_tensors.wrap(output, like=inpt)
@_register_kernel_internal(perspective, datapoints.Video)
@_register_kernel_internal(perspective, tv_tensors.Video)
def perspective_video(
video: torch.Tensor,
startpoints: Optional[List[List[int]]],
......@@ -1755,7 +1755,7 @@ elastic_transform = elastic
@_register_kernel_internal(elastic, torch.Tensor)
@_register_kernel_internal(elastic, datapoints.Image)
@_register_kernel_internal(elastic, tv_tensors.Image)
def elastic_image(
image: torch.Tensor,
displacement: torch.Tensor,
......@@ -1841,7 +1841,7 @@ def _create_identity_grid(size: Tuple[int, int], device: torch.device, dtype: to
def elastic_bounding_boxes(
bounding_boxes: torch.Tensor,
format: datapoints.BoundingBoxFormat,
format: tv_tensors.BoundingBoxFormat,
canvas_size: Tuple[int, int],
displacement: torch.Tensor,
) -> torch.Tensor:
......@@ -1864,7 +1864,7 @@ def elastic_bounding_boxes(
original_shape = bounding_boxes.shape
# TODO: first cast to float if bbox is int64 before convert_bounding_box_format
bounding_boxes = (
convert_bounding_box_format(bounding_boxes, old_format=format, new_format=datapoints.BoundingBoxFormat.XYXY)
convert_bounding_box_format(bounding_boxes, old_format=format, new_format=tv_tensors.BoundingBoxFormat.XYXY)
).reshape(-1, 4)
id_grid = _create_identity_grid(canvas_size, device=device, dtype=dtype)
......@@ -1887,23 +1887,23 @@ def elastic_bounding_boxes(
out_bbox_mins, out_bbox_maxs = torch.aminmax(transformed_points, dim=1)
out_bboxes = clamp_bounding_boxes(
torch.cat([out_bbox_mins, out_bbox_maxs], dim=1).to(bounding_boxes.dtype),
format=datapoints.BoundingBoxFormat.XYXY,
format=tv_tensors.BoundingBoxFormat.XYXY,
canvas_size=canvas_size,
)
return convert_bounding_box_format(
out_bboxes, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format, inplace=True
out_bboxes, old_format=tv_tensors.BoundingBoxFormat.XYXY, new_format=format, inplace=True
).reshape(original_shape)
@_register_kernel_internal(elastic, datapoints.BoundingBoxes, datapoint_wrapper=False)
@_register_kernel_internal(elastic, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False)
def _elastic_bounding_boxes_dispatch(
inpt: datapoints.BoundingBoxes, displacement: torch.Tensor, **kwargs
) -> datapoints.BoundingBoxes:
inpt: tv_tensors.BoundingBoxes, displacement: torch.Tensor, **kwargs
) -> tv_tensors.BoundingBoxes:
output = elastic_bounding_boxes(
inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size, displacement=displacement
)
return datapoints.wrap(output, like=inpt)
return tv_tensors.wrap(output, like=inpt)
def elastic_mask(
......@@ -1925,15 +1925,15 @@ def elastic_mask(
return output
@_register_kernel_internal(elastic, datapoints.Mask, datapoint_wrapper=False)
@_register_kernel_internal(elastic, tv_tensors.Mask, tv_tensor_wrapper=False)
def _elastic_mask_dispatch(
inpt: datapoints.Mask, displacement: torch.Tensor, fill: _FillTypeJIT = None, **kwargs
) -> datapoints.Mask:
inpt: tv_tensors.Mask, displacement: torch.Tensor, fill: _FillTypeJIT = None, **kwargs
) -> tv_tensors.Mask:
output = elastic_mask(inpt.as_subclass(torch.Tensor), displacement=displacement, fill=fill)
return datapoints.wrap(output, like=inpt)
return tv_tensors.wrap(output, like=inpt)
@_register_kernel_internal(elastic, datapoints.Video)
@_register_kernel_internal(elastic, tv_tensors.Video)
def elastic_video(
video: torch.Tensor,
displacement: torch.Tensor,
......@@ -1982,7 +1982,7 @@ def _center_crop_compute_crop_anchor(
@_register_kernel_internal(center_crop, torch.Tensor)
@_register_kernel_internal(center_crop, datapoints.Image)
@_register_kernel_internal(center_crop, tv_tensors.Image)
def center_crop_image(image: torch.Tensor, output_size: List[int]) -> torch.Tensor:
crop_height, crop_width = _center_crop_parse_output_size(output_size)
shape = image.shape
......@@ -2021,7 +2021,7 @@ def _center_crop_image_pil(image: PIL.Image.Image, output_size: List[int]) -> PI
def center_crop_bounding_boxes(
bounding_boxes: torch.Tensor,
format: datapoints.BoundingBoxFormat,
format: tv_tensors.BoundingBoxFormat,
canvas_size: Tuple[int, int],
output_size: List[int],
) -> Tuple[torch.Tensor, Tuple[int, int]]:
......@@ -2032,17 +2032,17 @@ def center_crop_bounding_boxes(
)
@_register_kernel_internal(center_crop, datapoints.BoundingBoxes, datapoint_wrapper=False)
@_register_kernel_internal(center_crop, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False)
def _center_crop_bounding_boxes_dispatch(
inpt: datapoints.BoundingBoxes, output_size: List[int]
) -> datapoints.BoundingBoxes:
inpt: tv_tensors.BoundingBoxes, output_size: List[int]
) -> tv_tensors.BoundingBoxes:
output, canvas_size = center_crop_bounding_boxes(
inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size, output_size=output_size
)
return datapoints.wrap(output, like=inpt, canvas_size=canvas_size)
return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size)
@_register_kernel_internal(center_crop, datapoints.Mask)
@_register_kernel_internal(center_crop, tv_tensors.Mask)
def center_crop_mask(mask: torch.Tensor, output_size: List[int]) -> torch.Tensor:
if mask.ndim < 3:
mask = mask.unsqueeze(0)
......@@ -2058,7 +2058,7 @@ def center_crop_mask(mask: torch.Tensor, output_size: List[int]) -> torch.Tensor
return output
@_register_kernel_internal(center_crop, datapoints.Video)
@_register_kernel_internal(center_crop, tv_tensors.Video)
def center_crop_video(video: torch.Tensor, output_size: List[int]) -> torch.Tensor:
return center_crop_image(video, output_size)
......@@ -2102,7 +2102,7 @@ def resized_crop(
@_register_kernel_internal(resized_crop, torch.Tensor)
@_register_kernel_internal(resized_crop, datapoints.Image)
@_register_kernel_internal(resized_crop, tv_tensors.Image)
def resized_crop_image(
image: torch.Tensor,
top: int,
......@@ -2156,7 +2156,7 @@ def _resized_crop_image_pil_dispatch(
def resized_crop_bounding_boxes(
bounding_boxes: torch.Tensor,
format: datapoints.BoundingBoxFormat,
format: tv_tensors.BoundingBoxFormat,
top: int,
left: int,
height: int,
......@@ -2167,14 +2167,14 @@ def resized_crop_bounding_boxes(
return resize_bounding_boxes(bounding_boxes, canvas_size=canvas_size, size=size)
@_register_kernel_internal(resized_crop, datapoints.BoundingBoxes, datapoint_wrapper=False)
@_register_kernel_internal(resized_crop, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False)
def _resized_crop_bounding_boxes_dispatch(
inpt: datapoints.BoundingBoxes, top: int, left: int, height: int, width: int, size: List[int], **kwargs
) -> datapoints.BoundingBoxes:
inpt: tv_tensors.BoundingBoxes, top: int, left: int, height: int, width: int, size: List[int], **kwargs
) -> tv_tensors.BoundingBoxes:
output, canvas_size = resized_crop_bounding_boxes(
inpt.as_subclass(torch.Tensor), format=inpt.format, top=top, left=left, height=height, width=width, size=size
)
return datapoints.wrap(output, like=inpt, canvas_size=canvas_size)
return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size)
def resized_crop_mask(
......@@ -2189,17 +2189,17 @@ def resized_crop_mask(
return resize_mask(mask, size)
@_register_kernel_internal(resized_crop, datapoints.Mask, datapoint_wrapper=False)
@_register_kernel_internal(resized_crop, tv_tensors.Mask, tv_tensor_wrapper=False)
def _resized_crop_mask_dispatch(
inpt: datapoints.Mask, top: int, left: int, height: int, width: int, size: List[int], **kwargs
) -> datapoints.Mask:
inpt: tv_tensors.Mask, top: int, left: int, height: int, width: int, size: List[int], **kwargs
) -> tv_tensors.Mask:
output = resized_crop_mask(
inpt.as_subclass(torch.Tensor), top=top, left=left, height=height, width=width, size=size
)
return datapoints.wrap(output, like=inpt)
return tv_tensors.wrap(output, like=inpt)
@_register_kernel_internal(resized_crop, datapoints.Video)
@_register_kernel_internal(resized_crop, tv_tensors.Video)
def resized_crop_video(
video: torch.Tensor,
top: int,
......@@ -2243,7 +2243,7 @@ def _parse_five_crop_size(size: List[int]) -> List[int]:
@_register_five_ten_crop_kernel_internal(five_crop, torch.Tensor)
@_register_five_ten_crop_kernel_internal(five_crop, datapoints.Image)
@_register_five_ten_crop_kernel_internal(five_crop, tv_tensors.Image)
def five_crop_image(
image: torch.Tensor, size: List[int]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
......@@ -2281,7 +2281,7 @@ def _five_crop_image_pil(
return tl, tr, bl, br, center
@_register_five_ten_crop_kernel_internal(five_crop, datapoints.Video)
@_register_five_ten_crop_kernel_internal(five_crop, tv_tensors.Video)
def five_crop_video(
video: torch.Tensor, size: List[int]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
......@@ -2313,7 +2313,7 @@ def ten_crop(
@_register_five_ten_crop_kernel_internal(ten_crop, torch.Tensor)
@_register_five_ten_crop_kernel_internal(ten_crop, datapoints.Image)
@_register_five_ten_crop_kernel_internal(ten_crop, tv_tensors.Image)
def ten_crop_image(
image: torch.Tensor, size: List[int], vertical_flip: bool = False
) -> Tuple[
......@@ -2367,7 +2367,7 @@ def _ten_crop_image_pil(
return non_flipped + flipped
@_register_five_ten_crop_kernel_internal(ten_crop, datapoints.Video)
@_register_five_ten_crop_kernel_internal(ten_crop, tv_tensors.Video)
def ten_crop_video(
video: torch.Tensor, size: List[int], vertical_flip: bool = False
) -> Tuple[
......
......@@ -2,9 +2,9 @@ from typing import List, Optional, Tuple
import PIL.Image
import torch
from torchvision import datapoints
from torchvision.datapoints import BoundingBoxFormat
from torchvision import tv_tensors
from torchvision.transforms import _functional_pil as _FP
from torchvision.tv_tensors import BoundingBoxFormat
from torchvision.utils import _log_api_usage_once
......@@ -22,7 +22,7 @@ def get_dimensions(inpt: torch.Tensor) -> List[int]:
@_register_kernel_internal(get_dimensions, torch.Tensor)
@_register_kernel_internal(get_dimensions, datapoints.Image, datapoint_wrapper=False)
@_register_kernel_internal(get_dimensions, tv_tensors.Image, tv_tensor_wrapper=False)
def get_dimensions_image(image: torch.Tensor) -> List[int]:
chw = list(image.shape[-3:])
ndims = len(chw)
......@@ -38,7 +38,7 @@ def get_dimensions_image(image: torch.Tensor) -> List[int]:
_get_dimensions_image_pil = _register_kernel_internal(get_dimensions, PIL.Image.Image)(_FP.get_dimensions)
@_register_kernel_internal(get_dimensions, datapoints.Video, datapoint_wrapper=False)
@_register_kernel_internal(get_dimensions, tv_tensors.Video, tv_tensor_wrapper=False)
def get_dimensions_video(video: torch.Tensor) -> List[int]:
return get_dimensions_image(video)
......@@ -54,7 +54,7 @@ def get_num_channels(inpt: torch.Tensor) -> int:
@_register_kernel_internal(get_num_channels, torch.Tensor)
@_register_kernel_internal(get_num_channels, datapoints.Image, datapoint_wrapper=False)
@_register_kernel_internal(get_num_channels, tv_tensors.Image, tv_tensor_wrapper=False)
def get_num_channels_image(image: torch.Tensor) -> int:
chw = image.shape[-3:]
ndims = len(chw)
......@@ -69,7 +69,7 @@ def get_num_channels_image(image: torch.Tensor) -> int:
_get_num_channels_image_pil = _register_kernel_internal(get_num_channels, PIL.Image.Image)(_FP.get_image_num_channels)
@_register_kernel_internal(get_num_channels, datapoints.Video, datapoint_wrapper=False)
@_register_kernel_internal(get_num_channels, tv_tensors.Video, tv_tensor_wrapper=False)
def get_num_channels_video(video: torch.Tensor) -> int:
return get_num_channels_image(video)
......@@ -90,7 +90,7 @@ def get_size(inpt: torch.Tensor) -> List[int]:
@_register_kernel_internal(get_size, torch.Tensor)
@_register_kernel_internal(get_size, datapoints.Image, datapoint_wrapper=False)
@_register_kernel_internal(get_size, tv_tensors.Image, tv_tensor_wrapper=False)
def get_size_image(image: torch.Tensor) -> List[int]:
hw = list(image.shape[-2:])
ndims = len(hw)
......@@ -106,18 +106,18 @@ def _get_size_image_pil(image: PIL.Image.Image) -> List[int]:
return [height, width]
@_register_kernel_internal(get_size, datapoints.Video, datapoint_wrapper=False)
@_register_kernel_internal(get_size, tv_tensors.Video, tv_tensor_wrapper=False)
def get_size_video(video: torch.Tensor) -> List[int]:
return get_size_image(video)
@_register_kernel_internal(get_size, datapoints.Mask, datapoint_wrapper=False)
@_register_kernel_internal(get_size, tv_tensors.Mask, tv_tensor_wrapper=False)
def get_size_mask(mask: torch.Tensor) -> List[int]:
return get_size_image(mask)
@_register_kernel_internal(get_size, datapoints.BoundingBoxes, datapoint_wrapper=False)
def get_size_bounding_boxes(bounding_box: datapoints.BoundingBoxes) -> List[int]:
@_register_kernel_internal(get_size, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False)
def get_size_bounding_boxes(bounding_box: tv_tensors.BoundingBoxes) -> List[int]:
return list(bounding_box.canvas_size)
......@@ -132,7 +132,7 @@ def get_num_frames(inpt: torch.Tensor) -> int:
@_register_kernel_internal(get_num_frames, torch.Tensor)
@_register_kernel_internal(get_num_frames, datapoints.Video, datapoint_wrapper=False)
@_register_kernel_internal(get_num_frames, tv_tensors.Video, tv_tensor_wrapper=False)
def get_num_frames_video(video: torch.Tensor) -> int:
return video.shape[-4]
......@@ -205,7 +205,7 @@ def convert_bounding_box_format(
) -> torch.Tensor:
"""[BETA] See :func:`~torchvision.transforms.v2.ConvertBoundingBoxFormat` for details."""
# This being a kernel / functional hybrid, we need an option to pass `old_format` explicitly for pure tensor
# inputs as well as extract it from `datapoints.BoundingBoxes` inputs. However, putting a default value on
# inputs as well as extract it from `tv_tensors.BoundingBoxes` inputs. However, putting a default value on
# `old_format` means we also need to put one on `new_format` to have syntactically correct Python. Here we mimic the
# default error that would be thrown if `new_format` had no default value.
if new_format is None:
......@@ -218,16 +218,16 @@ def convert_bounding_box_format(
if old_format is None:
raise ValueError("For pure tensor inputs, `old_format` has to be passed.")
return _convert_bounding_box_format(inpt, old_format=old_format, new_format=new_format, inplace=inplace)
elif isinstance(inpt, datapoints.BoundingBoxes):
elif isinstance(inpt, tv_tensors.BoundingBoxes):
if old_format is not None:
raise ValueError("For bounding box datapoint inputs, `old_format` must not be passed.")
raise ValueError("For bounding box tv_tensor inputs, `old_format` must not be passed.")
output = _convert_bounding_box_format(
inpt.as_subclass(torch.Tensor), old_format=inpt.format, new_format=new_format, inplace=inplace
)
return datapoints.wrap(output, like=inpt, format=new_format)
return tv_tensors.wrap(output, like=inpt, format=new_format)
else:
raise TypeError(
f"Input can either be a plain tensor or a bounding box datapoint, but got {type(inpt)} instead."
f"Input can either be a plain tensor or a bounding box tv_tensor, but got {type(inpt)} instead."
)
......@@ -239,7 +239,7 @@ def _clamp_bounding_boxes(
in_dtype = bounding_boxes.dtype
bounding_boxes = bounding_boxes.clone() if bounding_boxes.is_floating_point() else bounding_boxes.float()
xyxy_boxes = convert_bounding_box_format(
bounding_boxes, old_format=format, new_format=datapoints.BoundingBoxFormat.XYXY, inplace=True
bounding_boxes, old_format=format, new_format=tv_tensors.BoundingBoxFormat.XYXY, inplace=True
)
xyxy_boxes[..., 0::2].clamp_(min=0, max=canvas_size[1])
xyxy_boxes[..., 1::2].clamp_(min=0, max=canvas_size[0])
......@@ -263,12 +263,12 @@ def clamp_bounding_boxes(
if format is None or canvas_size is None:
raise ValueError("For pure tensor inputs, `format` and `canvas_size` has to be passed.")
return _clamp_bounding_boxes(inpt, format=format, canvas_size=canvas_size)
elif isinstance(inpt, datapoints.BoundingBoxes):
elif isinstance(inpt, tv_tensors.BoundingBoxes):
if format is not None or canvas_size is not None:
raise ValueError("For bounding box datapoint inputs, `format` and `canvas_size` must not be passed.")
raise ValueError("For bounding box tv_tensor inputs, `format` and `canvas_size` must not be passed.")
output = _clamp_bounding_boxes(inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size)
return datapoints.wrap(output, like=inpt)
return tv_tensors.wrap(output, like=inpt)
else:
raise TypeError(
f"Input can either be a plain tensor or a bounding box datapoint, but got {type(inpt)} instead."
f"Input can either be a plain tensor or a bounding box tv_tensor, but got {type(inpt)} instead."
)
......@@ -5,7 +5,7 @@ import PIL.Image
import torch
from torch.nn.functional import conv2d, pad as torch_pad
from torchvision import datapoints
from torchvision import tv_tensors
from torchvision.transforms._functional_tensor import _max_value
from torchvision.transforms.functional import pil_to_tensor, to_pil_image
......@@ -31,7 +31,7 @@ def normalize(
@_register_kernel_internal(normalize, torch.Tensor)
@_register_kernel_internal(normalize, datapoints.Image)
@_register_kernel_internal(normalize, tv_tensors.Image)
def normalize_image(image: torch.Tensor, mean: List[float], std: List[float], inplace: bool = False) -> torch.Tensor:
if not image.is_floating_point():
raise TypeError(f"Input tensor should be a float tensor. Got {image.dtype}.")
......@@ -65,7 +65,7 @@ def normalize_image(image: torch.Tensor, mean: List[float], std: List[float], in
return image.div_(std)
@_register_kernel_internal(normalize, datapoints.Video)
@_register_kernel_internal(normalize, tv_tensors.Video)
def normalize_video(video: torch.Tensor, mean: List[float], std: List[float], inplace: bool = False) -> torch.Tensor:
return normalize_image(video, mean, std, inplace=inplace)
......@@ -98,7 +98,7 @@ def _get_gaussian_kernel2d(
@_register_kernel_internal(gaussian_blur, torch.Tensor)
@_register_kernel_internal(gaussian_blur, datapoints.Image)
@_register_kernel_internal(gaussian_blur, tv_tensors.Image)
def gaussian_blur_image(
image: torch.Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None
) -> torch.Tensor:
......@@ -172,7 +172,7 @@ def _gaussian_blur_image_pil(
return to_pil_image(output, mode=image.mode)
@_register_kernel_internal(gaussian_blur, datapoints.Video)
@_register_kernel_internal(gaussian_blur, tv_tensors.Video)
def gaussian_blur_video(
video: torch.Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None
) -> torch.Tensor:
......@@ -206,7 +206,7 @@ def _num_value_bits(dtype: torch.dtype) -> int:
@_register_kernel_internal(to_dtype, torch.Tensor)
@_register_kernel_internal(to_dtype, datapoints.Image)
@_register_kernel_internal(to_dtype, tv_tensors.Image)
def to_dtype_image(image: torch.Tensor, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor:
if image.dtype == dtype:
......@@ -265,13 +265,13 @@ def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float32)
return to_dtype_image(image, dtype=dtype, scale=True)
@_register_kernel_internal(to_dtype, datapoints.Video)
@_register_kernel_internal(to_dtype, tv_tensors.Video)
def to_dtype_video(video: torch.Tensor, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor:
return to_dtype_image(video, dtype, scale=scale)
@_register_kernel_internal(to_dtype, datapoints.BoundingBoxes, datapoint_wrapper=False)
@_register_kernel_internal(to_dtype, datapoints.Mask, datapoint_wrapper=False)
@_register_kernel_internal(to_dtype, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False)
@_register_kernel_internal(to_dtype, tv_tensors.Mask, tv_tensor_wrapper=False)
def _to_dtype_tensor_dispatch(inpt: torch.Tensor, dtype: torch.dtype, scale: bool = False) -> torch.Tensor:
# We don't need to unwrap and rewrap here, since Datapoint.to() preserves the type
# We don't need to unwrap and rewrap here, since TVTensor.to() preserves the type
return inpt.to(dtype)
import torch
from torchvision import datapoints
from torchvision import tv_tensors
from torchvision.utils import _log_api_usage_once
......@@ -19,7 +19,7 @@ def uniform_temporal_subsample(inpt: torch.Tensor, num_samples: int) -> torch.Te
@_register_kernel_internal(uniform_temporal_subsample, torch.Tensor)
@_register_kernel_internal(uniform_temporal_subsample, datapoints.Video)
@_register_kernel_internal(uniform_temporal_subsample, tv_tensors.Video)
def uniform_temporal_subsample_video(video: torch.Tensor, num_samples: int) -> torch.Tensor:
# Reference: https://github.com/facebookresearch/pytorchvideo/blob/a0a131e/pytorchvideo/transforms/functional.py#L19
t_max = video.shape[-4] - 1
......
......@@ -3,12 +3,12 @@ from typing import Union
import numpy as np
import PIL.Image
import torch
from torchvision import datapoints
from torchvision import tv_tensors
from torchvision.transforms import functional as _F
@torch.jit.unused
def to_image(inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray]) -> datapoints.Image:
def to_image(inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray]) -> tv_tensors.Image:
"""[BETA] See :class:`~torchvision.transforms.v2.ToImage` for details."""
if isinstance(inpt, np.ndarray):
output = torch.from_numpy(inpt).permute((2, 0, 1)).contiguous()
......@@ -18,7 +18,7 @@ def to_image(inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray]) -> datapoin
output = inpt
else:
raise TypeError(f"Input can either be a numpy array or a PIL image, but got {type(inpt)} instead.")
return datapoints.Image(output)
return tv_tensors.Image(output)
to_pil_image = _F.to_pil_image
......
......@@ -2,21 +2,21 @@ import functools
from typing import Any, Callable, Dict, List, Optional, Sequence, Type, Union
import torch
from torchvision import datapoints
from torchvision import tv_tensors
_FillType = Union[int, float, Sequence[int], Sequence[float], None]
_FillTypeJIT = Optional[List[float]]
def is_pure_tensor(inpt: Any) -> bool:
return isinstance(inpt, torch.Tensor) and not isinstance(inpt, datapoints.Datapoint)
return isinstance(inpt, torch.Tensor) and not isinstance(inpt, tv_tensors.TVTensor)
# {functional: {input_type: type_specific_kernel}}
_KERNEL_REGISTRY: Dict[Callable, Dict[Type, Callable]] = {}
def _kernel_datapoint_wrapper(kernel):
def _kernel_tv_tensor_wrapper(kernel):
@functools.wraps(kernel)
def wrapper(inpt, *args, **kwargs):
# If you're wondering whether we could / should get rid of this wrapper,
......@@ -25,24 +25,24 @@ def _kernel_datapoint_wrapper(kernel):
# regardless of whether we override __torch_function__ in our base class
# or not.
# Also, even if we didn't call `as_subclass` here, we would still need
# this wrapper to call wrap(), because the Datapoint type would be
# this wrapper to call wrap(), because the TVTensor type would be
# lost after the first operation due to our own __torch_function__
# logic.
output = kernel(inpt.as_subclass(torch.Tensor), *args, **kwargs)
return datapoints.wrap(output, like=inpt)
return tv_tensors.wrap(output, like=inpt)
return wrapper
def _register_kernel_internal(functional, input_type, *, datapoint_wrapper=True):
def _register_kernel_internal(functional, input_type, *, tv_tensor_wrapper=True):
registry = _KERNEL_REGISTRY.setdefault(functional, {})
if input_type in registry:
raise ValueError(f"Functional {functional} already has a kernel registered for type {input_type}.")
def decorator(kernel):
registry[input_type] = (
_kernel_datapoint_wrapper(kernel)
if issubclass(input_type, datapoints.Datapoint) and datapoint_wrapper
_kernel_tv_tensor_wrapper(kernel)
if issubclass(input_type, tv_tensors.TVTensor) and tv_tensor_wrapper
else kernel
)
return kernel
......@@ -62,14 +62,14 @@ def _name_to_functional(name):
_BUILTIN_DATAPOINT_TYPES = {
obj for obj in datapoints.__dict__.values() if isinstance(obj, type) and issubclass(obj, datapoints.Datapoint)
obj for obj in tv_tensors.__dict__.values() if isinstance(obj, type) and issubclass(obj, tv_tensors.TVTensor)
}
def register_kernel(functional, datapoint_cls):
"""[BETA] Decorate a kernel to register it for a functional and a (custom) datapoint type.
def register_kernel(functional, tv_tensor_cls):
"""[BETA] Decorate a kernel to register it for a functional and a (custom) tv_tensor type.
See :ref:`sphx_glr_auto_examples_transforms_plot_custom_datapoints.py` for usage
See :ref:`sphx_glr_auto_examples_transforms_plot_custom_tv_tensors.py` for usage
details.
"""
if isinstance(functional, str):
......@@ -83,16 +83,16 @@ def register_kernel(functional, datapoint_cls):
f"but got {functional}."
)
if not (isinstance(datapoint_cls, type) and issubclass(datapoint_cls, datapoints.Datapoint)):
if not (isinstance(tv_tensor_cls, type) and issubclass(tv_tensor_cls, tv_tensors.TVTensor)):
raise ValueError(
f"Kernels can only be registered for subclasses of torchvision.datapoints.Datapoint, "
f"but got {datapoint_cls}."
f"Kernels can only be registered for subclasses of torchvision.tv_tensors.TVTensor, "
f"but got {tv_tensor_cls}."
)
if datapoint_cls in _BUILTIN_DATAPOINT_TYPES:
raise ValueError(f"Kernels cannot be registered for the builtin datapoint classes, but got {datapoint_cls}")
if tv_tensor_cls in _BUILTIN_DATAPOINT_TYPES:
raise ValueError(f"Kernels cannot be registered for the builtin tv_tensor classes, but got {tv_tensor_cls}")
return _register_kernel_internal(functional, datapoint_cls, datapoint_wrapper=False)
return _register_kernel_internal(functional, tv_tensor_cls, tv_tensor_wrapper=False)
def _get_kernel(functional, input_type, *, allow_passthrough=False):
......@@ -103,10 +103,10 @@ def _get_kernel(functional, input_type, *, allow_passthrough=False):
for cls in input_type.__mro__:
if cls in registry:
return registry[cls]
elif cls is datapoints.Datapoint:
# We don't want user-defined datapoints to dispatch to the pure Tensor kernels, so we explicit stop the
# MRO traversal before hitting torch.Tensor. We can even stop at datapoints.Datapoint, since we don't
# allow kernels to be registered for datapoints.Datapoint anyway.
elif cls is tv_tensors.TVTensor:
# We don't want user-defined tv_tensors to dispatch to the pure Tensor kernels, so we explicit stop the
# MRO traversal before hitting torch.Tensor. We can even stop at tv_tensors.TVTensor, since we don't
# allow kernels to be registered for tv_tensors.TVTensor anyway.
break
if allow_passthrough:
......@@ -130,12 +130,12 @@ def _register_five_ten_crop_kernel_internal(functional, input_type):
def wrapper(inpt, *args, **kwargs):
output = kernel(inpt, *args, **kwargs)
container_type = type(output)
return container_type(datapoints.wrap(o, like=inpt) for o in output)
return container_type(tv_tensors.wrap(o, like=inpt) for o in output)
return wrapper
def decorator(kernel):
registry[input_type] = wrap(kernel) if issubclass(input_type, datapoints.Datapoint) else kernel
registry[input_type] = wrap(kernel) if issubclass(input_type, tv_tensors.TVTensor) else kernel
return kernel
return decorator
import torch
from ._bounding_box import BoundingBoxes, BoundingBoxFormat
from ._datapoint import Datapoint
from ._image import Image
from ._mask import Mask
from ._torch_function_helpers import set_return_type
from ._tv_tensor import TVTensor
from ._video import Video
def wrap(wrappee, *, like, **kwargs):
"""[BETA] Convert a :class:`torch.Tensor` (``wrappee``) into the same :class:`~torchvision.datapoints.Datapoint` subclass as ``like``.
"""[BETA] Convert a :class:`torch.Tensor` (``wrappee``) into the same :class:`~torchvision.tv_tensors.TVTensor` subclass as ``like``.
If ``like`` is a :class:`~torchvision.datapoints.BoundingBoxes`, the ``format`` and ``canvas_size`` of
If ``like`` is a :class:`~torchvision.tv_tensors.BoundingBoxes`, the ``format`` and ``canvas_size`` of
``like`` are assigned to ``wrappee``, unless they are passed as ``kwargs``.
Args:
wrappee (Tensor): The tensor to convert.
like (:class:`~torchvision.datapoints.Datapoint`): The reference.
like (:class:`~torchvision.tv_tensors.TVTensor`): The reference.
``wrappee`` will be converted into the same subclass as ``like``.
kwargs: Can contain "format" and "canvas_size" if ``like`` is a :class:`~torchvision.datapoint.BoundingBoxes`.
kwargs: Can contain "format" and "canvas_size" if ``like`` is a :class:`~torchvision.tv_tensor.BoundingBoxes`.
Ignored otherwise.
"""
if isinstance(like, BoundingBoxes):
......
......@@ -6,7 +6,7 @@ from typing import Any, Mapping, Optional, Sequence, Tuple, Union
import torch
from torch.utils._pytree import tree_flatten
from ._datapoint import Datapoint
from ._tv_tensor import TVTensor
class BoundingBoxFormat(Enum):
......@@ -24,13 +24,13 @@ class BoundingBoxFormat(Enum):
CXCYWH = "CXCYWH"
class BoundingBoxes(Datapoint):
class BoundingBoxes(TVTensor):
"""[BETA] :class:`torch.Tensor` subclass for bounding boxes.
.. note::
There should be only one :class:`~torchvision.datapoints.BoundingBoxes`
There should be only one :class:`~torchvision.tv_tensors.BoundingBoxes`
instance per sample e.g. ``{"img": img, "bbox": BoundingBoxes(...)}``,
although one :class:`~torchvision.datapoints.BoundingBoxes` object can
although one :class:`~torchvision.tv_tensors.BoundingBoxes` object can
contain multiple bounding boxes.
Args:
......
......@@ -9,7 +9,7 @@ from collections import defaultdict
import torch
from torchvision import datapoints, datasets
from torchvision import datasets, tv_tensors
from torchvision.transforms.v2 import functional as F
__all__ = ["wrap_dataset_for_transforms_v2"]
......@@ -36,26 +36,26 @@ def wrap_dataset_for_transforms_v2(dataset, target_keys=None):
* :class:`~torchvision.datasets.CocoDetection`: Instead of returning the target as list of dicts, the wrapper
returns a dict of lists. In addition, the key-value-pairs ``"boxes"`` (in ``XYXY`` coordinate format),
``"masks"`` and ``"labels"`` are added and wrap the data in the corresponding ``torchvision.datapoints``.
``"masks"`` and ``"labels"`` are added and wrap the data in the corresponding ``torchvision.tv_tensors``.
The original keys are preserved. If ``target_keys`` is omitted, returns only the values for the
``"image_id"``, ``"boxes"``, and ``"labels"``.
* :class:`~torchvision.datasets.VOCDetection`: The key-value-pairs ``"boxes"`` and ``"labels"`` are added to
the target and wrap the data in the corresponding ``torchvision.datapoints``. The original keys are
the target and wrap the data in the corresponding ``torchvision.tv_tensors``. The original keys are
preserved. If ``target_keys`` is omitted, returns only the values for the ``"boxes"`` and ``"labels"``.
* :class:`~torchvision.datasets.CelebA`: The target for ``target_type="bbox"`` is converted to the ``XYXY``
coordinate format and wrapped into a :class:`~torchvision.datapoints.BoundingBoxes` datapoint.
coordinate format and wrapped into a :class:`~torchvision.tv_tensors.BoundingBoxes` tv_tensor.
* :class:`~torchvision.datasets.Kitti`: Instead returning the target as list of dicts, the wrapper returns a
dict of lists. In addition, the key-value-pairs ``"boxes"`` and ``"labels"`` are added and wrap the data
in the corresponding ``torchvision.datapoints``. The original keys are preserved. If ``target_keys`` is
in the corresponding ``torchvision.tv_tensors``. The original keys are preserved. If ``target_keys`` is
omitted, returns only the values for the ``"boxes"`` and ``"labels"``.
* :class:`~torchvision.datasets.OxfordIIITPet`: The target for ``target_type="segmentation"`` is wrapped into a
:class:`~torchvision.datapoints.Mask` datapoint.
:class:`~torchvision.tv_tensors.Mask` tv_tensor.
* :class:`~torchvision.datasets.Cityscapes`: The target for ``target_type="semantic"`` is wrapped into a
:class:`~torchvision.datapoints.Mask` datapoint. The target for ``target_type="instance"`` is *replaced* by
a dictionary with the key-value-pairs ``"masks"`` (as :class:`~torchvision.datapoints.Mask` datapoint) and
:class:`~torchvision.tv_tensors.Mask` tv_tensor. The target for ``target_type="instance"`` is *replaced* by
a dictionary with the key-value-pairs ``"masks"`` (as :class:`~torchvision.tv_tensors.Mask` tv_tensor) and
``"labels"``.
* :class:`~torchvision.datasets.WIDERFace`: The value for key ``"bbox"`` in the target is converted to ``XYXY``
coordinate format and wrapped into a :class:`~torchvision.datapoints.BoundingBoxes` datapoint.
coordinate format and wrapped into a :class:`~torchvision.tv_tensors.BoundingBoxes` tv_tensor.
Image classification datasets
......@@ -66,13 +66,13 @@ def wrap_dataset_for_transforms_v2(dataset, target_keys=None):
Segmentation datasets, e.g. :class:`~torchvision.datasets.VOCSegmentation`, return a two-tuple of
:class:`PIL.Image.Image`'s. This wrapper leaves the image as is (first item), while wrapping the
segmentation mask into a :class:`~torchvision.datapoints.Mask` (second item).
segmentation mask into a :class:`~torchvision.tv_tensors.Mask` (second item).
Video classification datasets
Video classification datasets, e.g. :class:`~torchvision.datasets.Kinetics`, return a three-tuple containing a
:class:`torch.Tensor` for the video and audio and a :class:`int` as label. This wrapper wraps the video into a
:class:`~torchvision.datapoints.Video` while leaving the other items as is.
:class:`~torchvision.tv_tensors.Video` while leaving the other items as is.
.. note::
......@@ -98,12 +98,12 @@ def wrap_dataset_for_transforms_v2(dataset, target_keys=None):
)
# Imagine we have isinstance(dataset, datasets.ImageNet). This will create a new class with the name
# "WrappedImageNet" at runtime that doubly inherits from VisionDatasetDatapointWrapper (see below) as well as the
# "WrappedImageNet" at runtime that doubly inherits from VisionDatasetTVTensorWrapper (see below) as well as the
# original ImageNet class. This allows the user to do regular isinstance(wrapped_dataset, datasets.ImageNet) checks,
# while we can still inject everything that we need.
wrapped_dataset_cls = type(f"Wrapped{type(dataset).__name__}", (VisionDatasetDatapointWrapper, type(dataset)), {})
# Since VisionDatasetDatapointWrapper comes before ImageNet in the MRO, calling the class hits
# VisionDatasetDatapointWrapper.__init__ first. Since we are never doing super().__init__(...), the constructor of
wrapped_dataset_cls = type(f"Wrapped{type(dataset).__name__}", (VisionDatasetTVTensorWrapper, type(dataset)), {})
# Since VisionDatasetTVTensorWrapper comes before ImageNet in the MRO, calling the class hits
# VisionDatasetTVTensorWrapper.__init__ first. Since we are never doing super().__init__(...), the constructor of
# ImageNet is never hit. That is by design, since we don't want to create the dataset instance again, but rather
# have the existing instance as attribute on the new object.
return wrapped_dataset_cls(dataset, target_keys)
......@@ -125,7 +125,7 @@ class WrapperFactories(dict):
WRAPPER_FACTORIES = WrapperFactories()
class VisionDatasetDatapointWrapper:
class VisionDatasetTVTensorWrapper:
def __init__(self, dataset, target_keys):
dataset_cls = type(dataset)
......@@ -134,7 +134,7 @@ class VisionDatasetDatapointWrapper:
f"This wrapper is meant for subclasses of `torchvision.datasets.VisionDataset`, "
f"but got a '{dataset_cls.__name__}' instead.\n"
f"For an example of how to perform the wrapping for custom datasets, see\n\n"
"https://pytorch.org/vision/main/auto_examples/plot_datapoints.html#do-i-have-to-wrap-the-output-of-the-datasets-myself"
"https://pytorch.org/vision/main/auto_examples/plot_tv_tensors.html#do-i-have-to-wrap-the-output-of-the-datasets-myself"
)
for cls in dataset_cls.mro():
......@@ -221,7 +221,7 @@ def identity_wrapper_factory(dataset, target_keys):
def pil_image_to_mask(pil_image):
return datapoints.Mask(pil_image)
return tv_tensors.Mask(pil_image)
def parse_target_keys(target_keys, *, available, default):
......@@ -302,7 +302,7 @@ def video_classification_wrapper_factory(dataset, target_keys):
def wrapper(idx, sample):
video, audio, label = sample
video = datapoints.Video(video)
video = tv_tensors.Video(video)
return video, audio, label
......@@ -373,16 +373,16 @@ def coco_dectection_wrapper_factory(dataset, target_keys):
if "boxes" in target_keys:
target["boxes"] = F.convert_bounding_box_format(
datapoints.BoundingBoxes(
tv_tensors.BoundingBoxes(
batched_target["bbox"],
format=datapoints.BoundingBoxFormat.XYWH,
format=tv_tensors.BoundingBoxFormat.XYWH,
canvas_size=canvas_size,
),
new_format=datapoints.BoundingBoxFormat.XYXY,
new_format=tv_tensors.BoundingBoxFormat.XYXY,
)
if "masks" in target_keys:
target["masks"] = datapoints.Mask(
target["masks"] = tv_tensors.Mask(
torch.stack(
[
segmentation_to_mask(segmentation, canvas_size=canvas_size)
......@@ -454,12 +454,12 @@ def voc_detection_wrapper_factory(dataset, target_keys):
target = {}
if "boxes" in target_keys:
target["boxes"] = datapoints.BoundingBoxes(
target["boxes"] = tv_tensors.BoundingBoxes(
[
[int(bndbox[part]) for part in ("xmin", "ymin", "xmax", "ymax")]
for bndbox in batched_instances["bndbox"]
],
format=datapoints.BoundingBoxFormat.XYXY,
format=tv_tensors.BoundingBoxFormat.XYXY,
canvas_size=(image.height, image.width),
)
......@@ -494,12 +494,12 @@ def celeba_wrapper_factory(dataset, target_keys):
target_types=dataset.target_type,
type_wrappers={
"bbox": lambda item: F.convert_bounding_box_format(
datapoints.BoundingBoxes(
tv_tensors.BoundingBoxes(
item,
format=datapoints.BoundingBoxFormat.XYWH,
format=tv_tensors.BoundingBoxFormat.XYWH,
canvas_size=(image.height, image.width),
),
new_format=datapoints.BoundingBoxFormat.XYXY,
new_format=tv_tensors.BoundingBoxFormat.XYXY,
),
},
)
......@@ -544,9 +544,9 @@ def kitti_wrapper_factory(dataset, target_keys):
target = {}
if "boxes" in target_keys:
target["boxes"] = datapoints.BoundingBoxes(
target["boxes"] = tv_tensors.BoundingBoxes(
batched_target["bbox"],
format=datapoints.BoundingBoxFormat.XYXY,
format=tv_tensors.BoundingBoxFormat.XYXY,
canvas_size=(image.height, image.width),
)
......@@ -596,7 +596,7 @@ def cityscapes_wrapper_factory(dataset, target_keys):
if label >= 1_000:
label //= 1_000
labels.append(label)
return dict(masks=datapoints.Mask(torch.stack(masks)), labels=torch.stack(labels))
return dict(masks=tv_tensors.Mask(torch.stack(masks)), labels=torch.stack(labels))
def wrapper(idx, sample):
image, target = sample
......@@ -641,10 +641,10 @@ def widerface_wrapper(dataset, target_keys):
if "bbox" in target_keys:
target["bbox"] = F.convert_bounding_box_format(
datapoints.BoundingBoxes(
target["bbox"], format=datapoints.BoundingBoxFormat.XYWH, canvas_size=(image.height, image.width)
tv_tensors.BoundingBoxes(
target["bbox"], format=tv_tensors.BoundingBoxFormat.XYWH, canvas_size=(image.height, image.width)
),
new_format=datapoints.BoundingBoxFormat.XYXY,
new_format=tv_tensors.BoundingBoxFormat.XYXY,
)
return image, target
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment