Unverified Commit d5f4cc38 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Datapoint -> TVTensor; datapoint[s] -> tv_tensor[s] (#7894)

parent b9447fdd
...@@ -5,13 +5,13 @@ from typing import Any, Optional, Sequence, Type, TypeVar, Union ...@@ -5,13 +5,13 @@ from typing import Any, Optional, Sequence, Type, TypeVar, Union
import torch import torch
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
from torchvision.datapoints._datapoint import Datapoint from torchvision.tv_tensors._tv_tensor import TVTensor
L = TypeVar("L", bound="_LabelBase") L = TypeVar("L", bound="_LabelBase")
class _LabelBase(Datapoint): class _LabelBase(TVTensor):
categories: Optional[Sequence[str]] categories: Optional[Sequence[str]]
@classmethod @classmethod
......
...@@ -7,7 +7,7 @@ import PIL.Image ...@@ -7,7 +7,7 @@ import PIL.Image
import torch import torch
from torch.nn.functional import one_hot from torch.nn.functional import one_hot
from torch.utils._pytree import tree_flatten, tree_unflatten from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision import datapoints, transforms as _transforms from torchvision import transforms as _transforms, tv_tensors
from torchvision.transforms.v2 import functional as F from torchvision.transforms.v2 import functional as F
from ._transform import _RandomApplyTransform, Transform from ._transform import _RandomApplyTransform, Transform
...@@ -91,10 +91,10 @@ class RandomErasing(_RandomApplyTransform): ...@@ -91,10 +91,10 @@ class RandomErasing(_RandomApplyTransform):
self._log_ratio = torch.log(torch.tensor(self.ratio)) self._log_ratio = torch.log(torch.tensor(self.ratio))
def _call_kernel(self, functional: Callable, inpt: Any, *args: Any, **kwargs: Any) -> Any: def _call_kernel(self, functional: Callable, inpt: Any, *args: Any, **kwargs: Any) -> Any:
if isinstance(inpt, (datapoints.BoundingBoxes, datapoints.Mask)): if isinstance(inpt, (tv_tensors.BoundingBoxes, tv_tensors.Mask)):
warnings.warn( warnings.warn(
f"{type(self).__name__}() is currently passing through inputs of type " f"{type(self).__name__}() is currently passing through inputs of type "
f"datapoints.{type(inpt).__name__}. This will likely change in the future." f"tv_tensors.{type(inpt).__name__}. This will likely change in the future."
) )
return super()._call_kernel(functional, inpt, *args, **kwargs) return super()._call_kernel(functional, inpt, *args, **kwargs)
...@@ -158,7 +158,7 @@ class _BaseMixUpCutMix(Transform): ...@@ -158,7 +158,7 @@ class _BaseMixUpCutMix(Transform):
flat_inputs, spec = tree_flatten(inputs) flat_inputs, spec = tree_flatten(inputs)
needs_transform_list = self._needs_transform_list(flat_inputs) needs_transform_list = self._needs_transform_list(flat_inputs)
if has_any(flat_inputs, PIL.Image.Image, datapoints.BoundingBoxes, datapoints.Mask): if has_any(flat_inputs, PIL.Image.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask):
raise ValueError(f"{type(self).__name__}() does not support PIL images, bounding boxes and masks.") raise ValueError(f"{type(self).__name__}() does not support PIL images, bounding boxes and masks.")
labels = self._labels_getter(inputs) labels = self._labels_getter(inputs)
...@@ -188,7 +188,7 @@ class _BaseMixUpCutMix(Transform): ...@@ -188,7 +188,7 @@ class _BaseMixUpCutMix(Transform):
return tree_unflatten(flat_outputs, spec) return tree_unflatten(flat_outputs, spec)
def _check_image_or_video(self, inpt: torch.Tensor, *, batch_size: int): def _check_image_or_video(self, inpt: torch.Tensor, *, batch_size: int):
expected_num_dims = 5 if isinstance(inpt, datapoints.Video) else 4 expected_num_dims = 5 if isinstance(inpt, tv_tensors.Video) else 4
if inpt.ndim != expected_num_dims: if inpt.ndim != expected_num_dims:
raise ValueError( raise ValueError(
f"Expected a batched input with {expected_num_dims} dims, but got {inpt.ndim} dimensions instead." f"Expected a batched input with {expected_num_dims} dims, but got {inpt.ndim} dimensions instead."
...@@ -242,13 +242,13 @@ class MixUp(_BaseMixUpCutMix): ...@@ -242,13 +242,13 @@ class MixUp(_BaseMixUpCutMix):
if inpt is params["labels"]: if inpt is params["labels"]:
return self._mixup_label(inpt, lam=lam) return self._mixup_label(inpt, lam=lam)
elif isinstance(inpt, (datapoints.Image, datapoints.Video)) or is_pure_tensor(inpt): elif isinstance(inpt, (tv_tensors.Image, tv_tensors.Video)) or is_pure_tensor(inpt):
self._check_image_or_video(inpt, batch_size=params["batch_size"]) self._check_image_or_video(inpt, batch_size=params["batch_size"])
output = inpt.roll(1, 0).mul_(1.0 - lam).add_(inpt.mul(lam)) output = inpt.roll(1, 0).mul_(1.0 - lam).add_(inpt.mul(lam))
if isinstance(inpt, (datapoints.Image, datapoints.Video)): if isinstance(inpt, (tv_tensors.Image, tv_tensors.Video)):
output = datapoints.wrap(output, like=inpt) output = tv_tensors.wrap(output, like=inpt)
return output return output
else: else:
...@@ -309,7 +309,7 @@ class CutMix(_BaseMixUpCutMix): ...@@ -309,7 +309,7 @@ class CutMix(_BaseMixUpCutMix):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if inpt is params["labels"]: if inpt is params["labels"]:
return self._mixup_label(inpt, lam=params["lam_adjusted"]) return self._mixup_label(inpt, lam=params["lam_adjusted"])
elif isinstance(inpt, (datapoints.Image, datapoints.Video)) or is_pure_tensor(inpt): elif isinstance(inpt, (tv_tensors.Image, tv_tensors.Video)) or is_pure_tensor(inpt):
self._check_image_or_video(inpt, batch_size=params["batch_size"]) self._check_image_or_video(inpt, batch_size=params["batch_size"])
x1, y1, x2, y2 = params["box"] x1, y1, x2, y2 = params["box"]
...@@ -317,8 +317,8 @@ class CutMix(_BaseMixUpCutMix): ...@@ -317,8 +317,8 @@ class CutMix(_BaseMixUpCutMix):
output = inpt.clone() output = inpt.clone()
output[..., y1:y2, x1:x2] = rolled[..., y1:y2, x1:x2] output[..., y1:y2, x1:x2] = rolled[..., y1:y2, x1:x2]
if isinstance(inpt, (datapoints.Image, datapoints.Video)): if isinstance(inpt, (tv_tensors.Image, tv_tensors.Video)):
output = datapoints.wrap(output, like=inpt) output = tv_tensors.wrap(output, like=inpt)
return output return output
else: else:
......
...@@ -5,7 +5,7 @@ import PIL.Image ...@@ -5,7 +5,7 @@ import PIL.Image
import torch import torch
from torch.utils._pytree import tree_flatten, tree_unflatten, TreeSpec from torch.utils._pytree import tree_flatten, tree_unflatten, TreeSpec
from torchvision import datapoints, transforms as _transforms from torchvision import transforms as _transforms, tv_tensors
from torchvision.transforms import _functional_tensor as _FT from torchvision.transforms import _functional_tensor as _FT
from torchvision.transforms.v2 import AutoAugmentPolicy, functional as F, InterpolationMode, Transform from torchvision.transforms.v2 import AutoAugmentPolicy, functional as F, InterpolationMode, Transform
from torchvision.transforms.v2.functional._geometry import _check_interpolation from torchvision.transforms.v2.functional._geometry import _check_interpolation
...@@ -15,7 +15,7 @@ from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT ...@@ -15,7 +15,7 @@ from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT
from ._utils import _get_fill, _setup_fill_arg, check_type, is_pure_tensor from ._utils import _get_fill, _setup_fill_arg, check_type, is_pure_tensor
ImageOrVideo = Union[torch.Tensor, PIL.Image.Image, datapoints.Image, datapoints.Video] ImageOrVideo = Union[torch.Tensor, PIL.Image.Image, tv_tensors.Image, tv_tensors.Video]
class _AutoAugmentBase(Transform): class _AutoAugmentBase(Transform):
...@@ -46,7 +46,7 @@ class _AutoAugmentBase(Transform): ...@@ -46,7 +46,7 @@ class _AutoAugmentBase(Transform):
def _flatten_and_extract_image_or_video( def _flatten_and_extract_image_or_video(
self, self,
inputs: Any, inputs: Any,
unsupported_types: Tuple[Type, ...] = (datapoints.BoundingBoxes, datapoints.Mask), unsupported_types: Tuple[Type, ...] = (tv_tensors.BoundingBoxes, tv_tensors.Mask),
) -> Tuple[Tuple[List[Any], TreeSpec, int], ImageOrVideo]: ) -> Tuple[Tuple[List[Any], TreeSpec, int], ImageOrVideo]:
flat_inputs, spec = tree_flatten(inputs if len(inputs) > 1 else inputs[0]) flat_inputs, spec = tree_flatten(inputs if len(inputs) > 1 else inputs[0])
needs_transform_list = self._needs_transform_list(flat_inputs) needs_transform_list = self._needs_transform_list(flat_inputs)
...@@ -56,10 +56,10 @@ class _AutoAugmentBase(Transform): ...@@ -56,10 +56,10 @@ class _AutoAugmentBase(Transform):
if needs_transform and check_type( if needs_transform and check_type(
inpt, inpt,
( (
datapoints.Image, tv_tensors.Image,
PIL.Image.Image, PIL.Image.Image,
is_pure_tensor, is_pure_tensor,
datapoints.Video, tv_tensors.Video,
), ),
): ):
image_or_videos.append((idx, inpt)) image_or_videos.append((idx, inpt))
...@@ -590,7 +590,7 @@ class AugMix(_AutoAugmentBase): ...@@ -590,7 +590,7 @@ class AugMix(_AutoAugmentBase):
augmentation_space = self._AUGMENTATION_SPACE if self.all_ops else self._PARTIAL_AUGMENTATION_SPACE augmentation_space = self._AUGMENTATION_SPACE if self.all_ops else self._PARTIAL_AUGMENTATION_SPACE
orig_dims = list(image_or_video.shape) orig_dims = list(image_or_video.shape)
expected_ndim = 5 if isinstance(orig_image_or_video, datapoints.Video) else 4 expected_ndim = 5 if isinstance(orig_image_or_video, tv_tensors.Video) else 4
batch = image_or_video.reshape([1] * max(expected_ndim - image_or_video.ndim, 0) + orig_dims) batch = image_or_video.reshape([1] * max(expected_ndim - image_or_video.ndim, 0) + orig_dims)
batch_dims = [batch.size(0)] + [1] * (batch.ndim - 1) batch_dims = [batch.size(0)] + [1] * (batch.ndim - 1)
...@@ -627,8 +627,8 @@ class AugMix(_AutoAugmentBase): ...@@ -627,8 +627,8 @@ class AugMix(_AutoAugmentBase):
mix.add_(combined_weights[:, i].reshape(batch_dims) * aug) mix.add_(combined_weights[:, i].reshape(batch_dims) * aug)
mix = mix.reshape(orig_dims).to(dtype=image_or_video.dtype) mix = mix.reshape(orig_dims).to(dtype=image_or_video.dtype)
if isinstance(orig_image_or_video, (datapoints.Image, datapoints.Video)): if isinstance(orig_image_or_video, (tv_tensors.Image, tv_tensors.Video)):
mix = datapoints.wrap(mix, like=orig_image_or_video) mix = tv_tensors.wrap(mix, like=orig_image_or_video)
elif isinstance(orig_image_or_video, PIL.Image.Image): elif isinstance(orig_image_or_video, PIL.Image.Image):
mix = F.to_pil_image(mix) mix = F.to_pil_image(mix)
......
This diff is collapsed.
from typing import Any, Dict, Union from typing import Any, Dict, Union
from torchvision import datapoints from torchvision import tv_tensors
from torchvision.transforms.v2 import functional as F, Transform from torchvision.transforms.v2 import functional as F, Transform
...@@ -10,20 +10,20 @@ class ConvertBoundingBoxFormat(Transform): ...@@ -10,20 +10,20 @@ class ConvertBoundingBoxFormat(Transform):
.. v2betastatus:: ConvertBoundingBoxFormat transform .. v2betastatus:: ConvertBoundingBoxFormat transform
Args: Args:
format (str or datapoints.BoundingBoxFormat): output bounding box format. format (str or tv_tensors.BoundingBoxFormat): output bounding box format.
Possible values are defined by :class:`~torchvision.datapoints.BoundingBoxFormat` and Possible values are defined by :class:`~torchvision.tv_tensors.BoundingBoxFormat` and
string values match the enums, e.g. "XYXY" or "XYWH" etc. string values match the enums, e.g. "XYXY" or "XYWH" etc.
""" """
_transformed_types = (datapoints.BoundingBoxes,) _transformed_types = (tv_tensors.BoundingBoxes,)
def __init__(self, format: Union[str, datapoints.BoundingBoxFormat]) -> None: def __init__(self, format: Union[str, tv_tensors.BoundingBoxFormat]) -> None:
super().__init__() super().__init__()
if isinstance(format, str): if isinstance(format, str):
format = datapoints.BoundingBoxFormat[format] format = tv_tensors.BoundingBoxFormat[format]
self.format = format self.format = format
def _transform(self, inpt: datapoints.BoundingBoxes, params: Dict[str, Any]) -> datapoints.BoundingBoxes: def _transform(self, inpt: tv_tensors.BoundingBoxes, params: Dict[str, Any]) -> tv_tensors.BoundingBoxes:
return F.convert_bounding_box_format(inpt, new_format=self.format) # type: ignore[return-value] return F.convert_bounding_box_format(inpt, new_format=self.format) # type: ignore[return-value]
...@@ -36,7 +36,7 @@ class ClampBoundingBoxes(Transform): ...@@ -36,7 +36,7 @@ class ClampBoundingBoxes(Transform):
""" """
_transformed_types = (datapoints.BoundingBoxes,) _transformed_types = (tv_tensors.BoundingBoxes,)
def _transform(self, inpt: datapoints.BoundingBoxes, params: Dict[str, Any]) -> datapoints.BoundingBoxes: def _transform(self, inpt: tv_tensors.BoundingBoxes, params: Dict[str, Any]) -> tv_tensors.BoundingBoxes:
return F.clamp_bounding_boxes(inpt) # type: ignore[return-value] return F.clamp_bounding_boxes(inpt) # type: ignore[return-value]
...@@ -6,7 +6,7 @@ import PIL.Image ...@@ -6,7 +6,7 @@ import PIL.Image
import torch import torch
from torch.utils._pytree import tree_flatten, tree_unflatten from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision import datapoints, transforms as _transforms from torchvision import transforms as _transforms, tv_tensors
from torchvision.transforms.v2 import functional as F, Transform from torchvision.transforms.v2 import functional as F, Transform
from ._utils import _parse_labels_getter, _setup_float_or_seq, _setup_size, get_bounding_boxes, has_any, is_pure_tensor from ._utils import _parse_labels_getter, _setup_float_or_seq, _setup_size, get_bounding_boxes, has_any, is_pure_tensor
...@@ -74,7 +74,7 @@ class LinearTransformation(Transform): ...@@ -74,7 +74,7 @@ class LinearTransformation(Transform):
_v1_transform_cls = _transforms.LinearTransformation _v1_transform_cls = _transforms.LinearTransformation
_transformed_types = (is_pure_tensor, datapoints.Image, datapoints.Video) _transformed_types = (is_pure_tensor, tv_tensors.Image, tv_tensors.Video)
def __init__(self, transformation_matrix: torch.Tensor, mean_vector: torch.Tensor): def __init__(self, transformation_matrix: torch.Tensor, mean_vector: torch.Tensor):
super().__init__() super().__init__()
...@@ -129,8 +129,8 @@ class LinearTransformation(Transform): ...@@ -129,8 +129,8 @@ class LinearTransformation(Transform):
output = torch.mm(flat_inpt, transformation_matrix) output = torch.mm(flat_inpt, transformation_matrix)
output = output.reshape(shape) output = output.reshape(shape)
if isinstance(inpt, (datapoints.Image, datapoints.Video)): if isinstance(inpt, (tv_tensors.Image, tv_tensors.Video)):
output = datapoints.wrap(output, like=inpt) output = tv_tensors.wrap(output, like=inpt)
return output return output
...@@ -227,12 +227,12 @@ class ToDtype(Transform): ...@@ -227,12 +227,12 @@ class ToDtype(Transform):
``ToDtype(dtype, scale=True)`` is the recommended replacement for ``ConvertImageDtype(dtype)``. ``ToDtype(dtype, scale=True)`` is the recommended replacement for ``ConvertImageDtype(dtype)``.
Args: Args:
dtype (``torch.dtype`` or dict of ``Datapoint`` -> ``torch.dtype``): The dtype to convert to. dtype (``torch.dtype`` or dict of ``TVTensor`` -> ``torch.dtype``): The dtype to convert to.
If a ``torch.dtype`` is passed, e.g. ``torch.float32``, only images and videos will be converted If a ``torch.dtype`` is passed, e.g. ``torch.float32``, only images and videos will be converted
to that dtype: this is for compatibility with :class:`~torchvision.transforms.v2.ConvertImageDtype`. to that dtype: this is for compatibility with :class:`~torchvision.transforms.v2.ConvertImageDtype`.
A dict can be passed to specify per-datapoint conversions, e.g. A dict can be passed to specify per-tv_tensor conversions, e.g.
``dtype={datapoints.Image: torch.float32, datapoints.Mask: torch.int64, "others":None}``. The "others" ``dtype={tv_tensors.Image: torch.float32, tv_tensors.Mask: torch.int64, "others":None}``. The "others"
key can be used as a catch-all for any other datapoint type, and ``None`` means no conversion. key can be used as a catch-all for any other tv_tensor type, and ``None`` means no conversion.
scale (bool, optional): Whether to scale the values for images or videos. See :ref:`range_and_dtype`. scale (bool, optional): Whether to scale the values for images or videos. See :ref:`range_and_dtype`.
Default: ``False``. Default: ``False``.
""" """
...@@ -250,12 +250,12 @@ class ToDtype(Transform): ...@@ -250,12 +250,12 @@ class ToDtype(Transform):
if ( if (
isinstance(dtype, dict) isinstance(dtype, dict)
and torch.Tensor in dtype and torch.Tensor in dtype
and any(cls in dtype for cls in [datapoints.Image, datapoints.Video]) and any(cls in dtype for cls in [tv_tensors.Image, tv_tensors.Video])
): ):
warnings.warn( warnings.warn(
"Got `dtype` values for `torch.Tensor` and either `datapoints.Image` or `datapoints.Video`. " "Got `dtype` values for `torch.Tensor` and either `tv_tensors.Image` or `tv_tensors.Video`. "
"Note that a plain `torch.Tensor` will *not* be transformed by this (or any other transformation) " "Note that a plain `torch.Tensor` will *not* be transformed by this (or any other transformation) "
"in case a `datapoints.Image` or `datapoints.Video` is present in the input." "in case a `tv_tensors.Image` or `tv_tensors.Video` is present in the input."
) )
self.dtype = dtype self.dtype = dtype
self.scale = scale self.scale = scale
...@@ -264,7 +264,7 @@ class ToDtype(Transform): ...@@ -264,7 +264,7 @@ class ToDtype(Transform):
if isinstance(self.dtype, torch.dtype): if isinstance(self.dtype, torch.dtype):
# For consistency / BC with ConvertImageDtype, we only care about images or videos when dtype # For consistency / BC with ConvertImageDtype, we only care about images or videos when dtype
# is a simple torch.dtype # is a simple torch.dtype
if not is_pure_tensor(inpt) and not isinstance(inpt, (datapoints.Image, datapoints.Video)): if not is_pure_tensor(inpt) and not isinstance(inpt, (tv_tensors.Image, tv_tensors.Video)):
return inpt return inpt
dtype: Optional[torch.dtype] = self.dtype dtype: Optional[torch.dtype] = self.dtype
...@@ -278,10 +278,10 @@ class ToDtype(Transform): ...@@ -278,10 +278,10 @@ class ToDtype(Transform):
"If you only need to convert the dtype of images or videos, you can just pass e.g. dtype=torch.float32. " "If you only need to convert the dtype of images or videos, you can just pass e.g. dtype=torch.float32. "
"If you're passing a dict as dtype, " "If you're passing a dict as dtype, "
'you can use "others" as a catch-all key ' 'you can use "others" as a catch-all key '
'e.g. dtype={datapoints.Mask: torch.int64, "others": None} to pass-through the rest of the inputs.' 'e.g. dtype={tv_tensors.Mask: torch.int64, "others": None} to pass-through the rest of the inputs.'
) )
supports_scaling = is_pure_tensor(inpt) or isinstance(inpt, (datapoints.Image, datapoints.Video)) supports_scaling = is_pure_tensor(inpt) or isinstance(inpt, (tv_tensors.Image, tv_tensors.Video))
if dtype is None: if dtype is None:
if self.scale and supports_scaling: if self.scale and supports_scaling:
warnings.warn( warnings.warn(
...@@ -389,10 +389,10 @@ class SanitizeBoundingBoxes(Transform): ...@@ -389,10 +389,10 @@ class SanitizeBoundingBoxes(Transform):
) )
boxes = cast( boxes = cast(
datapoints.BoundingBoxes, tv_tensors.BoundingBoxes,
F.convert_bounding_box_format( F.convert_bounding_box_format(
boxes, boxes,
new_format=datapoints.BoundingBoxFormat.XYXY, new_format=tv_tensors.BoundingBoxFormat.XYXY,
), ),
) )
ws, hs = boxes[:, 2] - boxes[:, 0], boxes[:, 3] - boxes[:, 1] ws, hs = boxes[:, 2] - boxes[:, 0], boxes[:, 3] - boxes[:, 1]
...@@ -415,7 +415,7 @@ class SanitizeBoundingBoxes(Transform): ...@@ -415,7 +415,7 @@ class SanitizeBoundingBoxes(Transform):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
is_label = inpt is not None and inpt is params["labels"] is_label = inpt is not None and inpt is params["labels"]
is_bounding_boxes_or_mask = isinstance(inpt, (datapoints.BoundingBoxes, datapoints.Mask)) is_bounding_boxes_or_mask = isinstance(inpt, (tv_tensors.BoundingBoxes, tv_tensors.Mask))
if not (is_label or is_bounding_boxes_or_mask): if not (is_label or is_bounding_boxes_or_mask):
return inpt return inpt
...@@ -425,4 +425,4 @@ class SanitizeBoundingBoxes(Transform): ...@@ -425,4 +425,4 @@ class SanitizeBoundingBoxes(Transform):
if is_label: if is_label:
return output return output
return datapoints.wrap(output, like=inpt) return tv_tensors.wrap(output, like=inpt)
...@@ -7,7 +7,7 @@ import PIL.Image ...@@ -7,7 +7,7 @@ import PIL.Image
import torch import torch
from torch import nn from torch import nn
from torch.utils._pytree import tree_flatten, tree_unflatten from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision import datapoints from torchvision import tv_tensors
from torchvision.transforms.v2._utils import check_type, has_any, is_pure_tensor from torchvision.transforms.v2._utils import check_type, has_any, is_pure_tensor
from torchvision.utils import _log_api_usage_once from torchvision.utils import _log_api_usage_once
...@@ -56,8 +56,8 @@ class Transform(nn.Module): ...@@ -56,8 +56,8 @@ class Transform(nn.Module):
def _needs_transform_list(self, flat_inputs: List[Any]) -> List[bool]: def _needs_transform_list(self, flat_inputs: List[Any]) -> List[bool]:
# Below is a heuristic on how to deal with pure tensor inputs: # Below is a heuristic on how to deal with pure tensor inputs:
# 1. Pure tensors, i.e. tensors that are not a datapoint, are passed through if there is an explicit image # 1. Pure tensors, i.e. tensors that are not a tv_tensor, are passed through if there is an explicit image
# (`datapoints.Image` or `PIL.Image.Image`) or video (`datapoints.Video`) in the sample. # (`tv_tensors.Image` or `PIL.Image.Image`) or video (`tv_tensors.Video`) in the sample.
# 2. If there is no explicit image or video in the sample, only the first encountered pure tensor is # 2. If there is no explicit image or video in the sample, only the first encountered pure tensor is
# transformed as image, while the rest is passed through. The order is defined by the returned `flat_inputs` # transformed as image, while the rest is passed through. The order is defined by the returned `flat_inputs`
# of `tree_flatten`, which recurses depth-first through the input. # of `tree_flatten`, which recurses depth-first through the input.
...@@ -72,7 +72,7 @@ class Transform(nn.Module): ...@@ -72,7 +72,7 @@ class Transform(nn.Module):
# However, this case wasn't supported by transforms v1 either, so there is no BC concern. # However, this case wasn't supported by transforms v1 either, so there is no BC concern.
needs_transform_list = [] needs_transform_list = []
transform_pure_tensor = not has_any(flat_inputs, datapoints.Image, datapoints.Video, PIL.Image.Image) transform_pure_tensor = not has_any(flat_inputs, tv_tensors.Image, tv_tensors.Video, PIL.Image.Image)
for inpt in flat_inputs: for inpt in flat_inputs:
needs_transform = True needs_transform = True
......
...@@ -4,7 +4,7 @@ import numpy as np ...@@ -4,7 +4,7 @@ import numpy as np
import PIL.Image import PIL.Image
import torch import torch
from torchvision import datapoints from torchvision import tv_tensors
from torchvision.transforms.v2 import functional as F, Transform from torchvision.transforms.v2 import functional as F, Transform
from torchvision.transforms.v2._utils import is_pure_tensor from torchvision.transforms.v2._utils import is_pure_tensor
...@@ -27,7 +27,7 @@ class PILToTensor(Transform): ...@@ -27,7 +27,7 @@ class PILToTensor(Transform):
class ToImage(Transform): class ToImage(Transform):
"""[BETA] Convert a tensor, ndarray, or PIL Image to :class:`~torchvision.datapoints.Image` """[BETA] Convert a tensor, ndarray, or PIL Image to :class:`~torchvision.tv_tensors.Image`
; this does not scale values. ; this does not scale values.
.. v2betastatus:: ToImage transform .. v2betastatus:: ToImage transform
...@@ -39,7 +39,7 @@ class ToImage(Transform): ...@@ -39,7 +39,7 @@ class ToImage(Transform):
def _transform( def _transform(
self, inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray], params: Dict[str, Any] self, inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray], params: Dict[str, Any]
) -> datapoints.Image: ) -> tv_tensors.Image:
return F.to_image(inpt) return F.to_image(inpt)
...@@ -66,7 +66,7 @@ class ToPILImage(Transform): ...@@ -66,7 +66,7 @@ class ToPILImage(Transform):
.. _PIL.Image mode: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#concept-modes .. _PIL.Image mode: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#concept-modes
""" """
_transformed_types = (is_pure_tensor, datapoints.Image, np.ndarray) _transformed_types = (is_pure_tensor, tv_tensors.Image, np.ndarray)
def __init__(self, mode: Optional[str] = None) -> None: def __init__(self, mode: Optional[str] = None) -> None:
super().__init__() super().__init__()
...@@ -79,14 +79,14 @@ class ToPILImage(Transform): ...@@ -79,14 +79,14 @@ class ToPILImage(Transform):
class ToPureTensor(Transform): class ToPureTensor(Transform):
"""[BETA] Convert all datapoints to pure tensors, removing associated metadata (if any). """[BETA] Convert all tv_tensors to pure tensors, removing associated metadata (if any).
.. v2betastatus:: ToPureTensor transform .. v2betastatus:: ToPureTensor transform
This doesn't scale or change the values, only the type. This doesn't scale or change the values, only the type.
""" """
_transformed_types = (datapoints.Datapoint,) _transformed_types = (tv_tensors.TVTensor,)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> torch.Tensor: def _transform(self, inpt: Any, params: Dict[str, Any]) -> torch.Tensor:
return inpt.as_subclass(torch.Tensor) return inpt.as_subclass(torch.Tensor)
...@@ -9,7 +9,7 @@ from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple ...@@ -9,7 +9,7 @@ from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple
import PIL.Image import PIL.Image
import torch import torch
from torchvision import datapoints from torchvision import tv_tensors
from torchvision._utils import sequence_to_str from torchvision._utils import sequence_to_str
...@@ -149,10 +149,10 @@ def _parse_labels_getter( ...@@ -149,10 +149,10 @@ def _parse_labels_getter(
raise ValueError(f"labels_getter should either be 'default', a callable, or None, but got {labels_getter}.") raise ValueError(f"labels_getter should either be 'default', a callable, or None, but got {labels_getter}.")
def get_bounding_boxes(flat_inputs: List[Any]) -> datapoints.BoundingBoxes: def get_bounding_boxes(flat_inputs: List[Any]) -> tv_tensors.BoundingBoxes:
# This assumes there is only one bbox per sample as per the general convention # This assumes there is only one bbox per sample as per the general convention
try: try:
return next(inpt for inpt in flat_inputs if isinstance(inpt, datapoints.BoundingBoxes)) return next(inpt for inpt in flat_inputs if isinstance(inpt, tv_tensors.BoundingBoxes))
except StopIteration: except StopIteration:
raise ValueError("No bounding boxes were found in the sample") raise ValueError("No bounding boxes were found in the sample")
...@@ -161,7 +161,7 @@ def query_chw(flat_inputs: List[Any]) -> Tuple[int, int, int]: ...@@ -161,7 +161,7 @@ def query_chw(flat_inputs: List[Any]) -> Tuple[int, int, int]:
chws = { chws = {
tuple(get_dimensions(inpt)) tuple(get_dimensions(inpt))
for inpt in flat_inputs for inpt in flat_inputs
if check_type(inpt, (is_pure_tensor, datapoints.Image, PIL.Image.Image, datapoints.Video)) if check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video))
} }
if not chws: if not chws:
raise TypeError("No image or video was found in the sample") raise TypeError("No image or video was found in the sample")
...@@ -179,11 +179,11 @@ def query_size(flat_inputs: List[Any]) -> Tuple[int, int]: ...@@ -179,11 +179,11 @@ def query_size(flat_inputs: List[Any]) -> Tuple[int, int]:
inpt, inpt,
( (
is_pure_tensor, is_pure_tensor,
datapoints.Image, tv_tensors.Image,
PIL.Image.Image, PIL.Image.Image,
datapoints.Video, tv_tensors.Video,
datapoints.Mask, tv_tensors.Mask,
datapoints.BoundingBoxes, tv_tensors.BoundingBoxes,
), ),
) )
} }
......
import PIL.Image import PIL.Image
import torch import torch
from torchvision import datapoints from torchvision import tv_tensors
from torchvision.transforms.functional import pil_to_tensor, to_pil_image from torchvision.transforms.functional import pil_to_tensor, to_pil_image
from torchvision.utils import _log_api_usage_once from torchvision.utils import _log_api_usage_once
...@@ -28,7 +28,7 @@ def erase( ...@@ -28,7 +28,7 @@ def erase(
@_register_kernel_internal(erase, torch.Tensor) @_register_kernel_internal(erase, torch.Tensor)
@_register_kernel_internal(erase, datapoints.Image) @_register_kernel_internal(erase, tv_tensors.Image)
def erase_image( def erase_image(
image: torch.Tensor, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False image: torch.Tensor, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -48,7 +48,7 @@ def _erase_image_pil( ...@@ -48,7 +48,7 @@ def _erase_image_pil(
return to_pil_image(output, mode=image.mode) return to_pil_image(output, mode=image.mode)
@_register_kernel_internal(erase, datapoints.Video) @_register_kernel_internal(erase, tv_tensors.Video)
def erase_video( def erase_video(
video: torch.Tensor, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False video: torch.Tensor, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False
) -> torch.Tensor: ) -> torch.Tensor:
......
...@@ -3,7 +3,7 @@ from typing import List ...@@ -3,7 +3,7 @@ from typing import List
import PIL.Image import PIL.Image
import torch import torch
from torch.nn.functional import conv2d from torch.nn.functional import conv2d
from torchvision import datapoints from torchvision import tv_tensors
from torchvision.transforms import _functional_pil as _FP from torchvision.transforms import _functional_pil as _FP
from torchvision.transforms._functional_tensor import _max_value from torchvision.transforms._functional_tensor import _max_value
...@@ -47,7 +47,7 @@ def _rgb_to_grayscale_image( ...@@ -47,7 +47,7 @@ def _rgb_to_grayscale_image(
@_register_kernel_internal(rgb_to_grayscale, torch.Tensor) @_register_kernel_internal(rgb_to_grayscale, torch.Tensor)
@_register_kernel_internal(rgb_to_grayscale, datapoints.Image) @_register_kernel_internal(rgb_to_grayscale, tv_tensors.Image)
def rgb_to_grayscale_image(image: torch.Tensor, num_output_channels: int = 1) -> torch.Tensor: def rgb_to_grayscale_image(image: torch.Tensor, num_output_channels: int = 1) -> torch.Tensor:
if num_output_channels not in (1, 3): if num_output_channels not in (1, 3):
raise ValueError(f"num_output_channels must be 1 or 3, got {num_output_channels}.") raise ValueError(f"num_output_channels must be 1 or 3, got {num_output_channels}.")
...@@ -82,7 +82,7 @@ def adjust_brightness(inpt: torch.Tensor, brightness_factor: float) -> torch.Ten ...@@ -82,7 +82,7 @@ def adjust_brightness(inpt: torch.Tensor, brightness_factor: float) -> torch.Ten
@_register_kernel_internal(adjust_brightness, torch.Tensor) @_register_kernel_internal(adjust_brightness, torch.Tensor)
@_register_kernel_internal(adjust_brightness, datapoints.Image) @_register_kernel_internal(adjust_brightness, tv_tensors.Image)
def adjust_brightness_image(image: torch.Tensor, brightness_factor: float) -> torch.Tensor: def adjust_brightness_image(image: torch.Tensor, brightness_factor: float) -> torch.Tensor:
if brightness_factor < 0: if brightness_factor < 0:
raise ValueError(f"brightness_factor ({brightness_factor}) is not non-negative.") raise ValueError(f"brightness_factor ({brightness_factor}) is not non-negative.")
...@@ -102,7 +102,7 @@ def _adjust_brightness_image_pil(image: PIL.Image.Image, brightness_factor: floa ...@@ -102,7 +102,7 @@ def _adjust_brightness_image_pil(image: PIL.Image.Image, brightness_factor: floa
return _FP.adjust_brightness(image, brightness_factor=brightness_factor) return _FP.adjust_brightness(image, brightness_factor=brightness_factor)
@_register_kernel_internal(adjust_brightness, datapoints.Video) @_register_kernel_internal(adjust_brightness, tv_tensors.Video)
def adjust_brightness_video(video: torch.Tensor, brightness_factor: float) -> torch.Tensor: def adjust_brightness_video(video: torch.Tensor, brightness_factor: float) -> torch.Tensor:
return adjust_brightness_image(video, brightness_factor=brightness_factor) return adjust_brightness_image(video, brightness_factor=brightness_factor)
...@@ -119,7 +119,7 @@ def adjust_saturation(inpt: torch.Tensor, saturation_factor: float) -> torch.Ten ...@@ -119,7 +119,7 @@ def adjust_saturation(inpt: torch.Tensor, saturation_factor: float) -> torch.Ten
@_register_kernel_internal(adjust_saturation, torch.Tensor) @_register_kernel_internal(adjust_saturation, torch.Tensor)
@_register_kernel_internal(adjust_saturation, datapoints.Image) @_register_kernel_internal(adjust_saturation, tv_tensors.Image)
def adjust_saturation_image(image: torch.Tensor, saturation_factor: float) -> torch.Tensor: def adjust_saturation_image(image: torch.Tensor, saturation_factor: float) -> torch.Tensor:
if saturation_factor < 0: if saturation_factor < 0:
raise ValueError(f"saturation_factor ({saturation_factor}) is not non-negative.") raise ValueError(f"saturation_factor ({saturation_factor}) is not non-negative.")
...@@ -141,7 +141,7 @@ def adjust_saturation_image(image: torch.Tensor, saturation_factor: float) -> to ...@@ -141,7 +141,7 @@ def adjust_saturation_image(image: torch.Tensor, saturation_factor: float) -> to
_adjust_saturation_image_pil = _register_kernel_internal(adjust_saturation, PIL.Image.Image)(_FP.adjust_saturation) _adjust_saturation_image_pil = _register_kernel_internal(adjust_saturation, PIL.Image.Image)(_FP.adjust_saturation)
@_register_kernel_internal(adjust_saturation, datapoints.Video) @_register_kernel_internal(adjust_saturation, tv_tensors.Video)
def adjust_saturation_video(video: torch.Tensor, saturation_factor: float) -> torch.Tensor: def adjust_saturation_video(video: torch.Tensor, saturation_factor: float) -> torch.Tensor:
return adjust_saturation_image(video, saturation_factor=saturation_factor) return adjust_saturation_image(video, saturation_factor=saturation_factor)
...@@ -158,7 +158,7 @@ def adjust_contrast(inpt: torch.Tensor, contrast_factor: float) -> torch.Tensor: ...@@ -158,7 +158,7 @@ def adjust_contrast(inpt: torch.Tensor, contrast_factor: float) -> torch.Tensor:
@_register_kernel_internal(adjust_contrast, torch.Tensor) @_register_kernel_internal(adjust_contrast, torch.Tensor)
@_register_kernel_internal(adjust_contrast, datapoints.Image) @_register_kernel_internal(adjust_contrast, tv_tensors.Image)
def adjust_contrast_image(image: torch.Tensor, contrast_factor: float) -> torch.Tensor: def adjust_contrast_image(image: torch.Tensor, contrast_factor: float) -> torch.Tensor:
if contrast_factor < 0: if contrast_factor < 0:
raise ValueError(f"contrast_factor ({contrast_factor}) is not non-negative.") raise ValueError(f"contrast_factor ({contrast_factor}) is not non-negative.")
...@@ -180,7 +180,7 @@ def adjust_contrast_image(image: torch.Tensor, contrast_factor: float) -> torch. ...@@ -180,7 +180,7 @@ def adjust_contrast_image(image: torch.Tensor, contrast_factor: float) -> torch.
_adjust_contrast_image_pil = _register_kernel_internal(adjust_contrast, PIL.Image.Image)(_FP.adjust_contrast) _adjust_contrast_image_pil = _register_kernel_internal(adjust_contrast, PIL.Image.Image)(_FP.adjust_contrast)
@_register_kernel_internal(adjust_contrast, datapoints.Video) @_register_kernel_internal(adjust_contrast, tv_tensors.Video)
def adjust_contrast_video(video: torch.Tensor, contrast_factor: float) -> torch.Tensor: def adjust_contrast_video(video: torch.Tensor, contrast_factor: float) -> torch.Tensor:
return adjust_contrast_image(video, contrast_factor=contrast_factor) return adjust_contrast_image(video, contrast_factor=contrast_factor)
...@@ -197,7 +197,7 @@ def adjust_sharpness(inpt: torch.Tensor, sharpness_factor: float) -> torch.Tenso ...@@ -197,7 +197,7 @@ def adjust_sharpness(inpt: torch.Tensor, sharpness_factor: float) -> torch.Tenso
@_register_kernel_internal(adjust_sharpness, torch.Tensor) @_register_kernel_internal(adjust_sharpness, torch.Tensor)
@_register_kernel_internal(adjust_sharpness, datapoints.Image) @_register_kernel_internal(adjust_sharpness, tv_tensors.Image)
def adjust_sharpness_image(image: torch.Tensor, sharpness_factor: float) -> torch.Tensor: def adjust_sharpness_image(image: torch.Tensor, sharpness_factor: float) -> torch.Tensor:
num_channels, height, width = image.shape[-3:] num_channels, height, width = image.shape[-3:]
if num_channels not in (1, 3): if num_channels not in (1, 3):
...@@ -253,7 +253,7 @@ def adjust_sharpness_image(image: torch.Tensor, sharpness_factor: float) -> torc ...@@ -253,7 +253,7 @@ def adjust_sharpness_image(image: torch.Tensor, sharpness_factor: float) -> torc
_adjust_sharpness_image_pil = _register_kernel_internal(adjust_sharpness, PIL.Image.Image)(_FP.adjust_sharpness) _adjust_sharpness_image_pil = _register_kernel_internal(adjust_sharpness, PIL.Image.Image)(_FP.adjust_sharpness)
@_register_kernel_internal(adjust_sharpness, datapoints.Video) @_register_kernel_internal(adjust_sharpness, tv_tensors.Video)
def adjust_sharpness_video(video: torch.Tensor, sharpness_factor: float) -> torch.Tensor: def adjust_sharpness_video(video: torch.Tensor, sharpness_factor: float) -> torch.Tensor:
return adjust_sharpness_image(video, sharpness_factor=sharpness_factor) return adjust_sharpness_image(video, sharpness_factor=sharpness_factor)
...@@ -340,7 +340,7 @@ def _hsv_to_rgb(img: torch.Tensor) -> torch.Tensor: ...@@ -340,7 +340,7 @@ def _hsv_to_rgb(img: torch.Tensor) -> torch.Tensor:
@_register_kernel_internal(adjust_hue, torch.Tensor) @_register_kernel_internal(adjust_hue, torch.Tensor)
@_register_kernel_internal(adjust_hue, datapoints.Image) @_register_kernel_internal(adjust_hue, tv_tensors.Image)
def adjust_hue_image(image: torch.Tensor, hue_factor: float) -> torch.Tensor: def adjust_hue_image(image: torch.Tensor, hue_factor: float) -> torch.Tensor:
if not (-0.5 <= hue_factor <= 0.5): if not (-0.5 <= hue_factor <= 0.5):
raise ValueError(f"hue_factor ({hue_factor}) is not in [-0.5, 0.5].") raise ValueError(f"hue_factor ({hue_factor}) is not in [-0.5, 0.5].")
...@@ -371,7 +371,7 @@ def adjust_hue_image(image: torch.Tensor, hue_factor: float) -> torch.Tensor: ...@@ -371,7 +371,7 @@ def adjust_hue_image(image: torch.Tensor, hue_factor: float) -> torch.Tensor:
_adjust_hue_image_pil = _register_kernel_internal(adjust_hue, PIL.Image.Image)(_FP.adjust_hue) _adjust_hue_image_pil = _register_kernel_internal(adjust_hue, PIL.Image.Image)(_FP.adjust_hue)
@_register_kernel_internal(adjust_hue, datapoints.Video) @_register_kernel_internal(adjust_hue, tv_tensors.Video)
def adjust_hue_video(video: torch.Tensor, hue_factor: float) -> torch.Tensor: def adjust_hue_video(video: torch.Tensor, hue_factor: float) -> torch.Tensor:
return adjust_hue_image(video, hue_factor=hue_factor) return adjust_hue_image(video, hue_factor=hue_factor)
...@@ -388,7 +388,7 @@ def adjust_gamma(inpt: torch.Tensor, gamma: float, gain: float = 1) -> torch.Ten ...@@ -388,7 +388,7 @@ def adjust_gamma(inpt: torch.Tensor, gamma: float, gain: float = 1) -> torch.Ten
@_register_kernel_internal(adjust_gamma, torch.Tensor) @_register_kernel_internal(adjust_gamma, torch.Tensor)
@_register_kernel_internal(adjust_gamma, datapoints.Image) @_register_kernel_internal(adjust_gamma, tv_tensors.Image)
def adjust_gamma_image(image: torch.Tensor, gamma: float, gain: float = 1.0) -> torch.Tensor: def adjust_gamma_image(image: torch.Tensor, gamma: float, gain: float = 1.0) -> torch.Tensor:
if gamma < 0: if gamma < 0:
raise ValueError("Gamma should be a non-negative real number") raise ValueError("Gamma should be a non-negative real number")
...@@ -411,7 +411,7 @@ def adjust_gamma_image(image: torch.Tensor, gamma: float, gain: float = 1.0) -> ...@@ -411,7 +411,7 @@ def adjust_gamma_image(image: torch.Tensor, gamma: float, gain: float = 1.0) ->
_adjust_gamma_image_pil = _register_kernel_internal(adjust_gamma, PIL.Image.Image)(_FP.adjust_gamma) _adjust_gamma_image_pil = _register_kernel_internal(adjust_gamma, PIL.Image.Image)(_FP.adjust_gamma)
@_register_kernel_internal(adjust_gamma, datapoints.Video) @_register_kernel_internal(adjust_gamma, tv_tensors.Video)
def adjust_gamma_video(video: torch.Tensor, gamma: float, gain: float = 1) -> torch.Tensor: def adjust_gamma_video(video: torch.Tensor, gamma: float, gain: float = 1) -> torch.Tensor:
return adjust_gamma_image(video, gamma=gamma, gain=gain) return adjust_gamma_image(video, gamma=gamma, gain=gain)
...@@ -428,7 +428,7 @@ def posterize(inpt: torch.Tensor, bits: int) -> torch.Tensor: ...@@ -428,7 +428,7 @@ def posterize(inpt: torch.Tensor, bits: int) -> torch.Tensor:
@_register_kernel_internal(posterize, torch.Tensor) @_register_kernel_internal(posterize, torch.Tensor)
@_register_kernel_internal(posterize, datapoints.Image) @_register_kernel_internal(posterize, tv_tensors.Image)
def posterize_image(image: torch.Tensor, bits: int) -> torch.Tensor: def posterize_image(image: torch.Tensor, bits: int) -> torch.Tensor:
if image.is_floating_point(): if image.is_floating_point():
levels = 1 << bits levels = 1 << bits
...@@ -445,7 +445,7 @@ def posterize_image(image: torch.Tensor, bits: int) -> torch.Tensor: ...@@ -445,7 +445,7 @@ def posterize_image(image: torch.Tensor, bits: int) -> torch.Tensor:
_posterize_image_pil = _register_kernel_internal(posterize, PIL.Image.Image)(_FP.posterize) _posterize_image_pil = _register_kernel_internal(posterize, PIL.Image.Image)(_FP.posterize)
@_register_kernel_internal(posterize, datapoints.Video) @_register_kernel_internal(posterize, tv_tensors.Video)
def posterize_video(video: torch.Tensor, bits: int) -> torch.Tensor: def posterize_video(video: torch.Tensor, bits: int) -> torch.Tensor:
return posterize_image(video, bits=bits) return posterize_image(video, bits=bits)
...@@ -462,7 +462,7 @@ def solarize(inpt: torch.Tensor, threshold: float) -> torch.Tensor: ...@@ -462,7 +462,7 @@ def solarize(inpt: torch.Tensor, threshold: float) -> torch.Tensor:
@_register_kernel_internal(solarize, torch.Tensor) @_register_kernel_internal(solarize, torch.Tensor)
@_register_kernel_internal(solarize, datapoints.Image) @_register_kernel_internal(solarize, tv_tensors.Image)
def solarize_image(image: torch.Tensor, threshold: float) -> torch.Tensor: def solarize_image(image: torch.Tensor, threshold: float) -> torch.Tensor:
if threshold > _max_value(image.dtype): if threshold > _max_value(image.dtype):
raise TypeError(f"Threshold should be less or equal the maximum value of the dtype, but got {threshold}") raise TypeError(f"Threshold should be less or equal the maximum value of the dtype, but got {threshold}")
...@@ -473,7 +473,7 @@ def solarize_image(image: torch.Tensor, threshold: float) -> torch.Tensor: ...@@ -473,7 +473,7 @@ def solarize_image(image: torch.Tensor, threshold: float) -> torch.Tensor:
_solarize_image_pil = _register_kernel_internal(solarize, PIL.Image.Image)(_FP.solarize) _solarize_image_pil = _register_kernel_internal(solarize, PIL.Image.Image)(_FP.solarize)
@_register_kernel_internal(solarize, datapoints.Video) @_register_kernel_internal(solarize, tv_tensors.Video)
def solarize_video(video: torch.Tensor, threshold: float) -> torch.Tensor: def solarize_video(video: torch.Tensor, threshold: float) -> torch.Tensor:
return solarize_image(video, threshold=threshold) return solarize_image(video, threshold=threshold)
...@@ -490,7 +490,7 @@ def autocontrast(inpt: torch.Tensor) -> torch.Tensor: ...@@ -490,7 +490,7 @@ def autocontrast(inpt: torch.Tensor) -> torch.Tensor:
@_register_kernel_internal(autocontrast, torch.Tensor) @_register_kernel_internal(autocontrast, torch.Tensor)
@_register_kernel_internal(autocontrast, datapoints.Image) @_register_kernel_internal(autocontrast, tv_tensors.Image)
def autocontrast_image(image: torch.Tensor) -> torch.Tensor: def autocontrast_image(image: torch.Tensor) -> torch.Tensor:
c = image.shape[-3] c = image.shape[-3]
if c not in [1, 3]: if c not in [1, 3]:
...@@ -523,7 +523,7 @@ def autocontrast_image(image: torch.Tensor) -> torch.Tensor: ...@@ -523,7 +523,7 @@ def autocontrast_image(image: torch.Tensor) -> torch.Tensor:
_autocontrast_image_pil = _register_kernel_internal(autocontrast, PIL.Image.Image)(_FP.autocontrast) _autocontrast_image_pil = _register_kernel_internal(autocontrast, PIL.Image.Image)(_FP.autocontrast)
@_register_kernel_internal(autocontrast, datapoints.Video) @_register_kernel_internal(autocontrast, tv_tensors.Video)
def autocontrast_video(video: torch.Tensor) -> torch.Tensor: def autocontrast_video(video: torch.Tensor) -> torch.Tensor:
return autocontrast_image(video) return autocontrast_image(video)
...@@ -540,7 +540,7 @@ def equalize(inpt: torch.Tensor) -> torch.Tensor: ...@@ -540,7 +540,7 @@ def equalize(inpt: torch.Tensor) -> torch.Tensor:
@_register_kernel_internal(equalize, torch.Tensor) @_register_kernel_internal(equalize, torch.Tensor)
@_register_kernel_internal(equalize, datapoints.Image) @_register_kernel_internal(equalize, tv_tensors.Image)
def equalize_image(image: torch.Tensor) -> torch.Tensor: def equalize_image(image: torch.Tensor) -> torch.Tensor:
if image.numel() == 0: if image.numel() == 0:
return image return image
...@@ -613,7 +613,7 @@ def equalize_image(image: torch.Tensor) -> torch.Tensor: ...@@ -613,7 +613,7 @@ def equalize_image(image: torch.Tensor) -> torch.Tensor:
_equalize_image_pil = _register_kernel_internal(equalize, PIL.Image.Image)(_FP.equalize) _equalize_image_pil = _register_kernel_internal(equalize, PIL.Image.Image)(_FP.equalize)
@_register_kernel_internal(equalize, datapoints.Video) @_register_kernel_internal(equalize, tv_tensors.Video)
def equalize_video(video: torch.Tensor) -> torch.Tensor: def equalize_video(video: torch.Tensor) -> torch.Tensor:
return equalize_image(video) return equalize_image(video)
...@@ -630,7 +630,7 @@ def invert(inpt: torch.Tensor) -> torch.Tensor: ...@@ -630,7 +630,7 @@ def invert(inpt: torch.Tensor) -> torch.Tensor:
@_register_kernel_internal(invert, torch.Tensor) @_register_kernel_internal(invert, torch.Tensor)
@_register_kernel_internal(invert, datapoints.Image) @_register_kernel_internal(invert, tv_tensors.Image)
def invert_image(image: torch.Tensor) -> torch.Tensor: def invert_image(image: torch.Tensor) -> torch.Tensor:
if image.is_floating_point(): if image.is_floating_point():
return 1.0 - image return 1.0 - image
...@@ -644,7 +644,7 @@ def invert_image(image: torch.Tensor) -> torch.Tensor: ...@@ -644,7 +644,7 @@ def invert_image(image: torch.Tensor) -> torch.Tensor:
_invert_image_pil = _register_kernel_internal(invert, PIL.Image.Image)(_FP.invert) _invert_image_pil = _register_kernel_internal(invert, PIL.Image.Image)(_FP.invert)
@_register_kernel_internal(invert, datapoints.Video) @_register_kernel_internal(invert, tv_tensors.Video)
def invert_video(video: torch.Tensor) -> torch.Tensor: def invert_video(video: torch.Tensor) -> torch.Tensor:
return invert_image(video) return invert_image(video)
...@@ -653,7 +653,7 @@ def permute_channels(inpt: torch.Tensor, permutation: List[int]) -> torch.Tensor ...@@ -653,7 +653,7 @@ def permute_channels(inpt: torch.Tensor, permutation: List[int]) -> torch.Tensor
"""Permute the channels of the input according to the given permutation. """Permute the channels of the input according to the given permutation.
This function supports plain :class:`~torch.Tensor`'s, :class:`PIL.Image.Image`'s, and This function supports plain :class:`~torch.Tensor`'s, :class:`PIL.Image.Image`'s, and
:class:`torchvision.datapoints.Image` and :class:`torchvision.datapoints.Video`. :class:`torchvision.tv_tensors.Image` and :class:`torchvision.tv_tensors.Video`.
Example: Example:
>>> rgb_image = torch.rand(3, 256, 256) >>> rgb_image = torch.rand(3, 256, 256)
...@@ -681,7 +681,7 @@ def permute_channels(inpt: torch.Tensor, permutation: List[int]) -> torch.Tensor ...@@ -681,7 +681,7 @@ def permute_channels(inpt: torch.Tensor, permutation: List[int]) -> torch.Tensor
@_register_kernel_internal(permute_channels, torch.Tensor) @_register_kernel_internal(permute_channels, torch.Tensor)
@_register_kernel_internal(permute_channels, datapoints.Image) @_register_kernel_internal(permute_channels, tv_tensors.Image)
def permute_channels_image(image: torch.Tensor, permutation: List[int]) -> torch.Tensor: def permute_channels_image(image: torch.Tensor, permutation: List[int]) -> torch.Tensor:
shape = image.shape shape = image.shape
num_channels, height, width = shape[-3:] num_channels, height, width = shape[-3:]
...@@ -704,6 +704,6 @@ def _permute_channels_image_pil(image: PIL.Image.Image, permutation: List[int]) ...@@ -704,6 +704,6 @@ def _permute_channels_image_pil(image: PIL.Image.Image, permutation: List[int])
return to_pil_image(permute_channels_image(pil_to_tensor(image), permutation=permutation)) return to_pil_image(permute_channels_image(pil_to_tensor(image), permutation=permutation))
@_register_kernel_internal(permute_channels, datapoints.Video) @_register_kernel_internal(permute_channels, tv_tensors.Video)
def permute_channels_video(video: torch.Tensor, permutation: List[int]) -> torch.Tensor: def permute_channels_video(video: torch.Tensor, permutation: List[int]) -> torch.Tensor:
return permute_channels_image(video, permutation=permutation) return permute_channels_image(video, permutation=permutation)
...@@ -2,9 +2,9 @@ from typing import List, Optional, Tuple ...@@ -2,9 +2,9 @@ from typing import List, Optional, Tuple
import PIL.Image import PIL.Image
import torch import torch
from torchvision import datapoints from torchvision import tv_tensors
from torchvision.datapoints import BoundingBoxFormat
from torchvision.transforms import _functional_pil as _FP from torchvision.transforms import _functional_pil as _FP
from torchvision.tv_tensors import BoundingBoxFormat
from torchvision.utils import _log_api_usage_once from torchvision.utils import _log_api_usage_once
...@@ -22,7 +22,7 @@ def get_dimensions(inpt: torch.Tensor) -> List[int]: ...@@ -22,7 +22,7 @@ def get_dimensions(inpt: torch.Tensor) -> List[int]:
@_register_kernel_internal(get_dimensions, torch.Tensor) @_register_kernel_internal(get_dimensions, torch.Tensor)
@_register_kernel_internal(get_dimensions, datapoints.Image, datapoint_wrapper=False) @_register_kernel_internal(get_dimensions, tv_tensors.Image, tv_tensor_wrapper=False)
def get_dimensions_image(image: torch.Tensor) -> List[int]: def get_dimensions_image(image: torch.Tensor) -> List[int]:
chw = list(image.shape[-3:]) chw = list(image.shape[-3:])
ndims = len(chw) ndims = len(chw)
...@@ -38,7 +38,7 @@ def get_dimensions_image(image: torch.Tensor) -> List[int]: ...@@ -38,7 +38,7 @@ def get_dimensions_image(image: torch.Tensor) -> List[int]:
_get_dimensions_image_pil = _register_kernel_internal(get_dimensions, PIL.Image.Image)(_FP.get_dimensions) _get_dimensions_image_pil = _register_kernel_internal(get_dimensions, PIL.Image.Image)(_FP.get_dimensions)
@_register_kernel_internal(get_dimensions, datapoints.Video, datapoint_wrapper=False) @_register_kernel_internal(get_dimensions, tv_tensors.Video, tv_tensor_wrapper=False)
def get_dimensions_video(video: torch.Tensor) -> List[int]: def get_dimensions_video(video: torch.Tensor) -> List[int]:
return get_dimensions_image(video) return get_dimensions_image(video)
...@@ -54,7 +54,7 @@ def get_num_channels(inpt: torch.Tensor) -> int: ...@@ -54,7 +54,7 @@ def get_num_channels(inpt: torch.Tensor) -> int:
@_register_kernel_internal(get_num_channels, torch.Tensor) @_register_kernel_internal(get_num_channels, torch.Tensor)
@_register_kernel_internal(get_num_channels, datapoints.Image, datapoint_wrapper=False) @_register_kernel_internal(get_num_channels, tv_tensors.Image, tv_tensor_wrapper=False)
def get_num_channels_image(image: torch.Tensor) -> int: def get_num_channels_image(image: torch.Tensor) -> int:
chw = image.shape[-3:] chw = image.shape[-3:]
ndims = len(chw) ndims = len(chw)
...@@ -69,7 +69,7 @@ def get_num_channels_image(image: torch.Tensor) -> int: ...@@ -69,7 +69,7 @@ def get_num_channels_image(image: torch.Tensor) -> int:
_get_num_channels_image_pil = _register_kernel_internal(get_num_channels, PIL.Image.Image)(_FP.get_image_num_channels) _get_num_channels_image_pil = _register_kernel_internal(get_num_channels, PIL.Image.Image)(_FP.get_image_num_channels)
@_register_kernel_internal(get_num_channels, datapoints.Video, datapoint_wrapper=False) @_register_kernel_internal(get_num_channels, tv_tensors.Video, tv_tensor_wrapper=False)
def get_num_channels_video(video: torch.Tensor) -> int: def get_num_channels_video(video: torch.Tensor) -> int:
return get_num_channels_image(video) return get_num_channels_image(video)
...@@ -90,7 +90,7 @@ def get_size(inpt: torch.Tensor) -> List[int]: ...@@ -90,7 +90,7 @@ def get_size(inpt: torch.Tensor) -> List[int]:
@_register_kernel_internal(get_size, torch.Tensor) @_register_kernel_internal(get_size, torch.Tensor)
@_register_kernel_internal(get_size, datapoints.Image, datapoint_wrapper=False) @_register_kernel_internal(get_size, tv_tensors.Image, tv_tensor_wrapper=False)
def get_size_image(image: torch.Tensor) -> List[int]: def get_size_image(image: torch.Tensor) -> List[int]:
hw = list(image.shape[-2:]) hw = list(image.shape[-2:])
ndims = len(hw) ndims = len(hw)
...@@ -106,18 +106,18 @@ def _get_size_image_pil(image: PIL.Image.Image) -> List[int]: ...@@ -106,18 +106,18 @@ def _get_size_image_pil(image: PIL.Image.Image) -> List[int]:
return [height, width] return [height, width]
@_register_kernel_internal(get_size, datapoints.Video, datapoint_wrapper=False) @_register_kernel_internal(get_size, tv_tensors.Video, tv_tensor_wrapper=False)
def get_size_video(video: torch.Tensor) -> List[int]: def get_size_video(video: torch.Tensor) -> List[int]:
return get_size_image(video) return get_size_image(video)
@_register_kernel_internal(get_size, datapoints.Mask, datapoint_wrapper=False) @_register_kernel_internal(get_size, tv_tensors.Mask, tv_tensor_wrapper=False)
def get_size_mask(mask: torch.Tensor) -> List[int]: def get_size_mask(mask: torch.Tensor) -> List[int]:
return get_size_image(mask) return get_size_image(mask)
@_register_kernel_internal(get_size, datapoints.BoundingBoxes, datapoint_wrapper=False) @_register_kernel_internal(get_size, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False)
def get_size_bounding_boxes(bounding_box: datapoints.BoundingBoxes) -> List[int]: def get_size_bounding_boxes(bounding_box: tv_tensors.BoundingBoxes) -> List[int]:
return list(bounding_box.canvas_size) return list(bounding_box.canvas_size)
...@@ -132,7 +132,7 @@ def get_num_frames(inpt: torch.Tensor) -> int: ...@@ -132,7 +132,7 @@ def get_num_frames(inpt: torch.Tensor) -> int:
@_register_kernel_internal(get_num_frames, torch.Tensor) @_register_kernel_internal(get_num_frames, torch.Tensor)
@_register_kernel_internal(get_num_frames, datapoints.Video, datapoint_wrapper=False) @_register_kernel_internal(get_num_frames, tv_tensors.Video, tv_tensor_wrapper=False)
def get_num_frames_video(video: torch.Tensor) -> int: def get_num_frames_video(video: torch.Tensor) -> int:
return video.shape[-4] return video.shape[-4]
...@@ -205,7 +205,7 @@ def convert_bounding_box_format( ...@@ -205,7 +205,7 @@ def convert_bounding_box_format(
) -> torch.Tensor: ) -> torch.Tensor:
"""[BETA] See :func:`~torchvision.transforms.v2.ConvertBoundingBoxFormat` for details.""" """[BETA] See :func:`~torchvision.transforms.v2.ConvertBoundingBoxFormat` for details."""
# This being a kernel / functional hybrid, we need an option to pass `old_format` explicitly for pure tensor # This being a kernel / functional hybrid, we need an option to pass `old_format` explicitly for pure tensor
# inputs as well as extract it from `datapoints.BoundingBoxes` inputs. However, putting a default value on # inputs as well as extract it from `tv_tensors.BoundingBoxes` inputs. However, putting a default value on
# `old_format` means we also need to put one on `new_format` to have syntactically correct Python. Here we mimic the # `old_format` means we also need to put one on `new_format` to have syntactically correct Python. Here we mimic the
# default error that would be thrown if `new_format` had no default value. # default error that would be thrown if `new_format` had no default value.
if new_format is None: if new_format is None:
...@@ -218,16 +218,16 @@ def convert_bounding_box_format( ...@@ -218,16 +218,16 @@ def convert_bounding_box_format(
if old_format is None: if old_format is None:
raise ValueError("For pure tensor inputs, `old_format` has to be passed.") raise ValueError("For pure tensor inputs, `old_format` has to be passed.")
return _convert_bounding_box_format(inpt, old_format=old_format, new_format=new_format, inplace=inplace) return _convert_bounding_box_format(inpt, old_format=old_format, new_format=new_format, inplace=inplace)
elif isinstance(inpt, datapoints.BoundingBoxes): elif isinstance(inpt, tv_tensors.BoundingBoxes):
if old_format is not None: if old_format is not None:
raise ValueError("For bounding box datapoint inputs, `old_format` must not be passed.") raise ValueError("For bounding box tv_tensor inputs, `old_format` must not be passed.")
output = _convert_bounding_box_format( output = _convert_bounding_box_format(
inpt.as_subclass(torch.Tensor), old_format=inpt.format, new_format=new_format, inplace=inplace inpt.as_subclass(torch.Tensor), old_format=inpt.format, new_format=new_format, inplace=inplace
) )
return datapoints.wrap(output, like=inpt, format=new_format) return tv_tensors.wrap(output, like=inpt, format=new_format)
else: else:
raise TypeError( raise TypeError(
f"Input can either be a plain tensor or a bounding box datapoint, but got {type(inpt)} instead." f"Input can either be a plain tensor or a bounding box tv_tensor, but got {type(inpt)} instead."
) )
...@@ -239,7 +239,7 @@ def _clamp_bounding_boxes( ...@@ -239,7 +239,7 @@ def _clamp_bounding_boxes(
in_dtype = bounding_boxes.dtype in_dtype = bounding_boxes.dtype
bounding_boxes = bounding_boxes.clone() if bounding_boxes.is_floating_point() else bounding_boxes.float() bounding_boxes = bounding_boxes.clone() if bounding_boxes.is_floating_point() else bounding_boxes.float()
xyxy_boxes = convert_bounding_box_format( xyxy_boxes = convert_bounding_box_format(
bounding_boxes, old_format=format, new_format=datapoints.BoundingBoxFormat.XYXY, inplace=True bounding_boxes, old_format=format, new_format=tv_tensors.BoundingBoxFormat.XYXY, inplace=True
) )
xyxy_boxes[..., 0::2].clamp_(min=0, max=canvas_size[1]) xyxy_boxes[..., 0::2].clamp_(min=0, max=canvas_size[1])
xyxy_boxes[..., 1::2].clamp_(min=0, max=canvas_size[0]) xyxy_boxes[..., 1::2].clamp_(min=0, max=canvas_size[0])
...@@ -263,12 +263,12 @@ def clamp_bounding_boxes( ...@@ -263,12 +263,12 @@ def clamp_bounding_boxes(
if format is None or canvas_size is None: if format is None or canvas_size is None:
raise ValueError("For pure tensor inputs, `format` and `canvas_size` has to be passed.") raise ValueError("For pure tensor inputs, `format` and `canvas_size` has to be passed.")
return _clamp_bounding_boxes(inpt, format=format, canvas_size=canvas_size) return _clamp_bounding_boxes(inpt, format=format, canvas_size=canvas_size)
elif isinstance(inpt, datapoints.BoundingBoxes): elif isinstance(inpt, tv_tensors.BoundingBoxes):
if format is not None or canvas_size is not None: if format is not None or canvas_size is not None:
raise ValueError("For bounding box datapoint inputs, `format` and `canvas_size` must not be passed.") raise ValueError("For bounding box tv_tensor inputs, `format` and `canvas_size` must not be passed.")
output = _clamp_bounding_boxes(inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size) output = _clamp_bounding_boxes(inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size)
return datapoints.wrap(output, like=inpt) return tv_tensors.wrap(output, like=inpt)
else: else:
raise TypeError( raise TypeError(
f"Input can either be a plain tensor or a bounding box datapoint, but got {type(inpt)} instead." f"Input can either be a plain tensor or a bounding box tv_tensor, but got {type(inpt)} instead."
) )
import torch import torch
from torchvision import datapoints from torchvision import tv_tensors
from torchvision.utils import _log_api_usage_once from torchvision.utils import _log_api_usage_once
...@@ -19,7 +19,7 @@ def uniform_temporal_subsample(inpt: torch.Tensor, num_samples: int) -> torch.Te ...@@ -19,7 +19,7 @@ def uniform_temporal_subsample(inpt: torch.Tensor, num_samples: int) -> torch.Te
@_register_kernel_internal(uniform_temporal_subsample, torch.Tensor) @_register_kernel_internal(uniform_temporal_subsample, torch.Tensor)
@_register_kernel_internal(uniform_temporal_subsample, datapoints.Video) @_register_kernel_internal(uniform_temporal_subsample, tv_tensors.Video)
def uniform_temporal_subsample_video(video: torch.Tensor, num_samples: int) -> torch.Tensor: def uniform_temporal_subsample_video(video: torch.Tensor, num_samples: int) -> torch.Tensor:
# Reference: https://github.com/facebookresearch/pytorchvideo/blob/a0a131e/pytorchvideo/transforms/functional.py#L19 # Reference: https://github.com/facebookresearch/pytorchvideo/blob/a0a131e/pytorchvideo/transforms/functional.py#L19
t_max = video.shape[-4] - 1 t_max = video.shape[-4] - 1
......
...@@ -3,12 +3,12 @@ from typing import Union ...@@ -3,12 +3,12 @@ from typing import Union
import numpy as np import numpy as np
import PIL.Image import PIL.Image
import torch import torch
from torchvision import datapoints from torchvision import tv_tensors
from torchvision.transforms import functional as _F from torchvision.transforms import functional as _F
@torch.jit.unused @torch.jit.unused
def to_image(inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray]) -> datapoints.Image: def to_image(inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray]) -> tv_tensors.Image:
"""[BETA] See :class:`~torchvision.transforms.v2.ToImage` for details.""" """[BETA] See :class:`~torchvision.transforms.v2.ToImage` for details."""
if isinstance(inpt, np.ndarray): if isinstance(inpt, np.ndarray):
output = torch.from_numpy(inpt).permute((2, 0, 1)).contiguous() output = torch.from_numpy(inpt).permute((2, 0, 1)).contiguous()
...@@ -18,7 +18,7 @@ def to_image(inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray]) -> datapoin ...@@ -18,7 +18,7 @@ def to_image(inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray]) -> datapoin
output = inpt output = inpt
else: else:
raise TypeError(f"Input can either be a numpy array or a PIL image, but got {type(inpt)} instead.") raise TypeError(f"Input can either be a numpy array or a PIL image, but got {type(inpt)} instead.")
return datapoints.Image(output) return tv_tensors.Image(output)
to_pil_image = _F.to_pil_image to_pil_image = _F.to_pil_image
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment