"vscode:/vscode.git/clone" did not exist on "d693034ecfb6ce62fbfe168004682dccee471f8c"
Unverified Commit d5f4cc38 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Datapoint -> TVTensor; datapoint[s] -> tv_tensor[s] (#7894)

parent b9447fdd
......@@ -5,13 +5,13 @@ from typing import Any, Optional, Sequence, Type, TypeVar, Union
import torch
from torch.utils._pytree import tree_map
from torchvision.datapoints._datapoint import Datapoint
from torchvision.tv_tensors._tv_tensor import TVTensor
L = TypeVar("L", bound="_LabelBase")
class _LabelBase(Datapoint):
class _LabelBase(TVTensor):
categories: Optional[Sequence[str]]
@classmethod
......
......@@ -7,7 +7,7 @@ import PIL.Image
import torch
from torch.nn.functional import one_hot
from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision import datapoints, transforms as _transforms
from torchvision import transforms as _transforms, tv_tensors
from torchvision.transforms.v2 import functional as F
from ._transform import _RandomApplyTransform, Transform
......@@ -91,10 +91,10 @@ class RandomErasing(_RandomApplyTransform):
self._log_ratio = torch.log(torch.tensor(self.ratio))
def _call_kernel(self, functional: Callable, inpt: Any, *args: Any, **kwargs: Any) -> Any:
if isinstance(inpt, (datapoints.BoundingBoxes, datapoints.Mask)):
if isinstance(inpt, (tv_tensors.BoundingBoxes, tv_tensors.Mask)):
warnings.warn(
f"{type(self).__name__}() is currently passing through inputs of type "
f"datapoints.{type(inpt).__name__}. This will likely change in the future."
f"tv_tensors.{type(inpt).__name__}. This will likely change in the future."
)
return super()._call_kernel(functional, inpt, *args, **kwargs)
......@@ -158,7 +158,7 @@ class _BaseMixUpCutMix(Transform):
flat_inputs, spec = tree_flatten(inputs)
needs_transform_list = self._needs_transform_list(flat_inputs)
if has_any(flat_inputs, PIL.Image.Image, datapoints.BoundingBoxes, datapoints.Mask):
if has_any(flat_inputs, PIL.Image.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask):
raise ValueError(f"{type(self).__name__}() does not support PIL images, bounding boxes and masks.")
labels = self._labels_getter(inputs)
......@@ -188,7 +188,7 @@ class _BaseMixUpCutMix(Transform):
return tree_unflatten(flat_outputs, spec)
def _check_image_or_video(self, inpt: torch.Tensor, *, batch_size: int):
expected_num_dims = 5 if isinstance(inpt, datapoints.Video) else 4
expected_num_dims = 5 if isinstance(inpt, tv_tensors.Video) else 4
if inpt.ndim != expected_num_dims:
raise ValueError(
f"Expected a batched input with {expected_num_dims} dims, but got {inpt.ndim} dimensions instead."
......@@ -242,13 +242,13 @@ class MixUp(_BaseMixUpCutMix):
if inpt is params["labels"]:
return self._mixup_label(inpt, lam=lam)
elif isinstance(inpt, (datapoints.Image, datapoints.Video)) or is_pure_tensor(inpt):
elif isinstance(inpt, (tv_tensors.Image, tv_tensors.Video)) or is_pure_tensor(inpt):
self._check_image_or_video(inpt, batch_size=params["batch_size"])
output = inpt.roll(1, 0).mul_(1.0 - lam).add_(inpt.mul(lam))
if isinstance(inpt, (datapoints.Image, datapoints.Video)):
output = datapoints.wrap(output, like=inpt)
if isinstance(inpt, (tv_tensors.Image, tv_tensors.Video)):
output = tv_tensors.wrap(output, like=inpt)
return output
else:
......@@ -309,7 +309,7 @@ class CutMix(_BaseMixUpCutMix):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if inpt is params["labels"]:
return self._mixup_label(inpt, lam=params["lam_adjusted"])
elif isinstance(inpt, (datapoints.Image, datapoints.Video)) or is_pure_tensor(inpt):
elif isinstance(inpt, (tv_tensors.Image, tv_tensors.Video)) or is_pure_tensor(inpt):
self._check_image_or_video(inpt, batch_size=params["batch_size"])
x1, y1, x2, y2 = params["box"]
......@@ -317,8 +317,8 @@ class CutMix(_BaseMixUpCutMix):
output = inpt.clone()
output[..., y1:y2, x1:x2] = rolled[..., y1:y2, x1:x2]
if isinstance(inpt, (datapoints.Image, datapoints.Video)):
output = datapoints.wrap(output, like=inpt)
if isinstance(inpt, (tv_tensors.Image, tv_tensors.Video)):
output = tv_tensors.wrap(output, like=inpt)
return output
else:
......
......@@ -5,7 +5,7 @@ import PIL.Image
import torch
from torch.utils._pytree import tree_flatten, tree_unflatten, TreeSpec
from torchvision import datapoints, transforms as _transforms
from torchvision import transforms as _transforms, tv_tensors
from torchvision.transforms import _functional_tensor as _FT
from torchvision.transforms.v2 import AutoAugmentPolicy, functional as F, InterpolationMode, Transform
from torchvision.transforms.v2.functional._geometry import _check_interpolation
......@@ -15,7 +15,7 @@ from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT
from ._utils import _get_fill, _setup_fill_arg, check_type, is_pure_tensor
ImageOrVideo = Union[torch.Tensor, PIL.Image.Image, datapoints.Image, datapoints.Video]
ImageOrVideo = Union[torch.Tensor, PIL.Image.Image, tv_tensors.Image, tv_tensors.Video]
class _AutoAugmentBase(Transform):
......@@ -46,7 +46,7 @@ class _AutoAugmentBase(Transform):
def _flatten_and_extract_image_or_video(
self,
inputs: Any,
unsupported_types: Tuple[Type, ...] = (datapoints.BoundingBoxes, datapoints.Mask),
unsupported_types: Tuple[Type, ...] = (tv_tensors.BoundingBoxes, tv_tensors.Mask),
) -> Tuple[Tuple[List[Any], TreeSpec, int], ImageOrVideo]:
flat_inputs, spec = tree_flatten(inputs if len(inputs) > 1 else inputs[0])
needs_transform_list = self._needs_transform_list(flat_inputs)
......@@ -56,10 +56,10 @@ class _AutoAugmentBase(Transform):
if needs_transform and check_type(
inpt,
(
datapoints.Image,
tv_tensors.Image,
PIL.Image.Image,
is_pure_tensor,
datapoints.Video,
tv_tensors.Video,
),
):
image_or_videos.append((idx, inpt))
......@@ -590,7 +590,7 @@ class AugMix(_AutoAugmentBase):
augmentation_space = self._AUGMENTATION_SPACE if self.all_ops else self._PARTIAL_AUGMENTATION_SPACE
orig_dims = list(image_or_video.shape)
expected_ndim = 5 if isinstance(orig_image_or_video, datapoints.Video) else 4
expected_ndim = 5 if isinstance(orig_image_or_video, tv_tensors.Video) else 4
batch = image_or_video.reshape([1] * max(expected_ndim - image_or_video.ndim, 0) + orig_dims)
batch_dims = [batch.size(0)] + [1] * (batch.ndim - 1)
......@@ -627,8 +627,8 @@ class AugMix(_AutoAugmentBase):
mix.add_(combined_weights[:, i].reshape(batch_dims) * aug)
mix = mix.reshape(orig_dims).to(dtype=image_or_video.dtype)
if isinstance(orig_image_or_video, (datapoints.Image, datapoints.Video)):
mix = datapoints.wrap(mix, like=orig_image_or_video)
if isinstance(orig_image_or_video, (tv_tensors.Image, tv_tensors.Video)):
mix = tv_tensors.wrap(mix, like=orig_image_or_video)
elif isinstance(orig_image_or_video, PIL.Image.Image):
mix = F.to_pil_image(mix)
......
This diff is collapsed.
from typing import Any, Dict, Union
from torchvision import datapoints
from torchvision import tv_tensors
from torchvision.transforms.v2 import functional as F, Transform
......@@ -10,20 +10,20 @@ class ConvertBoundingBoxFormat(Transform):
.. v2betastatus:: ConvertBoundingBoxFormat transform
Args:
format (str or datapoints.BoundingBoxFormat): output bounding box format.
Possible values are defined by :class:`~torchvision.datapoints.BoundingBoxFormat` and
format (str or tv_tensors.BoundingBoxFormat): output bounding box format.
Possible values are defined by :class:`~torchvision.tv_tensors.BoundingBoxFormat` and
string values match the enums, e.g. "XYXY" or "XYWH" etc.
"""
_transformed_types = (datapoints.BoundingBoxes,)
_transformed_types = (tv_tensors.BoundingBoxes,)
def __init__(self, format: Union[str, datapoints.BoundingBoxFormat]) -> None:
def __init__(self, format: Union[str, tv_tensors.BoundingBoxFormat]) -> None:
super().__init__()
if isinstance(format, str):
format = datapoints.BoundingBoxFormat[format]
format = tv_tensors.BoundingBoxFormat[format]
self.format = format
def _transform(self, inpt: datapoints.BoundingBoxes, params: Dict[str, Any]) -> datapoints.BoundingBoxes:
def _transform(self, inpt: tv_tensors.BoundingBoxes, params: Dict[str, Any]) -> tv_tensors.BoundingBoxes:
return F.convert_bounding_box_format(inpt, new_format=self.format) # type: ignore[return-value]
......@@ -36,7 +36,7 @@ class ClampBoundingBoxes(Transform):
"""
_transformed_types = (datapoints.BoundingBoxes,)
_transformed_types = (tv_tensors.BoundingBoxes,)
def _transform(self, inpt: datapoints.BoundingBoxes, params: Dict[str, Any]) -> datapoints.BoundingBoxes:
def _transform(self, inpt: tv_tensors.BoundingBoxes, params: Dict[str, Any]) -> tv_tensors.BoundingBoxes:
return F.clamp_bounding_boxes(inpt) # type: ignore[return-value]
......@@ -6,7 +6,7 @@ import PIL.Image
import torch
from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision import datapoints, transforms as _transforms
from torchvision import transforms as _transforms, tv_tensors
from torchvision.transforms.v2 import functional as F, Transform
from ._utils import _parse_labels_getter, _setup_float_or_seq, _setup_size, get_bounding_boxes, has_any, is_pure_tensor
......@@ -74,7 +74,7 @@ class LinearTransformation(Transform):
_v1_transform_cls = _transforms.LinearTransformation
_transformed_types = (is_pure_tensor, datapoints.Image, datapoints.Video)
_transformed_types = (is_pure_tensor, tv_tensors.Image, tv_tensors.Video)
def __init__(self, transformation_matrix: torch.Tensor, mean_vector: torch.Tensor):
super().__init__()
......@@ -129,8 +129,8 @@ class LinearTransformation(Transform):
output = torch.mm(flat_inpt, transformation_matrix)
output = output.reshape(shape)
if isinstance(inpt, (datapoints.Image, datapoints.Video)):
output = datapoints.wrap(output, like=inpt)
if isinstance(inpt, (tv_tensors.Image, tv_tensors.Video)):
output = tv_tensors.wrap(output, like=inpt)
return output
......@@ -227,12 +227,12 @@ class ToDtype(Transform):
``ToDtype(dtype, scale=True)`` is the recommended replacement for ``ConvertImageDtype(dtype)``.
Args:
dtype (``torch.dtype`` or dict of ``Datapoint`` -> ``torch.dtype``): The dtype to convert to.
dtype (``torch.dtype`` or dict of ``TVTensor`` -> ``torch.dtype``): The dtype to convert to.
If a ``torch.dtype`` is passed, e.g. ``torch.float32``, only images and videos will be converted
to that dtype: this is for compatibility with :class:`~torchvision.transforms.v2.ConvertImageDtype`.
A dict can be passed to specify per-datapoint conversions, e.g.
``dtype={datapoints.Image: torch.float32, datapoints.Mask: torch.int64, "others":None}``. The "others"
key can be used as a catch-all for any other datapoint type, and ``None`` means no conversion.
A dict can be passed to specify per-tv_tensor conversions, e.g.
``dtype={tv_tensors.Image: torch.float32, tv_tensors.Mask: torch.int64, "others":None}``. The "others"
key can be used as a catch-all for any other tv_tensor type, and ``None`` means no conversion.
scale (bool, optional): Whether to scale the values for images or videos. See :ref:`range_and_dtype`.
Default: ``False``.
"""
......@@ -250,12 +250,12 @@ class ToDtype(Transform):
if (
isinstance(dtype, dict)
and torch.Tensor in dtype
and any(cls in dtype for cls in [datapoints.Image, datapoints.Video])
and any(cls in dtype for cls in [tv_tensors.Image, tv_tensors.Video])
):
warnings.warn(
"Got `dtype` values for `torch.Tensor` and either `datapoints.Image` or `datapoints.Video`. "
"Got `dtype` values for `torch.Tensor` and either `tv_tensors.Image` or `tv_tensors.Video`. "
"Note that a plain `torch.Tensor` will *not* be transformed by this (or any other transformation) "
"in case a `datapoints.Image` or `datapoints.Video` is present in the input."
"in case a `tv_tensors.Image` or `tv_tensors.Video` is present in the input."
)
self.dtype = dtype
self.scale = scale
......@@ -264,7 +264,7 @@ class ToDtype(Transform):
if isinstance(self.dtype, torch.dtype):
# For consistency / BC with ConvertImageDtype, we only care about images or videos when dtype
# is a simple torch.dtype
if not is_pure_tensor(inpt) and not isinstance(inpt, (datapoints.Image, datapoints.Video)):
if not is_pure_tensor(inpt) and not isinstance(inpt, (tv_tensors.Image, tv_tensors.Video)):
return inpt
dtype: Optional[torch.dtype] = self.dtype
......@@ -278,10 +278,10 @@ class ToDtype(Transform):
"If you only need to convert the dtype of images or videos, you can just pass e.g. dtype=torch.float32. "
"If you're passing a dict as dtype, "
'you can use "others" as a catch-all key '
'e.g. dtype={datapoints.Mask: torch.int64, "others": None} to pass-through the rest of the inputs.'
'e.g. dtype={tv_tensors.Mask: torch.int64, "others": None} to pass-through the rest of the inputs.'
)
supports_scaling = is_pure_tensor(inpt) or isinstance(inpt, (datapoints.Image, datapoints.Video))
supports_scaling = is_pure_tensor(inpt) or isinstance(inpt, (tv_tensors.Image, tv_tensors.Video))
if dtype is None:
if self.scale and supports_scaling:
warnings.warn(
......@@ -389,10 +389,10 @@ class SanitizeBoundingBoxes(Transform):
)
boxes = cast(
datapoints.BoundingBoxes,
tv_tensors.BoundingBoxes,
F.convert_bounding_box_format(
boxes,
new_format=datapoints.BoundingBoxFormat.XYXY,
new_format=tv_tensors.BoundingBoxFormat.XYXY,
),
)
ws, hs = boxes[:, 2] - boxes[:, 0], boxes[:, 3] - boxes[:, 1]
......@@ -415,7 +415,7 @@ class SanitizeBoundingBoxes(Transform):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
is_label = inpt is not None and inpt is params["labels"]
is_bounding_boxes_or_mask = isinstance(inpt, (datapoints.BoundingBoxes, datapoints.Mask))
is_bounding_boxes_or_mask = isinstance(inpt, (tv_tensors.BoundingBoxes, tv_tensors.Mask))
if not (is_label or is_bounding_boxes_or_mask):
return inpt
......@@ -425,4 +425,4 @@ class SanitizeBoundingBoxes(Transform):
if is_label:
return output
return datapoints.wrap(output, like=inpt)
return tv_tensors.wrap(output, like=inpt)
......@@ -7,7 +7,7 @@ import PIL.Image
import torch
from torch import nn
from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision import datapoints
from torchvision import tv_tensors
from torchvision.transforms.v2._utils import check_type, has_any, is_pure_tensor
from torchvision.utils import _log_api_usage_once
......@@ -56,8 +56,8 @@ class Transform(nn.Module):
def _needs_transform_list(self, flat_inputs: List[Any]) -> List[bool]:
# Below is a heuristic on how to deal with pure tensor inputs:
# 1. Pure tensors, i.e. tensors that are not a datapoint, are passed through if there is an explicit image
# (`datapoints.Image` or `PIL.Image.Image`) or video (`datapoints.Video`) in the sample.
# 1. Pure tensors, i.e. tensors that are not a tv_tensor, are passed through if there is an explicit image
# (`tv_tensors.Image` or `PIL.Image.Image`) or video (`tv_tensors.Video`) in the sample.
# 2. If there is no explicit image or video in the sample, only the first encountered pure tensor is
# transformed as image, while the rest is passed through. The order is defined by the returned `flat_inputs`
# of `tree_flatten`, which recurses depth-first through the input.
......@@ -72,7 +72,7 @@ class Transform(nn.Module):
# However, this case wasn't supported by transforms v1 either, so there is no BC concern.
needs_transform_list = []
transform_pure_tensor = not has_any(flat_inputs, datapoints.Image, datapoints.Video, PIL.Image.Image)
transform_pure_tensor = not has_any(flat_inputs, tv_tensors.Image, tv_tensors.Video, PIL.Image.Image)
for inpt in flat_inputs:
needs_transform = True
......
......@@ -4,7 +4,7 @@ import numpy as np
import PIL.Image
import torch
from torchvision import datapoints
from torchvision import tv_tensors
from torchvision.transforms.v2 import functional as F, Transform
from torchvision.transforms.v2._utils import is_pure_tensor
......@@ -27,7 +27,7 @@ class PILToTensor(Transform):
class ToImage(Transform):
"""[BETA] Convert a tensor, ndarray, or PIL Image to :class:`~torchvision.datapoints.Image`
"""[BETA] Convert a tensor, ndarray, or PIL Image to :class:`~torchvision.tv_tensors.Image`
; this does not scale values.
.. v2betastatus:: ToImage transform
......@@ -39,7 +39,7 @@ class ToImage(Transform):
def _transform(
self, inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray], params: Dict[str, Any]
) -> datapoints.Image:
) -> tv_tensors.Image:
return F.to_image(inpt)
......@@ -66,7 +66,7 @@ class ToPILImage(Transform):
.. _PIL.Image mode: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#concept-modes
"""
_transformed_types = (is_pure_tensor, datapoints.Image, np.ndarray)
_transformed_types = (is_pure_tensor, tv_tensors.Image, np.ndarray)
def __init__(self, mode: Optional[str] = None) -> None:
super().__init__()
......@@ -79,14 +79,14 @@ class ToPILImage(Transform):
class ToPureTensor(Transform):
"""[BETA] Convert all datapoints to pure tensors, removing associated metadata (if any).
"""[BETA] Convert all tv_tensors to pure tensors, removing associated metadata (if any).
.. v2betastatus:: ToPureTensor transform
This doesn't scale or change the values, only the type.
"""
_transformed_types = (datapoints.Datapoint,)
_transformed_types = (tv_tensors.TVTensor,)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> torch.Tensor:
return inpt.as_subclass(torch.Tensor)
......@@ -9,7 +9,7 @@ from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple
import PIL.Image
import torch
from torchvision import datapoints
from torchvision import tv_tensors
from torchvision._utils import sequence_to_str
......@@ -149,10 +149,10 @@ def _parse_labels_getter(
raise ValueError(f"labels_getter should either be 'default', a callable, or None, but got {labels_getter}.")
def get_bounding_boxes(flat_inputs: List[Any]) -> datapoints.BoundingBoxes:
def get_bounding_boxes(flat_inputs: List[Any]) -> tv_tensors.BoundingBoxes:
# This assumes there is only one bbox per sample as per the general convention
try:
return next(inpt for inpt in flat_inputs if isinstance(inpt, datapoints.BoundingBoxes))
return next(inpt for inpt in flat_inputs if isinstance(inpt, tv_tensors.BoundingBoxes))
except StopIteration:
raise ValueError("No bounding boxes were found in the sample")
......@@ -161,7 +161,7 @@ def query_chw(flat_inputs: List[Any]) -> Tuple[int, int, int]:
chws = {
tuple(get_dimensions(inpt))
for inpt in flat_inputs
if check_type(inpt, (is_pure_tensor, datapoints.Image, PIL.Image.Image, datapoints.Video))
if check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video))
}
if not chws:
raise TypeError("No image or video was found in the sample")
......@@ -179,11 +179,11 @@ def query_size(flat_inputs: List[Any]) -> Tuple[int, int]:
inpt,
(
is_pure_tensor,
datapoints.Image,
tv_tensors.Image,
PIL.Image.Image,
datapoints.Video,
datapoints.Mask,
datapoints.BoundingBoxes,
tv_tensors.Video,
tv_tensors.Mask,
tv_tensors.BoundingBoxes,
),
)
}
......
import PIL.Image
import torch
from torchvision import datapoints
from torchvision import tv_tensors
from torchvision.transforms.functional import pil_to_tensor, to_pil_image
from torchvision.utils import _log_api_usage_once
......@@ -28,7 +28,7 @@ def erase(
@_register_kernel_internal(erase, torch.Tensor)
@_register_kernel_internal(erase, datapoints.Image)
@_register_kernel_internal(erase, tv_tensors.Image)
def erase_image(
image: torch.Tensor, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False
) -> torch.Tensor:
......@@ -48,7 +48,7 @@ def _erase_image_pil(
return to_pil_image(output, mode=image.mode)
@_register_kernel_internal(erase, datapoints.Video)
@_register_kernel_internal(erase, tv_tensors.Video)
def erase_video(
video: torch.Tensor, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False
) -> torch.Tensor:
......
......@@ -3,7 +3,7 @@ from typing import List
import PIL.Image
import torch
from torch.nn.functional import conv2d
from torchvision import datapoints
from torchvision import tv_tensors
from torchvision.transforms import _functional_pil as _FP
from torchvision.transforms._functional_tensor import _max_value
......@@ -47,7 +47,7 @@ def _rgb_to_grayscale_image(
@_register_kernel_internal(rgb_to_grayscale, torch.Tensor)
@_register_kernel_internal(rgb_to_grayscale, datapoints.Image)
@_register_kernel_internal(rgb_to_grayscale, tv_tensors.Image)
def rgb_to_grayscale_image(image: torch.Tensor, num_output_channels: int = 1) -> torch.Tensor:
if num_output_channels not in (1, 3):
raise ValueError(f"num_output_channels must be 1 or 3, got {num_output_channels}.")
......@@ -82,7 +82,7 @@ def adjust_brightness(inpt: torch.Tensor, brightness_factor: float) -> torch.Ten
@_register_kernel_internal(adjust_brightness, torch.Tensor)
@_register_kernel_internal(adjust_brightness, datapoints.Image)
@_register_kernel_internal(adjust_brightness, tv_tensors.Image)
def adjust_brightness_image(image: torch.Tensor, brightness_factor: float) -> torch.Tensor:
if brightness_factor < 0:
raise ValueError(f"brightness_factor ({brightness_factor}) is not non-negative.")
......@@ -102,7 +102,7 @@ def _adjust_brightness_image_pil(image: PIL.Image.Image, brightness_factor: floa
return _FP.adjust_brightness(image, brightness_factor=brightness_factor)
@_register_kernel_internal(adjust_brightness, datapoints.Video)
@_register_kernel_internal(adjust_brightness, tv_tensors.Video)
def adjust_brightness_video(video: torch.Tensor, brightness_factor: float) -> torch.Tensor:
return adjust_brightness_image(video, brightness_factor=brightness_factor)
......@@ -119,7 +119,7 @@ def adjust_saturation(inpt: torch.Tensor, saturation_factor: float) -> torch.Ten
@_register_kernel_internal(adjust_saturation, torch.Tensor)
@_register_kernel_internal(adjust_saturation, datapoints.Image)
@_register_kernel_internal(adjust_saturation, tv_tensors.Image)
def adjust_saturation_image(image: torch.Tensor, saturation_factor: float) -> torch.Tensor:
if saturation_factor < 0:
raise ValueError(f"saturation_factor ({saturation_factor}) is not non-negative.")
......@@ -141,7 +141,7 @@ def adjust_saturation_image(image: torch.Tensor, saturation_factor: float) -> to
_adjust_saturation_image_pil = _register_kernel_internal(adjust_saturation, PIL.Image.Image)(_FP.adjust_saturation)
@_register_kernel_internal(adjust_saturation, datapoints.Video)
@_register_kernel_internal(adjust_saturation, tv_tensors.Video)
def adjust_saturation_video(video: torch.Tensor, saturation_factor: float) -> torch.Tensor:
return adjust_saturation_image(video, saturation_factor=saturation_factor)
......@@ -158,7 +158,7 @@ def adjust_contrast(inpt: torch.Tensor, contrast_factor: float) -> torch.Tensor:
@_register_kernel_internal(adjust_contrast, torch.Tensor)
@_register_kernel_internal(adjust_contrast, datapoints.Image)
@_register_kernel_internal(adjust_contrast, tv_tensors.Image)
def adjust_contrast_image(image: torch.Tensor, contrast_factor: float) -> torch.Tensor:
if contrast_factor < 0:
raise ValueError(f"contrast_factor ({contrast_factor}) is not non-negative.")
......@@ -180,7 +180,7 @@ def adjust_contrast_image(image: torch.Tensor, contrast_factor: float) -> torch.
_adjust_contrast_image_pil = _register_kernel_internal(adjust_contrast, PIL.Image.Image)(_FP.adjust_contrast)
@_register_kernel_internal(adjust_contrast, datapoints.Video)
@_register_kernel_internal(adjust_contrast, tv_tensors.Video)
def adjust_contrast_video(video: torch.Tensor, contrast_factor: float) -> torch.Tensor:
return adjust_contrast_image(video, contrast_factor=contrast_factor)
......@@ -197,7 +197,7 @@ def adjust_sharpness(inpt: torch.Tensor, sharpness_factor: float) -> torch.Tenso
@_register_kernel_internal(adjust_sharpness, torch.Tensor)
@_register_kernel_internal(adjust_sharpness, datapoints.Image)
@_register_kernel_internal(adjust_sharpness, tv_tensors.Image)
def adjust_sharpness_image(image: torch.Tensor, sharpness_factor: float) -> torch.Tensor:
num_channels, height, width = image.shape[-3:]
if num_channels not in (1, 3):
......@@ -253,7 +253,7 @@ def adjust_sharpness_image(image: torch.Tensor, sharpness_factor: float) -> torc
_adjust_sharpness_image_pil = _register_kernel_internal(adjust_sharpness, PIL.Image.Image)(_FP.adjust_sharpness)
@_register_kernel_internal(adjust_sharpness, datapoints.Video)
@_register_kernel_internal(adjust_sharpness, tv_tensors.Video)
def adjust_sharpness_video(video: torch.Tensor, sharpness_factor: float) -> torch.Tensor:
return adjust_sharpness_image(video, sharpness_factor=sharpness_factor)
......@@ -340,7 +340,7 @@ def _hsv_to_rgb(img: torch.Tensor) -> torch.Tensor:
@_register_kernel_internal(adjust_hue, torch.Tensor)
@_register_kernel_internal(adjust_hue, datapoints.Image)
@_register_kernel_internal(adjust_hue, tv_tensors.Image)
def adjust_hue_image(image: torch.Tensor, hue_factor: float) -> torch.Tensor:
if not (-0.5 <= hue_factor <= 0.5):
raise ValueError(f"hue_factor ({hue_factor}) is not in [-0.5, 0.5].")
......@@ -371,7 +371,7 @@ def adjust_hue_image(image: torch.Tensor, hue_factor: float) -> torch.Tensor:
_adjust_hue_image_pil = _register_kernel_internal(adjust_hue, PIL.Image.Image)(_FP.adjust_hue)
@_register_kernel_internal(adjust_hue, datapoints.Video)
@_register_kernel_internal(adjust_hue, tv_tensors.Video)
def adjust_hue_video(video: torch.Tensor, hue_factor: float) -> torch.Tensor:
return adjust_hue_image(video, hue_factor=hue_factor)
......@@ -388,7 +388,7 @@ def adjust_gamma(inpt: torch.Tensor, gamma: float, gain: float = 1) -> torch.Ten
@_register_kernel_internal(adjust_gamma, torch.Tensor)
@_register_kernel_internal(adjust_gamma, datapoints.Image)
@_register_kernel_internal(adjust_gamma, tv_tensors.Image)
def adjust_gamma_image(image: torch.Tensor, gamma: float, gain: float = 1.0) -> torch.Tensor:
if gamma < 0:
raise ValueError("Gamma should be a non-negative real number")
......@@ -411,7 +411,7 @@ def adjust_gamma_image(image: torch.Tensor, gamma: float, gain: float = 1.0) ->
_adjust_gamma_image_pil = _register_kernel_internal(adjust_gamma, PIL.Image.Image)(_FP.adjust_gamma)
@_register_kernel_internal(adjust_gamma, datapoints.Video)
@_register_kernel_internal(adjust_gamma, tv_tensors.Video)
def adjust_gamma_video(video: torch.Tensor, gamma: float, gain: float = 1) -> torch.Tensor:
return adjust_gamma_image(video, gamma=gamma, gain=gain)
......@@ -428,7 +428,7 @@ def posterize(inpt: torch.Tensor, bits: int) -> torch.Tensor:
@_register_kernel_internal(posterize, torch.Tensor)
@_register_kernel_internal(posterize, datapoints.Image)
@_register_kernel_internal(posterize, tv_tensors.Image)
def posterize_image(image: torch.Tensor, bits: int) -> torch.Tensor:
if image.is_floating_point():
levels = 1 << bits
......@@ -445,7 +445,7 @@ def posterize_image(image: torch.Tensor, bits: int) -> torch.Tensor:
_posterize_image_pil = _register_kernel_internal(posterize, PIL.Image.Image)(_FP.posterize)
@_register_kernel_internal(posterize, datapoints.Video)
@_register_kernel_internal(posterize, tv_tensors.Video)
def posterize_video(video: torch.Tensor, bits: int) -> torch.Tensor:
return posterize_image(video, bits=bits)
......@@ -462,7 +462,7 @@ def solarize(inpt: torch.Tensor, threshold: float) -> torch.Tensor:
@_register_kernel_internal(solarize, torch.Tensor)
@_register_kernel_internal(solarize, datapoints.Image)
@_register_kernel_internal(solarize, tv_tensors.Image)
def solarize_image(image: torch.Tensor, threshold: float) -> torch.Tensor:
if threshold > _max_value(image.dtype):
raise TypeError(f"Threshold should be less or equal the maximum value of the dtype, but got {threshold}")
......@@ -473,7 +473,7 @@ def solarize_image(image: torch.Tensor, threshold: float) -> torch.Tensor:
_solarize_image_pil = _register_kernel_internal(solarize, PIL.Image.Image)(_FP.solarize)
@_register_kernel_internal(solarize, datapoints.Video)
@_register_kernel_internal(solarize, tv_tensors.Video)
def solarize_video(video: torch.Tensor, threshold: float) -> torch.Tensor:
return solarize_image(video, threshold=threshold)
......@@ -490,7 +490,7 @@ def autocontrast(inpt: torch.Tensor) -> torch.Tensor:
@_register_kernel_internal(autocontrast, torch.Tensor)
@_register_kernel_internal(autocontrast, datapoints.Image)
@_register_kernel_internal(autocontrast, tv_tensors.Image)
def autocontrast_image(image: torch.Tensor) -> torch.Tensor:
c = image.shape[-3]
if c not in [1, 3]:
......@@ -523,7 +523,7 @@ def autocontrast_image(image: torch.Tensor) -> torch.Tensor:
_autocontrast_image_pil = _register_kernel_internal(autocontrast, PIL.Image.Image)(_FP.autocontrast)
@_register_kernel_internal(autocontrast, datapoints.Video)
@_register_kernel_internal(autocontrast, tv_tensors.Video)
def autocontrast_video(video: torch.Tensor) -> torch.Tensor:
return autocontrast_image(video)
......@@ -540,7 +540,7 @@ def equalize(inpt: torch.Tensor) -> torch.Tensor:
@_register_kernel_internal(equalize, torch.Tensor)
@_register_kernel_internal(equalize, datapoints.Image)
@_register_kernel_internal(equalize, tv_tensors.Image)
def equalize_image(image: torch.Tensor) -> torch.Tensor:
if image.numel() == 0:
return image
......@@ -613,7 +613,7 @@ def equalize_image(image: torch.Tensor) -> torch.Tensor:
_equalize_image_pil = _register_kernel_internal(equalize, PIL.Image.Image)(_FP.equalize)
@_register_kernel_internal(equalize, datapoints.Video)
@_register_kernel_internal(equalize, tv_tensors.Video)
def equalize_video(video: torch.Tensor) -> torch.Tensor:
return equalize_image(video)
......@@ -630,7 +630,7 @@ def invert(inpt: torch.Tensor) -> torch.Tensor:
@_register_kernel_internal(invert, torch.Tensor)
@_register_kernel_internal(invert, datapoints.Image)
@_register_kernel_internal(invert, tv_tensors.Image)
def invert_image(image: torch.Tensor) -> torch.Tensor:
if image.is_floating_point():
return 1.0 - image
......@@ -644,7 +644,7 @@ def invert_image(image: torch.Tensor) -> torch.Tensor:
_invert_image_pil = _register_kernel_internal(invert, PIL.Image.Image)(_FP.invert)
@_register_kernel_internal(invert, datapoints.Video)
@_register_kernel_internal(invert, tv_tensors.Video)
def invert_video(video: torch.Tensor) -> torch.Tensor:
return invert_image(video)
......@@ -653,7 +653,7 @@ def permute_channels(inpt: torch.Tensor, permutation: List[int]) -> torch.Tensor
"""Permute the channels of the input according to the given permutation.
This function supports plain :class:`~torch.Tensor`'s, :class:`PIL.Image.Image`'s, and
:class:`torchvision.datapoints.Image` and :class:`torchvision.datapoints.Video`.
:class:`torchvision.tv_tensors.Image` and :class:`torchvision.tv_tensors.Video`.
Example:
>>> rgb_image = torch.rand(3, 256, 256)
......@@ -681,7 +681,7 @@ def permute_channels(inpt: torch.Tensor, permutation: List[int]) -> torch.Tensor
@_register_kernel_internal(permute_channels, torch.Tensor)
@_register_kernel_internal(permute_channels, datapoints.Image)
@_register_kernel_internal(permute_channels, tv_tensors.Image)
def permute_channels_image(image: torch.Tensor, permutation: List[int]) -> torch.Tensor:
shape = image.shape
num_channels, height, width = shape[-3:]
......@@ -704,6 +704,6 @@ def _permute_channels_image_pil(image: PIL.Image.Image, permutation: List[int])
return to_pil_image(permute_channels_image(pil_to_tensor(image), permutation=permutation))
@_register_kernel_internal(permute_channels, datapoints.Video)
@_register_kernel_internal(permute_channels, tv_tensors.Video)
def permute_channels_video(video: torch.Tensor, permutation: List[int]) -> torch.Tensor:
return permute_channels_image(video, permutation=permutation)
......@@ -2,9 +2,9 @@ from typing import List, Optional, Tuple
import PIL.Image
import torch
from torchvision import datapoints
from torchvision.datapoints import BoundingBoxFormat
from torchvision import tv_tensors
from torchvision.transforms import _functional_pil as _FP
from torchvision.tv_tensors import BoundingBoxFormat
from torchvision.utils import _log_api_usage_once
......@@ -22,7 +22,7 @@ def get_dimensions(inpt: torch.Tensor) -> List[int]:
@_register_kernel_internal(get_dimensions, torch.Tensor)
@_register_kernel_internal(get_dimensions, datapoints.Image, datapoint_wrapper=False)
@_register_kernel_internal(get_dimensions, tv_tensors.Image, tv_tensor_wrapper=False)
def get_dimensions_image(image: torch.Tensor) -> List[int]:
chw = list(image.shape[-3:])
ndims = len(chw)
......@@ -38,7 +38,7 @@ def get_dimensions_image(image: torch.Tensor) -> List[int]:
_get_dimensions_image_pil = _register_kernel_internal(get_dimensions, PIL.Image.Image)(_FP.get_dimensions)
@_register_kernel_internal(get_dimensions, datapoints.Video, datapoint_wrapper=False)
@_register_kernel_internal(get_dimensions, tv_tensors.Video, tv_tensor_wrapper=False)
def get_dimensions_video(video: torch.Tensor) -> List[int]:
return get_dimensions_image(video)
......@@ -54,7 +54,7 @@ def get_num_channels(inpt: torch.Tensor) -> int:
@_register_kernel_internal(get_num_channels, torch.Tensor)
@_register_kernel_internal(get_num_channels, datapoints.Image, datapoint_wrapper=False)
@_register_kernel_internal(get_num_channels, tv_tensors.Image, tv_tensor_wrapper=False)
def get_num_channels_image(image: torch.Tensor) -> int:
chw = image.shape[-3:]
ndims = len(chw)
......@@ -69,7 +69,7 @@ def get_num_channels_image(image: torch.Tensor) -> int:
_get_num_channels_image_pil = _register_kernel_internal(get_num_channels, PIL.Image.Image)(_FP.get_image_num_channels)
@_register_kernel_internal(get_num_channels, datapoints.Video, datapoint_wrapper=False)
@_register_kernel_internal(get_num_channels, tv_tensors.Video, tv_tensor_wrapper=False)
def get_num_channels_video(video: torch.Tensor) -> int:
return get_num_channels_image(video)
......@@ -90,7 +90,7 @@ def get_size(inpt: torch.Tensor) -> List[int]:
@_register_kernel_internal(get_size, torch.Tensor)
@_register_kernel_internal(get_size, datapoints.Image, datapoint_wrapper=False)
@_register_kernel_internal(get_size, tv_tensors.Image, tv_tensor_wrapper=False)
def get_size_image(image: torch.Tensor) -> List[int]:
hw = list(image.shape[-2:])
ndims = len(hw)
......@@ -106,18 +106,18 @@ def _get_size_image_pil(image: PIL.Image.Image) -> List[int]:
return [height, width]
@_register_kernel_internal(get_size, datapoints.Video, datapoint_wrapper=False)
@_register_kernel_internal(get_size, tv_tensors.Video, tv_tensor_wrapper=False)
def get_size_video(video: torch.Tensor) -> List[int]:
return get_size_image(video)
@_register_kernel_internal(get_size, datapoints.Mask, datapoint_wrapper=False)
@_register_kernel_internal(get_size, tv_tensors.Mask, tv_tensor_wrapper=False)
def get_size_mask(mask: torch.Tensor) -> List[int]:
return get_size_image(mask)
@_register_kernel_internal(get_size, datapoints.BoundingBoxes, datapoint_wrapper=False)
def get_size_bounding_boxes(bounding_box: datapoints.BoundingBoxes) -> List[int]:
@_register_kernel_internal(get_size, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False)
def get_size_bounding_boxes(bounding_box: tv_tensors.BoundingBoxes) -> List[int]:
return list(bounding_box.canvas_size)
......@@ -132,7 +132,7 @@ def get_num_frames(inpt: torch.Tensor) -> int:
@_register_kernel_internal(get_num_frames, torch.Tensor)
@_register_kernel_internal(get_num_frames, datapoints.Video, datapoint_wrapper=False)
@_register_kernel_internal(get_num_frames, tv_tensors.Video, tv_tensor_wrapper=False)
def get_num_frames_video(video: torch.Tensor) -> int:
return video.shape[-4]
......@@ -205,7 +205,7 @@ def convert_bounding_box_format(
) -> torch.Tensor:
"""[BETA] See :func:`~torchvision.transforms.v2.ConvertBoundingBoxFormat` for details."""
# This being a kernel / functional hybrid, we need an option to pass `old_format` explicitly for pure tensor
# inputs as well as extract it from `datapoints.BoundingBoxes` inputs. However, putting a default value on
# inputs as well as extract it from `tv_tensors.BoundingBoxes` inputs. However, putting a default value on
# `old_format` means we also need to put one on `new_format` to have syntactically correct Python. Here we mimic the
# default error that would be thrown if `new_format` had no default value.
if new_format is None:
......@@ -218,16 +218,16 @@ def convert_bounding_box_format(
if old_format is None:
raise ValueError("For pure tensor inputs, `old_format` has to be passed.")
return _convert_bounding_box_format(inpt, old_format=old_format, new_format=new_format, inplace=inplace)
elif isinstance(inpt, datapoints.BoundingBoxes):
elif isinstance(inpt, tv_tensors.BoundingBoxes):
if old_format is not None:
raise ValueError("For bounding box datapoint inputs, `old_format` must not be passed.")
raise ValueError("For bounding box tv_tensor inputs, `old_format` must not be passed.")
output = _convert_bounding_box_format(
inpt.as_subclass(torch.Tensor), old_format=inpt.format, new_format=new_format, inplace=inplace
)
return datapoints.wrap(output, like=inpt, format=new_format)
return tv_tensors.wrap(output, like=inpt, format=new_format)
else:
raise TypeError(
f"Input can either be a plain tensor or a bounding box datapoint, but got {type(inpt)} instead."
f"Input can either be a plain tensor or a bounding box tv_tensor, but got {type(inpt)} instead."
)
......@@ -239,7 +239,7 @@ def _clamp_bounding_boxes(
in_dtype = bounding_boxes.dtype
bounding_boxes = bounding_boxes.clone() if bounding_boxes.is_floating_point() else bounding_boxes.float()
xyxy_boxes = convert_bounding_box_format(
bounding_boxes, old_format=format, new_format=datapoints.BoundingBoxFormat.XYXY, inplace=True
bounding_boxes, old_format=format, new_format=tv_tensors.BoundingBoxFormat.XYXY, inplace=True
)
xyxy_boxes[..., 0::2].clamp_(min=0, max=canvas_size[1])
xyxy_boxes[..., 1::2].clamp_(min=0, max=canvas_size[0])
......@@ -263,12 +263,12 @@ def clamp_bounding_boxes(
if format is None or canvas_size is None:
raise ValueError("For pure tensor inputs, `format` and `canvas_size` has to be passed.")
return _clamp_bounding_boxes(inpt, format=format, canvas_size=canvas_size)
elif isinstance(inpt, datapoints.BoundingBoxes):
elif isinstance(inpt, tv_tensors.BoundingBoxes):
if format is not None or canvas_size is not None:
raise ValueError("For bounding box datapoint inputs, `format` and `canvas_size` must not be passed.")
raise ValueError("For bounding box tv_tensor inputs, `format` and `canvas_size` must not be passed.")
output = _clamp_bounding_boxes(inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size)
return datapoints.wrap(output, like=inpt)
return tv_tensors.wrap(output, like=inpt)
else:
raise TypeError(
f"Input can either be a plain tensor or a bounding box datapoint, but got {type(inpt)} instead."
f"Input can either be a plain tensor or a bounding box tv_tensor, but got {type(inpt)} instead."
)
import torch
from torchvision import datapoints
from torchvision import tv_tensors
from torchvision.utils import _log_api_usage_once
......@@ -19,7 +19,7 @@ def uniform_temporal_subsample(inpt: torch.Tensor, num_samples: int) -> torch.Te
@_register_kernel_internal(uniform_temporal_subsample, torch.Tensor)
@_register_kernel_internal(uniform_temporal_subsample, datapoints.Video)
@_register_kernel_internal(uniform_temporal_subsample, tv_tensors.Video)
def uniform_temporal_subsample_video(video: torch.Tensor, num_samples: int) -> torch.Tensor:
# Reference: https://github.com/facebookresearch/pytorchvideo/blob/a0a131e/pytorchvideo/transforms/functional.py#L19
t_max = video.shape[-4] - 1
......
......@@ -3,12 +3,12 @@ from typing import Union
import numpy as np
import PIL.Image
import torch
from torchvision import datapoints
from torchvision import tv_tensors
from torchvision.transforms import functional as _F
@torch.jit.unused
def to_image(inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray]) -> datapoints.Image:
def to_image(inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray]) -> tv_tensors.Image:
"""[BETA] See :class:`~torchvision.transforms.v2.ToImage` for details."""
if isinstance(inpt, np.ndarray):
output = torch.from_numpy(inpt).permute((2, 0, 1)).contiguous()
......@@ -18,7 +18,7 @@ def to_image(inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray]) -> datapoin
output = inpt
else:
raise TypeError(f"Input can either be a numpy array or a PIL image, but got {type(inpt)} instead.")
return datapoints.Image(output)
return tv_tensors.Image(output)
to_pil_image = _F.to_pil_image
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment