Unverified Commit a8007dcd authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

rename features._Feature to datapoints._Datapoint (#7002)

* rename features._Feature to datapoints.Datapoint

* _Datapoint to Datapoint

* move is_simple_tensor to transforms.utils

* fix CI

* move Datapoint out of public namespace
parent c093b9c0
...@@ -5,9 +5,9 @@ import pathlib ...@@ -5,9 +5,9 @@ import pathlib
from typing import Any, BinaryIO, Collection, Dict, List, Optional, Tuple, Union from typing import Any, BinaryIO, Collection, Dict, List, Optional, Tuple, Union
from torchdata.datapipes.iter import FileLister, FileOpener, Filter, IterDataPipe, Mapper from torchdata.datapipes.iter import FileLister, FileOpener, Filter, IterDataPipe, Mapper
from torchvision.prototype.datapoints import Label
from torchvision.prototype.datasets.utils import EncodedData, EncodedImage from torchvision.prototype.datasets.utils import EncodedData, EncodedImage
from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling
from torchvision.prototype.features import Label
__all__ = ["from_data_folder", "from_image_folder"] __all__ = ["from_data_folder", "from_image_folder"]
......
...@@ -7,13 +7,13 @@ from typing import Any, BinaryIO, Optional, Tuple, Type, TypeVar, Union ...@@ -7,13 +7,13 @@ from typing import Any, BinaryIO, Optional, Tuple, Type, TypeVar, Union
import PIL.Image import PIL.Image
import torch import torch
from torchvision.prototype.features._feature import _Feature from torchvision.prototype.datapoints._datapoint import Datapoint
from torchvision.prototype.utils._internal import fromfile, ReadOnlyTensorBuffer from torchvision.prototype.utils._internal import fromfile, ReadOnlyTensorBuffer
D = TypeVar("D", bound="EncodedData") D = TypeVar("D", bound="EncodedData")
class EncodedData(_Feature): class EncodedData(Datapoint):
@classmethod @classmethod
def _wrap(cls: Type[D], tensor: torch.Tensor) -> D: def _wrap(cls: Type[D], tensor: torch.Tensor) -> D:
return tensor.as_subclass(cls) return tensor.as_subclass(cls)
......
...@@ -6,16 +6,17 @@ from typing import Any, cast, Dict, List, Optional, Tuple, Union ...@@ -6,16 +6,17 @@ from typing import Any, cast, Dict, List, Optional, Tuple, Union
import PIL.Image import PIL.Image
import torch import torch
from torch.utils._pytree import tree_flatten, tree_unflatten from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision.ops import masks_to_boxes from torchvision.ops import masks_to_boxes
from torchvision.prototype import features from torchvision.prototype import datapoints
from torchvision.prototype.transforms import functional as F, InterpolationMode from torchvision.prototype.transforms import functional as F, InterpolationMode
from ._transform import _RandomApplyTransform from ._transform import _RandomApplyTransform
from .utils import has_any, query_chw, query_spatial_size from .utils import has_any, is_simple_tensor, query_chw, query_spatial_size
class RandomErasing(_RandomApplyTransform): class RandomErasing(_RandomApplyTransform):
_transformed_types = (features.is_simple_tensor, features.Image, PIL.Image.Image, features.Video) _transformed_types = (is_simple_tensor, datapoints.Image, PIL.Image.Image, datapoints.Video)
def __init__( def __init__(
self, self,
...@@ -91,8 +92,8 @@ class RandomErasing(_RandomApplyTransform): ...@@ -91,8 +92,8 @@ class RandomErasing(_RandomApplyTransform):
return dict(i=i, j=j, h=h, w=w, v=v) return dict(i=i, j=j, h=h, w=w, v=v)
def _transform( def _transform(
self, inpt: Union[features.ImageType, features.VideoType], params: Dict[str, Any] self, inpt: Union[datapoints.ImageType, datapoints.VideoType], params: Dict[str, Any]
) -> Union[features.ImageType, features.VideoType]: ) -> Union[datapoints.ImageType, datapoints.VideoType]:
if params["v"] is not None: if params["v"] is not None:
inpt = F.erase(inpt, **params, inplace=self.inplace) inpt = F.erase(inpt, **params, inplace=self.inplace)
...@@ -107,20 +108,20 @@ class _BaseMixupCutmix(_RandomApplyTransform): ...@@ -107,20 +108,20 @@ class _BaseMixupCutmix(_RandomApplyTransform):
def _check_inputs(self, flat_inputs: List[Any]) -> None: def _check_inputs(self, flat_inputs: List[Any]) -> None:
if not ( if not (
has_any(flat_inputs, features.Image, features.Video, features.is_simple_tensor) has_any(flat_inputs, datapoints.Image, datapoints.Video, is_simple_tensor)
and has_any(flat_inputs, features.OneHotLabel) and has_any(flat_inputs, datapoints.OneHotLabel)
): ):
raise TypeError(f"{type(self).__name__}() is only defined for tensor images/videos and one-hot labels.") raise TypeError(f"{type(self).__name__}() is only defined for tensor images/videos and one-hot labels.")
if has_any(flat_inputs, PIL.Image.Image, features.BoundingBox, features.Mask, features.Label): if has_any(flat_inputs, PIL.Image.Image, datapoints.BoundingBox, datapoints.Mask, datapoints.Label):
raise TypeError( raise TypeError(
f"{type(self).__name__}() does not support PIL images, bounding boxes, masks and plain labels." f"{type(self).__name__}() does not support PIL images, bounding boxes, masks and plain labels."
) )
def _mixup_onehotlabel(self, inpt: features.OneHotLabel, lam: float) -> features.OneHotLabel: def _mixup_onehotlabel(self, inpt: datapoints.OneHotLabel, lam: float) -> datapoints.OneHotLabel:
if inpt.ndim < 2: if inpt.ndim < 2:
raise ValueError("Need a batch of one hot labels") raise ValueError("Need a batch of one hot labels")
output = inpt.roll(1, 0).mul_(1.0 - lam).add_(inpt.mul(lam)) output = inpt.roll(1, 0).mul_(1.0 - lam).add_(inpt.mul(lam))
return features.OneHotLabel.wrap_like(inpt, output) return datapoints.OneHotLabel.wrap_like(inpt, output)
class RandomMixup(_BaseMixupCutmix): class RandomMixup(_BaseMixupCutmix):
...@@ -129,17 +130,17 @@ class RandomMixup(_BaseMixupCutmix): ...@@ -129,17 +130,17 @@ class RandomMixup(_BaseMixupCutmix):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
lam = params["lam"] lam = params["lam"]
if isinstance(inpt, (features.Image, features.Video)) or features.is_simple_tensor(inpt): if isinstance(inpt, (datapoints.Image, datapoints.Video)) or is_simple_tensor(inpt):
expected_ndim = 5 if isinstance(inpt, features.Video) else 4 expected_ndim = 5 if isinstance(inpt, datapoints.Video) else 4
if inpt.ndim < expected_ndim: if inpt.ndim < expected_ndim:
raise ValueError("The transform expects a batched input") raise ValueError("The transform expects a batched input")
output = inpt.roll(1, 0).mul_(1.0 - lam).add_(inpt.mul(lam)) output = inpt.roll(1, 0).mul_(1.0 - lam).add_(inpt.mul(lam))
if isinstance(inpt, (features.Image, features.Video)): if isinstance(inpt, (datapoints.Image, datapoints.Video)):
output = type(inpt).wrap_like(inpt, output) # type: ignore[arg-type] output = type(inpt).wrap_like(inpt, output) # type: ignore[arg-type]
return output return output
elif isinstance(inpt, features.OneHotLabel): elif isinstance(inpt, datapoints.OneHotLabel):
return self._mixup_onehotlabel(inpt, lam) return self._mixup_onehotlabel(inpt, lam)
else: else:
return inpt return inpt
...@@ -169,9 +170,9 @@ class RandomCutmix(_BaseMixupCutmix): ...@@ -169,9 +170,9 @@ class RandomCutmix(_BaseMixupCutmix):
return dict(box=box, lam_adjusted=lam_adjusted) return dict(box=box, lam_adjusted=lam_adjusted)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if isinstance(inpt, (features.Image, features.Video)) or features.is_simple_tensor(inpt): if isinstance(inpt, (datapoints.Image, datapoints.Video)) or is_simple_tensor(inpt):
box = params["box"] box = params["box"]
expected_ndim = 5 if isinstance(inpt, features.Video) else 4 expected_ndim = 5 if isinstance(inpt, datapoints.Video) else 4
if inpt.ndim < expected_ndim: if inpt.ndim < expected_ndim:
raise ValueError("The transform expects a batched input") raise ValueError("The transform expects a batched input")
x1, y1, x2, y2 = box x1, y1, x2, y2 = box
...@@ -179,11 +180,11 @@ class RandomCutmix(_BaseMixupCutmix): ...@@ -179,11 +180,11 @@ class RandomCutmix(_BaseMixupCutmix):
output = inpt.clone() output = inpt.clone()
output[..., y1:y2, x1:x2] = rolled[..., y1:y2, x1:x2] output[..., y1:y2, x1:x2] = rolled[..., y1:y2, x1:x2]
if isinstance(inpt, (features.Image, features.Video)): if isinstance(inpt, (datapoints.Image, datapoints.Video)):
output = inpt.wrap_like(inpt, output) # type: ignore[arg-type] output = inpt.wrap_like(inpt, output) # type: ignore[arg-type]
return output return output
elif isinstance(inpt, features.OneHotLabel): elif isinstance(inpt, datapoints.OneHotLabel):
lam_adjusted = params["lam_adjusted"] lam_adjusted = params["lam_adjusted"]
return self._mixup_onehotlabel(inpt, lam_adjusted) return self._mixup_onehotlabel(inpt, lam_adjusted)
else: else:
...@@ -205,15 +206,15 @@ class SimpleCopyPaste(_RandomApplyTransform): ...@@ -205,15 +206,15 @@ class SimpleCopyPaste(_RandomApplyTransform):
def _copy_paste( def _copy_paste(
self, self,
image: features.TensorImageType, image: datapoints.TensorImageType,
target: Dict[str, Any], target: Dict[str, Any],
paste_image: features.TensorImageType, paste_image: datapoints.TensorImageType,
paste_target: Dict[str, Any], paste_target: Dict[str, Any],
random_selection: torch.Tensor, random_selection: torch.Tensor,
blending: bool, blending: bool,
resize_interpolation: F.InterpolationMode, resize_interpolation: F.InterpolationMode,
antialias: Optional[bool], antialias: Optional[bool],
) -> Tuple[features.TensorImageType, Dict[str, Any]]: ) -> Tuple[datapoints.TensorImageType, Dict[str, Any]]:
paste_masks = paste_target["masks"].wrap_like(paste_target["masks"], paste_target["masks"][random_selection]) paste_masks = paste_target["masks"].wrap_like(paste_target["masks"], paste_target["masks"][random_selection])
paste_boxes = paste_target["boxes"].wrap_like(paste_target["boxes"], paste_target["boxes"][random_selection]) paste_boxes = paste_target["boxes"].wrap_like(paste_target["boxes"], paste_target["boxes"][random_selection])
...@@ -262,7 +263,7 @@ class SimpleCopyPaste(_RandomApplyTransform): ...@@ -262,7 +263,7 @@ class SimpleCopyPaste(_RandomApplyTransform):
# https://github.com/pytorch/vision/blob/b6feccbc4387766b76a3e22b13815dbbbfa87c0f/torchvision/models/detection/roi_heads.py#L418-L422 # https://github.com/pytorch/vision/blob/b6feccbc4387766b76a3e22b13815dbbbfa87c0f/torchvision/models/detection/roi_heads.py#L418-L422
xyxy_boxes[:, 2:] += 1 xyxy_boxes[:, 2:] += 1
boxes = F.convert_format_bounding_box( boxes = F.convert_format_bounding_box(
xyxy_boxes, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox_format, inplace=True xyxy_boxes, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=bbox_format, inplace=True
) )
out_target["boxes"] = torch.cat([boxes, paste_boxes]) out_target["boxes"] = torch.cat([boxes, paste_boxes])
...@@ -271,7 +272,7 @@ class SimpleCopyPaste(_RandomApplyTransform): ...@@ -271,7 +272,7 @@ class SimpleCopyPaste(_RandomApplyTransform):
# Check for degenerated boxes and remove them # Check for degenerated boxes and remove them
boxes = F.convert_format_bounding_box( boxes = F.convert_format_bounding_box(
out_target["boxes"], old_format=bbox_format, new_format=features.BoundingBoxFormat.XYXY out_target["boxes"], old_format=bbox_format, new_format=datapoints.BoundingBoxFormat.XYXY
) )
degenerate_boxes = boxes[:, 2:] <= boxes[:, :2] degenerate_boxes = boxes[:, 2:] <= boxes[:, :2]
if degenerate_boxes.any(): if degenerate_boxes.any():
...@@ -285,20 +286,20 @@ class SimpleCopyPaste(_RandomApplyTransform): ...@@ -285,20 +286,20 @@ class SimpleCopyPaste(_RandomApplyTransform):
def _extract_image_targets( def _extract_image_targets(
self, flat_sample: List[Any] self, flat_sample: List[Any]
) -> Tuple[List[features.TensorImageType], List[Dict[str, Any]]]: ) -> Tuple[List[datapoints.TensorImageType], List[Dict[str, Any]]]:
# fetch all images, bboxes, masks and labels from unstructured input # fetch all images, bboxes, masks and labels from unstructured input
# with List[image], List[BoundingBox], List[Mask], List[Label] # with List[image], List[BoundingBox], List[Mask], List[Label]
images, bboxes, masks, labels = [], [], [], [] images, bboxes, masks, labels = [], [], [], []
for obj in flat_sample: for obj in flat_sample:
if isinstance(obj, features.Image) or features.is_simple_tensor(obj): if isinstance(obj, datapoints.Image) or is_simple_tensor(obj):
images.append(obj) images.append(obj)
elif isinstance(obj, PIL.Image.Image): elif isinstance(obj, PIL.Image.Image):
images.append(F.to_image_tensor(obj)) images.append(F.to_image_tensor(obj))
elif isinstance(obj, features.BoundingBox): elif isinstance(obj, datapoints.BoundingBox):
bboxes.append(obj) bboxes.append(obj)
elif isinstance(obj, features.Mask): elif isinstance(obj, datapoints.Mask):
masks.append(obj) masks.append(obj)
elif isinstance(obj, (features.Label, features.OneHotLabel)): elif isinstance(obj, (datapoints.Label, datapoints.OneHotLabel)):
labels.append(obj) labels.append(obj)
if not (len(images) == len(bboxes) == len(masks) == len(labels)): if not (len(images) == len(bboxes) == len(masks) == len(labels)):
...@@ -316,27 +317,27 @@ class SimpleCopyPaste(_RandomApplyTransform): ...@@ -316,27 +317,27 @@ class SimpleCopyPaste(_RandomApplyTransform):
def _insert_outputs( def _insert_outputs(
self, self,
flat_sample: List[Any], flat_sample: List[Any],
output_images: List[features.TensorImageType], output_images: List[datapoints.TensorImageType],
output_targets: List[Dict[str, Any]], output_targets: List[Dict[str, Any]],
) -> None: ) -> None:
c0, c1, c2, c3 = 0, 0, 0, 0 c0, c1, c2, c3 = 0, 0, 0, 0
for i, obj in enumerate(flat_sample): for i, obj in enumerate(flat_sample):
if isinstance(obj, features.Image): if isinstance(obj, datapoints.Image):
flat_sample[i] = features.Image.wrap_like(obj, output_images[c0]) flat_sample[i] = datapoints.Image.wrap_like(obj, output_images[c0])
c0 += 1 c0 += 1
elif isinstance(obj, PIL.Image.Image): elif isinstance(obj, PIL.Image.Image):
flat_sample[i] = F.to_image_pil(output_images[c0]) flat_sample[i] = F.to_image_pil(output_images[c0])
c0 += 1 c0 += 1
elif features.is_simple_tensor(obj): elif is_simple_tensor(obj):
flat_sample[i] = output_images[c0] flat_sample[i] = output_images[c0]
c0 += 1 c0 += 1
elif isinstance(obj, features.BoundingBox): elif isinstance(obj, datapoints.BoundingBox):
flat_sample[i] = features.BoundingBox.wrap_like(obj, output_targets[c1]["boxes"]) flat_sample[i] = datapoints.BoundingBox.wrap_like(obj, output_targets[c1]["boxes"])
c1 += 1 c1 += 1
elif isinstance(obj, features.Mask): elif isinstance(obj, datapoints.Mask):
flat_sample[i] = features.Mask.wrap_like(obj, output_targets[c2]["masks"]) flat_sample[i] = datapoints.Mask.wrap_like(obj, output_targets[c2]["masks"])
c2 += 1 c2 += 1
elif isinstance(obj, (features.Label, features.OneHotLabel)): elif isinstance(obj, (datapoints.Label, datapoints.OneHotLabel)):
flat_sample[i] = obj.wrap_like(obj, output_targets[c3]["labels"]) # type: ignore[arg-type] flat_sample[i] = obj.wrap_like(obj, output_targets[c3]["labels"]) # type: ignore[arg-type]
c3 += 1 c3 += 1
......
...@@ -5,13 +5,14 @@ import PIL.Image ...@@ -5,13 +5,14 @@ import PIL.Image
import torch import torch
from torch.utils._pytree import tree_flatten, tree_unflatten, TreeSpec from torch.utils._pytree import tree_flatten, tree_unflatten, TreeSpec
from torchvision.prototype import features
from torchvision.prototype import datapoints
from torchvision.prototype.transforms import AutoAugmentPolicy, functional as F, InterpolationMode, Transform from torchvision.prototype.transforms import AutoAugmentPolicy, functional as F, InterpolationMode, Transform
from torchvision.prototype.transforms.functional._meta import get_spatial_size from torchvision.prototype.transforms.functional._meta import get_spatial_size
from torchvision.transforms import functional_tensor as _FT from torchvision.transforms import functional_tensor as _FT
from ._utils import _setup_fill_arg from ._utils import _setup_fill_arg
from .utils import check_type from .utils import check_type, is_simple_tensor
class _AutoAugmentBase(Transform): class _AutoAugmentBase(Transform):
...@@ -19,7 +20,7 @@ class _AutoAugmentBase(Transform): ...@@ -19,7 +20,7 @@ class _AutoAugmentBase(Transform):
self, self,
*, *,
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Union[features.FillType, Dict[Type, features.FillType]] = None, fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.interpolation = interpolation self.interpolation = interpolation
...@@ -33,13 +34,21 @@ class _AutoAugmentBase(Transform): ...@@ -33,13 +34,21 @@ class _AutoAugmentBase(Transform):
def _flatten_and_extract_image_or_video( def _flatten_and_extract_image_or_video(
self, self,
inputs: Any, inputs: Any,
unsupported_types: Tuple[Type, ...] = (features.BoundingBox, features.Mask), unsupported_types: Tuple[Type, ...] = (datapoints.BoundingBox, datapoints.Mask),
) -> Tuple[Tuple[List[Any], TreeSpec, int], Union[features.ImageType, features.VideoType]]: ) -> Tuple[Tuple[List[Any], TreeSpec, int], Union[datapoints.ImageType, datapoints.VideoType]]:
flat_inputs, spec = tree_flatten(inputs if len(inputs) > 1 else inputs[0]) flat_inputs, spec = tree_flatten(inputs if len(inputs) > 1 else inputs[0])
image_or_videos = [] image_or_videos = []
for idx, inpt in enumerate(flat_inputs): for idx, inpt in enumerate(flat_inputs):
if check_type(inpt, (features.Image, PIL.Image.Image, features.is_simple_tensor, features.Video)): if check_type(
inpt,
(
datapoints.Image,
PIL.Image.Image,
is_simple_tensor,
datapoints.Video,
),
):
image_or_videos.append((idx, inpt)) image_or_videos.append((idx, inpt))
elif isinstance(inpt, unsupported_types): elif isinstance(inpt, unsupported_types):
raise TypeError(f"Inputs of type {type(inpt).__name__} are not supported by {type(self).__name__}()") raise TypeError(f"Inputs of type {type(inpt).__name__} are not supported by {type(self).__name__}()")
...@@ -58,7 +67,7 @@ class _AutoAugmentBase(Transform): ...@@ -58,7 +67,7 @@ class _AutoAugmentBase(Transform):
def _unflatten_and_insert_image_or_video( def _unflatten_and_insert_image_or_video(
self, self,
flat_inputs_with_spec: Tuple[List[Any], TreeSpec, int], flat_inputs_with_spec: Tuple[List[Any], TreeSpec, int],
image_or_video: Union[features.ImageType, features.VideoType], image_or_video: Union[datapoints.ImageType, datapoints.VideoType],
) -> Any: ) -> Any:
flat_inputs, spec, idx = flat_inputs_with_spec flat_inputs, spec, idx = flat_inputs_with_spec
flat_inputs[idx] = image_or_video flat_inputs[idx] = image_or_video
...@@ -66,12 +75,12 @@ class _AutoAugmentBase(Transform): ...@@ -66,12 +75,12 @@ class _AutoAugmentBase(Transform):
def _apply_image_or_video_transform( def _apply_image_or_video_transform(
self, self,
image: Union[features.ImageType, features.VideoType], image: Union[datapoints.ImageType, datapoints.VideoType],
transform_id: str, transform_id: str,
magnitude: float, magnitude: float,
interpolation: InterpolationMode, interpolation: InterpolationMode,
fill: Dict[Type, features.FillTypeJIT], fill: Dict[Type, datapoints.FillTypeJIT],
) -> Union[features.ImageType, features.VideoType]: ) -> Union[datapoints.ImageType, datapoints.VideoType]:
fill_ = fill[type(image)] fill_ = fill[type(image)]
if transform_id == "Identity": if transform_id == "Identity":
...@@ -182,7 +191,7 @@ class AutoAugment(_AutoAugmentBase): ...@@ -182,7 +191,7 @@ class AutoAugment(_AutoAugmentBase):
self, self,
policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET, policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET,
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Union[features.FillType, Dict[Type, features.FillType]] = None, fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = None,
) -> None: ) -> None:
super().__init__(interpolation=interpolation, fill=fill) super().__init__(interpolation=interpolation, fill=fill)
self.policy = policy self.policy = policy
...@@ -338,7 +347,7 @@ class RandAugment(_AutoAugmentBase): ...@@ -338,7 +347,7 @@ class RandAugment(_AutoAugmentBase):
magnitude: int = 9, magnitude: int = 9,
num_magnitude_bins: int = 31, num_magnitude_bins: int = 31,
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Union[features.FillType, Dict[Type, features.FillType]] = None, fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = None,
) -> None: ) -> None:
super().__init__(interpolation=interpolation, fill=fill) super().__init__(interpolation=interpolation, fill=fill)
self.num_ops = num_ops self.num_ops = num_ops
...@@ -390,7 +399,7 @@ class TrivialAugmentWide(_AutoAugmentBase): ...@@ -390,7 +399,7 @@ class TrivialAugmentWide(_AutoAugmentBase):
self, self,
num_magnitude_bins: int = 31, num_magnitude_bins: int = 31,
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Union[features.FillType, Dict[Type, features.FillType]] = None, fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = None,
): ):
super().__init__(interpolation=interpolation, fill=fill) super().__init__(interpolation=interpolation, fill=fill)
self.num_magnitude_bins = num_magnitude_bins self.num_magnitude_bins = num_magnitude_bins
...@@ -446,7 +455,7 @@ class AugMix(_AutoAugmentBase): ...@@ -446,7 +455,7 @@ class AugMix(_AutoAugmentBase):
alpha: float = 1.0, alpha: float = 1.0,
all_ops: bool = True, all_ops: bool = True,
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Union[features.FillType, Dict[Type, features.FillType]] = None, fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = None,
) -> None: ) -> None:
super().__init__(interpolation=interpolation, fill=fill) super().__init__(interpolation=interpolation, fill=fill)
self._PARAMETER_MAX = 10 self._PARAMETER_MAX = 10
...@@ -474,7 +483,7 @@ class AugMix(_AutoAugmentBase): ...@@ -474,7 +483,7 @@ class AugMix(_AutoAugmentBase):
augmentation_space = self._AUGMENTATION_SPACE if self.all_ops else self._PARTIAL_AUGMENTATION_SPACE augmentation_space = self._AUGMENTATION_SPACE if self.all_ops else self._PARTIAL_AUGMENTATION_SPACE
orig_dims = list(image_or_video.shape) orig_dims = list(image_or_video.shape)
expected_ndim = 5 if isinstance(orig_image_or_video, features.Video) else 4 expected_ndim = 5 if isinstance(orig_image_or_video, datapoints.Video) else 4
batch = image_or_video.reshape([1] * max(expected_ndim - image_or_video.ndim, 0) + orig_dims) batch = image_or_video.reshape([1] * max(expected_ndim - image_or_video.ndim, 0) + orig_dims)
batch_dims = [batch.size(0)] + [1] * (batch.ndim - 1) batch_dims = [batch.size(0)] + [1] * (batch.ndim - 1)
...@@ -511,7 +520,7 @@ class AugMix(_AutoAugmentBase): ...@@ -511,7 +520,7 @@ class AugMix(_AutoAugmentBase):
mix.add_(combined_weights[:, i].reshape(batch_dims) * aug) mix.add_(combined_weights[:, i].reshape(batch_dims) * aug)
mix = mix.reshape(orig_dims).to(dtype=image_or_video.dtype) mix = mix.reshape(orig_dims).to(dtype=image_or_video.dtype)
if isinstance(orig_image_or_video, (features.Image, features.Video)): if isinstance(orig_image_or_video, (datapoints.Image, datapoints.Video)):
mix = orig_image_or_video.wrap_like(orig_image_or_video, mix) # type: ignore[arg-type] mix = orig_image_or_video.wrap_like(orig_image_or_video, mix) # type: ignore[arg-type]
elif isinstance(orig_image_or_video, PIL.Image.Image): elif isinstance(orig_image_or_video, PIL.Image.Image):
mix = F.to_image_pil(mix) mix = F.to_image_pil(mix)
......
...@@ -3,11 +3,12 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union ...@@ -3,11 +3,12 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import PIL.Image import PIL.Image
import torch import torch
from torchvision.prototype import features
from torchvision.prototype import datapoints
from torchvision.prototype.transforms import functional as F, Transform from torchvision.prototype.transforms import functional as F, Transform
from ._transform import _RandomApplyTransform from ._transform import _RandomApplyTransform
from .utils import query_chw from .utils import is_simple_tensor, query_chw
class ColorJitter(Transform): class ColorJitter(Transform):
...@@ -82,7 +83,12 @@ class ColorJitter(Transform): ...@@ -82,7 +83,12 @@ class ColorJitter(Transform):
class RandomPhotometricDistort(Transform): class RandomPhotometricDistort(Transform):
_transformed_types = (features.Image, PIL.Image.Image, features.is_simple_tensor, features.Video) _transformed_types = (
datapoints.Image,
PIL.Image.Image,
is_simple_tensor,
datapoints.Video,
)
def __init__( def __init__(
self, self,
...@@ -111,15 +117,15 @@ class RandomPhotometricDistort(Transform): ...@@ -111,15 +117,15 @@ class RandomPhotometricDistort(Transform):
) )
def _permute_channels( def _permute_channels(
self, inpt: Union[features.ImageType, features.VideoType], permutation: torch.Tensor self, inpt: Union[datapoints.ImageType, datapoints.VideoType], permutation: torch.Tensor
) -> Union[features.ImageType, features.VideoType]: ) -> Union[datapoints.ImageType, datapoints.VideoType]:
if isinstance(inpt, PIL.Image.Image): if isinstance(inpt, PIL.Image.Image):
inpt = F.pil_to_tensor(inpt) inpt = F.pil_to_tensor(inpt)
output = inpt[..., permutation, :, :] output = inpt[..., permutation, :, :]
if isinstance(inpt, (features.Image, features.Video)): if isinstance(inpt, (datapoints.Image, datapoints.Video)):
output = inpt.wrap_like(inpt, output, color_space=features.ColorSpace.OTHER) # type: ignore[arg-type] output = inpt.wrap_like(inpt, output, color_space=datapoints.ColorSpace.OTHER) # type: ignore[arg-type]
elif isinstance(inpt, PIL.Image.Image): elif isinstance(inpt, PIL.Image.Image):
output = F.to_image_pil(output) output = F.to_image_pil(output)
...@@ -127,8 +133,8 @@ class RandomPhotometricDistort(Transform): ...@@ -127,8 +133,8 @@ class RandomPhotometricDistort(Transform):
return output return output
def _transform( def _transform(
self, inpt: Union[features.ImageType, features.VideoType], params: Dict[str, Any] self, inpt: Union[datapoints.ImageType, datapoints.VideoType], params: Dict[str, Any]
) -> Union[features.ImageType, features.VideoType]: ) -> Union[datapoints.ImageType, datapoints.VideoType]:
if params["brightness"]: if params["brightness"]:
inpt = F.adjust_brightness( inpt = F.adjust_brightness(
inpt, brightness_factor=ColorJitter._generate_value(self.brightness[0], self.brightness[1]) inpt, brightness_factor=ColorJitter._generate_value(self.brightness[0], self.brightness[1])
......
...@@ -5,13 +5,13 @@ import numpy as np ...@@ -5,13 +5,13 @@ import numpy as np
import PIL.Image import PIL.Image
import torch import torch
from torchvision.prototype import features from torchvision.prototype import datapoints
from torchvision.prototype.transforms import Transform from torchvision.prototype.transforms import Transform
from torchvision.transforms import functional as _F from torchvision.transforms import functional as _F
from typing_extensions import Literal from typing_extensions import Literal
from ._transform import _RandomApplyTransform from ._transform import _RandomApplyTransform
from .utils import query_chw from .utils import is_simple_tensor, query_chw
class ToTensor(Transform): class ToTensor(Transform):
...@@ -29,7 +29,12 @@ class ToTensor(Transform): ...@@ -29,7 +29,12 @@ class ToTensor(Transform):
class Grayscale(Transform): class Grayscale(Transform):
_transformed_types = (features.Image, PIL.Image.Image, features.is_simple_tensor, features.Video) _transformed_types = (
datapoints.Image,
PIL.Image.Image,
is_simple_tensor,
datapoints.Video,
)
def __init__(self, num_output_channels: Literal[1, 3] = 1) -> None: def __init__(self, num_output_channels: Literal[1, 3] = 1) -> None:
deprecation_msg = ( deprecation_msg = (
...@@ -53,16 +58,21 @@ class Grayscale(Transform): ...@@ -53,16 +58,21 @@ class Grayscale(Transform):
self.num_output_channels = num_output_channels self.num_output_channels = num_output_channels
def _transform( def _transform(
self, inpt: Union[features.ImageType, features.VideoType], params: Dict[str, Any] self, inpt: Union[datapoints.ImageType, datapoints.VideoType], params: Dict[str, Any]
) -> Union[features.ImageType, features.VideoType]: ) -> Union[datapoints.ImageType, datapoints.VideoType]:
output = _F.rgb_to_grayscale(inpt, num_output_channels=self.num_output_channels) output = _F.rgb_to_grayscale(inpt, num_output_channels=self.num_output_channels)
if isinstance(inpt, (features.Image, features.Video)): if isinstance(inpt, (datapoints.Image, datapoints.Video)):
output = inpt.wrap_like(inpt, output, color_space=features.ColorSpace.GRAY) # type: ignore[arg-type] output = inpt.wrap_like(inpt, output, color_space=datapoints.ColorSpace.GRAY) # type: ignore[arg-type]
return output return output
class RandomGrayscale(_RandomApplyTransform): class RandomGrayscale(_RandomApplyTransform):
_transformed_types = (features.Image, PIL.Image.Image, features.is_simple_tensor, features.Video) _transformed_types = (
datapoints.Image,
PIL.Image.Image,
is_simple_tensor,
datapoints.Video,
)
def __init__(self, p: float = 0.1) -> None: def __init__(self, p: float = 0.1) -> None:
warnings.warn( warnings.warn(
...@@ -84,9 +94,9 @@ class RandomGrayscale(_RandomApplyTransform): ...@@ -84,9 +94,9 @@ class RandomGrayscale(_RandomApplyTransform):
return dict(num_input_channels=num_input_channels) return dict(num_input_channels=num_input_channels)
def _transform( def _transform(
self, inpt: Union[features.ImageType, features.VideoType], params: Dict[str, Any] self, inpt: Union[datapoints.ImageType, datapoints.VideoType], params: Dict[str, Any]
) -> Union[features.ImageType, features.VideoType]: ) -> Union[datapoints.ImageType, datapoints.VideoType]:
output = _F.rgb_to_grayscale(inpt, num_output_channels=params["num_input_channels"]) output = _F.rgb_to_grayscale(inpt, num_output_channels=params["num_input_channels"])
if isinstance(inpt, (features.Image, features.Video)): if isinstance(inpt, (datapoints.Image, datapoints.Video)):
output = inpt.wrap_like(inpt, output, color_space=features.ColorSpace.GRAY) # type: ignore[arg-type] output = inpt.wrap_like(inpt, output, color_space=datapoints.ColorSpace.GRAY) # type: ignore[arg-type]
return output return output
...@@ -5,8 +5,9 @@ from typing import Any, cast, Dict, List, Optional, Sequence, Tuple, Type, Union ...@@ -5,8 +5,9 @@ from typing import Any, cast, Dict, List, Optional, Sequence, Tuple, Type, Union
import PIL.Image import PIL.Image
import torch import torch
from torchvision.ops.boxes import box_iou from torchvision.ops.boxes import box_iou
from torchvision.prototype import features from torchvision.prototype import datapoints
from torchvision.prototype.transforms import functional as F, InterpolationMode, Transform from torchvision.prototype.transforms import functional as F, InterpolationMode, Transform
from torchvision.transforms.functional import _get_perspective_coeffs from torchvision.transforms.functional import _get_perspective_coeffs
...@@ -22,7 +23,7 @@ from ._utils import ( ...@@ -22,7 +23,7 @@ from ._utils import (
_setup_float_or_seq, _setup_float_or_seq,
_setup_size, _setup_size,
) )
from .utils import has_all, has_any, query_bounding_box, query_spatial_size from .utils import has_all, has_any, is_simple_tensor, query_bounding_box, query_spatial_size
class RandomHorizontalFlip(_RandomApplyTransform): class RandomHorizontalFlip(_RandomApplyTransform):
...@@ -145,23 +146,23 @@ class RandomResizedCrop(Transform): ...@@ -145,23 +146,23 @@ class RandomResizedCrop(Transform):
) )
ImageOrVideoTypeJIT = Union[features.ImageTypeJIT, features.VideoTypeJIT] ImageOrVideoTypeJIT = Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT]
class FiveCrop(Transform): class FiveCrop(Transform):
""" """
Example: Example:
>>> class BatchMultiCrop(transforms.Transform): >>> class BatchMultiCrop(transforms.Transform):
... def forward(self, sample: Tuple[Tuple[Union[features.Image, features.Video], ...], features.Label]): ... def forward(self, sample: Tuple[Tuple[Union[datapoints.Image, datapoints.Video], ...], datapoints.Label]):
... images_or_videos, labels = sample ... images_or_videos, labels = sample
... batch_size = len(images_or_videos) ... batch_size = len(images_or_videos)
... image_or_video = images_or_videos[0] ... image_or_video = images_or_videos[0]
... images_or_videos = image_or_video.wrap_like(image_or_video, torch.stack(images_or_videos)) ... images_or_videos = image_or_video.wrap_like(image_or_video, torch.stack(images_or_videos))
... labels = features.Label.wrap_like(labels, labels.repeat(batch_size)) ... labels = datapoints.Label.wrap_like(labels, labels.repeat(batch_size))
... return images_or_videos, labels ... return images_or_videos, labels
... ...
>>> image = features.Image(torch.rand(3, 256, 256)) >>> image = datapoints.Image(torch.rand(3, 256, 256))
>>> label = features.Label(0) >>> label = datapoints.Label(0)
>>> transform = transforms.Compose([transforms.FiveCrop(), BatchMultiCrop()]) >>> transform = transforms.Compose([transforms.FiveCrop(), BatchMultiCrop()])
>>> images, labels = transform(image, label) >>> images, labels = transform(image, label)
>>> images.shape >>> images.shape
...@@ -170,7 +171,12 @@ class FiveCrop(Transform): ...@@ -170,7 +171,12 @@ class FiveCrop(Transform):
torch.Size([5]) torch.Size([5])
""" """
_transformed_types = (features.Image, PIL.Image.Image, features.is_simple_tensor, features.Video) _transformed_types = (
datapoints.Image,
PIL.Image.Image,
is_simple_tensor,
datapoints.Video,
)
def __init__(self, size: Union[int, Sequence[int]]) -> None: def __init__(self, size: Union[int, Sequence[int]]) -> None:
super().__init__() super().__init__()
...@@ -182,7 +188,7 @@ class FiveCrop(Transform): ...@@ -182,7 +188,7 @@ class FiveCrop(Transform):
return F.five_crop(inpt, self.size) return F.five_crop(inpt, self.size)
def _check_inputs(self, flat_inputs: List[Any]) -> None: def _check_inputs(self, flat_inputs: List[Any]) -> None:
if has_any(flat_inputs, features.BoundingBox, features.Mask): if has_any(flat_inputs, datapoints.BoundingBox, datapoints.Mask):
raise TypeError(f"BoundingBox'es and Mask's are not supported by {type(self).__name__}()") raise TypeError(f"BoundingBox'es and Mask's are not supported by {type(self).__name__}()")
...@@ -191,7 +197,12 @@ class TenCrop(Transform): ...@@ -191,7 +197,12 @@ class TenCrop(Transform):
See :class:`~torchvision.prototype.transforms.FiveCrop` for an example. See :class:`~torchvision.prototype.transforms.FiveCrop` for an example.
""" """
_transformed_types = (features.Image, PIL.Image.Image, features.is_simple_tensor, features.Video) _transformed_types = (
datapoints.Image,
PIL.Image.Image,
is_simple_tensor,
datapoints.Video,
)
def __init__(self, size: Union[int, Sequence[int]], vertical_flip: bool = False) -> None: def __init__(self, size: Union[int, Sequence[int]], vertical_flip: bool = False) -> None:
super().__init__() super().__init__()
...@@ -199,12 +210,12 @@ class TenCrop(Transform): ...@@ -199,12 +210,12 @@ class TenCrop(Transform):
self.vertical_flip = vertical_flip self.vertical_flip = vertical_flip
def _check_inputs(self, flat_inputs: List[Any]) -> None: def _check_inputs(self, flat_inputs: List[Any]) -> None:
if has_any(flat_inputs, features.BoundingBox, features.Mask): if has_any(flat_inputs, datapoints.BoundingBox, datapoints.Mask):
raise TypeError(f"BoundingBox'es and Mask's are not supported by {type(self).__name__}()") raise TypeError(f"BoundingBox'es and Mask's are not supported by {type(self).__name__}()")
def _transform( def _transform(
self, inpt: Union[features.ImageType, features.VideoType], params: Dict[str, Any] self, inpt: Union[datapoints.ImageType, datapoints.VideoType], params: Dict[str, Any]
) -> Union[List[features.ImageTypeJIT], List[features.VideoTypeJIT]]: ) -> Union[List[datapoints.ImageTypeJIT], List[datapoints.VideoTypeJIT]]:
return F.ten_crop(inpt, self.size, vertical_flip=self.vertical_flip) return F.ten_crop(inpt, self.size, vertical_flip=self.vertical_flip)
...@@ -212,7 +223,7 @@ class Pad(Transform): ...@@ -212,7 +223,7 @@ class Pad(Transform):
def __init__( def __init__(
self, self,
padding: Union[int, Sequence[int]], padding: Union[int, Sequence[int]],
fill: Union[features.FillType, Dict[Type, features.FillType]] = 0, fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = 0,
padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant", padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant",
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -235,7 +246,7 @@ class Pad(Transform): ...@@ -235,7 +246,7 @@ class Pad(Transform):
class RandomZoomOut(_RandomApplyTransform): class RandomZoomOut(_RandomApplyTransform):
def __init__( def __init__(
self, self,
fill: Union[features.FillType, Dict[Type, features.FillType]] = 0, fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = 0,
side_range: Sequence[float] = (1.0, 4.0), side_range: Sequence[float] = (1.0, 4.0),
p: float = 0.5, p: float = 0.5,
) -> None: ) -> None:
...@@ -276,7 +287,7 @@ class RandomRotation(Transform): ...@@ -276,7 +287,7 @@ class RandomRotation(Transform):
degrees: Union[numbers.Number, Sequence], degrees: Union[numbers.Number, Sequence],
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: InterpolationMode = InterpolationMode.NEAREST,
expand: bool = False, expand: bool = False,
fill: Union[features.FillType, Dict[Type, features.FillType]] = 0, fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = 0,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -315,7 +326,7 @@ class RandomAffine(Transform): ...@@ -315,7 +326,7 @@ class RandomAffine(Transform):
scale: Optional[Sequence[float]] = None, scale: Optional[Sequence[float]] = None,
shear: Optional[Union[int, float, Sequence[float]]] = None, shear: Optional[Union[int, float, Sequence[float]]] = None,
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Union[features.FillType, Dict[Type, features.FillType]] = 0, fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = 0,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -390,7 +401,7 @@ class RandomCrop(Transform): ...@@ -390,7 +401,7 @@ class RandomCrop(Transform):
size: Union[int, Sequence[int]], size: Union[int, Sequence[int]],
padding: Optional[Union[int, Sequence[int]]] = None, padding: Optional[Union[int, Sequence[int]]] = None,
pad_if_needed: bool = False, pad_if_needed: bool = False,
fill: Union[features.FillType, Dict[Type, features.FillType]] = 0, fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = 0,
padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant", padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant",
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -480,7 +491,7 @@ class RandomPerspective(_RandomApplyTransform): ...@@ -480,7 +491,7 @@ class RandomPerspective(_RandomApplyTransform):
def __init__( def __init__(
self, self,
distortion_scale: float = 0.5, distortion_scale: float = 0.5,
fill: Union[features.FillType, Dict[Type, features.FillType]] = 0, fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = 0,
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
p: float = 0.5, p: float = 0.5,
) -> None: ) -> None:
...@@ -540,7 +551,7 @@ class ElasticTransform(Transform): ...@@ -540,7 +551,7 @@ class ElasticTransform(Transform):
self, self,
alpha: Union[float, Sequence[float]] = 50.0, alpha: Union[float, Sequence[float]] = 50.0,
sigma: Union[float, Sequence[float]] = 5.0, sigma: Union[float, Sequence[float]] = 5.0,
fill: Union[features.FillType, Dict[Type, features.FillType]] = 0, fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = 0,
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -606,9 +617,9 @@ class RandomIoUCrop(Transform): ...@@ -606,9 +617,9 @@ class RandomIoUCrop(Transform):
def _check_inputs(self, flat_inputs: List[Any]) -> None: def _check_inputs(self, flat_inputs: List[Any]) -> None:
if not ( if not (
has_all(flat_inputs, features.BoundingBox) has_all(flat_inputs, datapoints.BoundingBox)
and has_any(flat_inputs, PIL.Image.Image, features.Image, features.is_simple_tensor) and has_any(flat_inputs, PIL.Image.Image, datapoints.Image, is_simple_tensor)
and has_any(flat_inputs, features.Label, features.OneHotLabel) and has_any(flat_inputs, datapoints.Label, datapoints.OneHotLabel)
): ):
raise TypeError( raise TypeError(
f"{type(self).__name__}() requires input sample to contain Images or PIL Images, " f"{type(self).__name__}() requires input sample to contain Images or PIL Images, "
...@@ -646,7 +657,7 @@ class RandomIoUCrop(Transform): ...@@ -646,7 +657,7 @@ class RandomIoUCrop(Transform):
# check for any valid boxes with centers within the crop area # check for any valid boxes with centers within the crop area
xyxy_bboxes = F.convert_format_bounding_box( xyxy_bboxes = F.convert_format_bounding_box(
bboxes.as_subclass(torch.Tensor), bboxes.format, features.BoundingBoxFormat.XYXY bboxes.as_subclass(torch.Tensor), bboxes.format, datapoints.BoundingBoxFormat.XYXY
) )
cx = 0.5 * (xyxy_bboxes[..., 0] + xyxy_bboxes[..., 2]) cx = 0.5 * (xyxy_bboxes[..., 0] + xyxy_bboxes[..., 2])
cy = 0.5 * (xyxy_bboxes[..., 1] + xyxy_bboxes[..., 3]) cy = 0.5 * (xyxy_bboxes[..., 1] + xyxy_bboxes[..., 3])
...@@ -671,19 +682,19 @@ class RandomIoUCrop(Transform): ...@@ -671,19 +682,19 @@ class RandomIoUCrop(Transform):
is_within_crop_area = params["is_within_crop_area"] is_within_crop_area = params["is_within_crop_area"]
if isinstance(inpt, (features.Label, features.OneHotLabel)): if isinstance(inpt, (datapoints.Label, datapoints.OneHotLabel)):
return inpt.wrap_like(inpt, inpt[is_within_crop_area]) # type: ignore[arg-type] return inpt.wrap_like(inpt, inpt[is_within_crop_area]) # type: ignore[arg-type]
output = F.crop(inpt, top=params["top"], left=params["left"], height=params["height"], width=params["width"]) output = F.crop(inpt, top=params["top"], left=params["left"], height=params["height"], width=params["width"])
if isinstance(output, features.BoundingBox): if isinstance(output, datapoints.BoundingBox):
bboxes = output[is_within_crop_area] bboxes = output[is_within_crop_area]
bboxes = F.clamp_bounding_box(bboxes, output.format, output.spatial_size) bboxes = F.clamp_bounding_box(bboxes, output.format, output.spatial_size)
output = features.BoundingBox.wrap_like(output, bboxes) output = datapoints.BoundingBox.wrap_like(output, bboxes)
elif isinstance(output, features.Mask): elif isinstance(output, datapoints.Mask):
# apply is_within_crop_area if mask is one-hot encoded # apply is_within_crop_area if mask is one-hot encoded
masks = output[is_within_crop_area] masks = output[is_within_crop_area]
output = features.Mask.wrap_like(output, masks) output = datapoints.Mask.wrap_like(output, masks)
return output return output
...@@ -751,7 +762,7 @@ class FixedSizeCrop(Transform): ...@@ -751,7 +762,7 @@ class FixedSizeCrop(Transform):
def __init__( def __init__(
self, self,
size: Union[int, Sequence[int]], size: Union[int, Sequence[int]],
fill: Union[features.FillType, Dict[Type, features.FillType]] = 0, fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = 0,
padding_mode: str = "constant", padding_mode: str = "constant",
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -764,13 +775,19 @@ class FixedSizeCrop(Transform): ...@@ -764,13 +775,19 @@ class FixedSizeCrop(Transform):
self.padding_mode = padding_mode self.padding_mode = padding_mode
def _check_inputs(self, flat_inputs: List[Any]) -> None: def _check_inputs(self, flat_inputs: List[Any]) -> None:
if not has_any(flat_inputs, PIL.Image.Image, features.Image, features.is_simple_tensor, features.Video): if not has_any(
flat_inputs,
PIL.Image.Image,
datapoints.Image,
is_simple_tensor,
datapoints.Video,
):
raise TypeError( raise TypeError(
f"{type(self).__name__}() requires input sample to contain an tensor or PIL image or a Video." f"{type(self).__name__}() requires input sample to contain an tensor or PIL image or a Video."
) )
if has_any(flat_inputs, features.BoundingBox) and not has_any( if has_any(flat_inputs, datapoints.BoundingBox) and not has_any(
flat_inputs, features.Label, features.OneHotLabel flat_inputs, datapoints.Label, datapoints.OneHotLabel
): ):
raise TypeError( raise TypeError(
f"If a BoundingBox is contained in the input sample, " f"If a BoundingBox is contained in the input sample, "
...@@ -809,7 +826,7 @@ class FixedSizeCrop(Transform): ...@@ -809,7 +826,7 @@ class FixedSizeCrop(Transform):
) )
bounding_boxes = F.clamp_bounding_box(bounding_boxes, format=format, spatial_size=spatial_size) bounding_boxes = F.clamp_bounding_box(bounding_boxes, format=format, spatial_size=spatial_size)
height_and_width = F.convert_format_bounding_box( height_and_width = F.convert_format_bounding_box(
bounding_boxes, old_format=format, new_format=features.BoundingBoxFormat.XYWH bounding_boxes, old_format=format, new_format=datapoints.BoundingBoxFormat.XYWH
)[..., 2:] )[..., 2:]
is_valid = torch.all(height_and_width > 0, dim=-1) is_valid = torch.all(height_and_width > 0, dim=-1)
else: else:
...@@ -842,10 +859,10 @@ class FixedSizeCrop(Transform): ...@@ -842,10 +859,10 @@ class FixedSizeCrop(Transform):
) )
if params["is_valid"] is not None: if params["is_valid"] is not None:
if isinstance(inpt, (features.Label, features.OneHotLabel, features.Mask)): if isinstance(inpt, (datapoints.Label, datapoints.OneHotLabel, datapoints.Mask)):
inpt = inpt.wrap_like(inpt, inpt[params["is_valid"]]) # type: ignore[arg-type] inpt = inpt.wrap_like(inpt, inpt[params["is_valid"]]) # type: ignore[arg-type]
elif isinstance(inpt, features.BoundingBox): elif isinstance(inpt, datapoints.BoundingBox):
inpt = features.BoundingBox.wrap_like( inpt = datapoints.BoundingBox.wrap_like(
inpt, inpt,
F.clamp_bounding_box(inpt[params["is_valid"]], format=inpt.format, spatial_size=inpt.spatial_size), F.clamp_bounding_box(inpt[params["is_valid"]], format=inpt.format, spatial_size=inpt.spatial_size),
) )
......
...@@ -3,38 +3,41 @@ from typing import Any, Dict, Optional, Union ...@@ -3,38 +3,41 @@ from typing import Any, Dict, Optional, Union
import PIL.Image import PIL.Image
import torch import torch
from torchvision.prototype import features
from torchvision.prototype import datapoints
from torchvision.prototype.transforms import functional as F, Transform from torchvision.prototype.transforms import functional as F, Transform
from .utils import is_simple_tensor
class ConvertBoundingBoxFormat(Transform): class ConvertBoundingBoxFormat(Transform):
_transformed_types = (features.BoundingBox,) _transformed_types = (datapoints.BoundingBox,)
def __init__(self, format: Union[str, features.BoundingBoxFormat]) -> None: def __init__(self, format: Union[str, datapoints.BoundingBoxFormat]) -> None:
super().__init__() super().__init__()
if isinstance(format, str): if isinstance(format, str):
format = features.BoundingBoxFormat[format] format = datapoints.BoundingBoxFormat[format]
self.format = format self.format = format
def _transform(self, inpt: features.BoundingBox, params: Dict[str, Any]) -> features.BoundingBox: def _transform(self, inpt: datapoints.BoundingBox, params: Dict[str, Any]) -> datapoints.BoundingBox:
# We need to unwrap here to avoid unnecessary `__torch_function__` calls, # We need to unwrap here to avoid unnecessary `__torch_function__` calls,
# since `convert_format_bounding_box` does not have a dispatcher function that would do that for us # since `convert_format_bounding_box` does not have a dispatcher function that would do that for us
output = F.convert_format_bounding_box( output = F.convert_format_bounding_box(
inpt.as_subclass(torch.Tensor), old_format=inpt.format, new_format=params["format"] inpt.as_subclass(torch.Tensor), old_format=inpt.format, new_format=params["format"]
) )
return features.BoundingBox.wrap_like(inpt, output, format=params["format"]) return datapoints.BoundingBox.wrap_like(inpt, output, format=params["format"])
class ConvertDtype(Transform): class ConvertDtype(Transform):
_transformed_types = (features.is_simple_tensor, features.Image, features.Video) _transformed_types = (is_simple_tensor, datapoints.Image, datapoints.Video)
def __init__(self, dtype: torch.dtype = torch.float32) -> None: def __init__(self, dtype: torch.dtype = torch.float32) -> None:
super().__init__() super().__init__()
self.dtype = dtype self.dtype = dtype
def _transform( def _transform(
self, inpt: Union[features.TensorImageType, features.TensorVideoType], params: Dict[str, Any] self, inpt: Union[datapoints.TensorImageType, datapoints.TensorVideoType], params: Dict[str, Any]
) -> Union[features.TensorImageType, features.TensorVideoType]: ) -> Union[datapoints.TensorImageType, datapoints.TensorVideoType]:
return F.convert_dtype(inpt, self.dtype) return F.convert_dtype(inpt, self.dtype)
...@@ -44,36 +47,41 @@ ConvertImageDtype = ConvertDtype ...@@ -44,36 +47,41 @@ ConvertImageDtype = ConvertDtype
class ConvertColorSpace(Transform): class ConvertColorSpace(Transform):
_transformed_types = (features.is_simple_tensor, features.Image, PIL.Image.Image, features.Video) _transformed_types = (
is_simple_tensor,
datapoints.Image,
PIL.Image.Image,
datapoints.Video,
)
def __init__( def __init__(
self, self,
color_space: Union[str, features.ColorSpace], color_space: Union[str, datapoints.ColorSpace],
old_color_space: Optional[Union[str, features.ColorSpace]] = None, old_color_space: Optional[Union[str, datapoints.ColorSpace]] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
if isinstance(color_space, str): if isinstance(color_space, str):
color_space = features.ColorSpace.from_str(color_space) color_space = datapoints.ColorSpace.from_str(color_space)
self.color_space = color_space self.color_space = color_space
if isinstance(old_color_space, str): if isinstance(old_color_space, str):
old_color_space = features.ColorSpace.from_str(old_color_space) old_color_space = datapoints.ColorSpace.from_str(old_color_space)
self.old_color_space = old_color_space self.old_color_space = old_color_space
def _transform( def _transform(
self, inpt: Union[features.ImageType, features.VideoType], params: Dict[str, Any] self, inpt: Union[datapoints.ImageType, datapoints.VideoType], params: Dict[str, Any]
) -> Union[features.ImageType, features.VideoType]: ) -> Union[datapoints.ImageType, datapoints.VideoType]:
return F.convert_color_space(inpt, color_space=self.color_space, old_color_space=self.old_color_space) return F.convert_color_space(inpt, color_space=self.color_space, old_color_space=self.old_color_space)
class ClampBoundingBoxes(Transform): class ClampBoundingBoxes(Transform):
_transformed_types = (features.BoundingBox,) _transformed_types = (datapoints.BoundingBox,)
def _transform(self, inpt: features.BoundingBox, params: Dict[str, Any]) -> features.BoundingBox: def _transform(self, inpt: datapoints.BoundingBox, params: Dict[str, Any]) -> datapoints.BoundingBox:
# We need to unwrap here to avoid unnecessary `__torch_function__` calls, # We need to unwrap here to avoid unnecessary `__torch_function__` calls,
# since `clamp_bounding_box` does not have a dispatcher function that would do that for us # since `clamp_bounding_box` does not have a dispatcher function that would do that for us
output = F.clamp_bounding_box( output = F.clamp_bounding_box(
inpt.as_subclass(torch.Tensor), format=inpt.format, spatial_size=inpt.spatial_size inpt.as_subclass(torch.Tensor), format=inpt.format, spatial_size=inpt.spatial_size
) )
return features.BoundingBox.wrap_like(inpt, output) return datapoints.BoundingBox.wrap_like(inpt, output)
...@@ -3,12 +3,13 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, U ...@@ -3,12 +3,13 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, U
import PIL.Image import PIL.Image
import torch import torch
from torchvision.ops import remove_small_boxes from torchvision.ops import remove_small_boxes
from torchvision.prototype import features from torchvision.prototype import datapoints
from torchvision.prototype.transforms import functional as F, Transform from torchvision.prototype.transforms import functional as F, Transform
from ._utils import _get_defaultdict, _setup_float_or_seq, _setup_size from ._utils import _get_defaultdict, _setup_float_or_seq, _setup_size
from .utils import has_any, query_bounding_box from .utils import has_any, is_simple_tensor, query_bounding_box
class Identity(Transform): class Identity(Transform):
...@@ -38,7 +39,7 @@ class Lambda(Transform): ...@@ -38,7 +39,7 @@ class Lambda(Transform):
class LinearTransformation(Transform): class LinearTransformation(Transform):
_transformed_types = (features.is_simple_tensor, features.Image, features.Video) _transformed_types = (is_simple_tensor, datapoints.Image, datapoints.Video)
def __init__(self, transformation_matrix: torch.Tensor, mean_vector: torch.Tensor): def __init__(self, transformation_matrix: torch.Tensor, mean_vector: torch.Tensor):
super().__init__() super().__init__()
...@@ -67,7 +68,7 @@ class LinearTransformation(Transform): ...@@ -67,7 +68,7 @@ class LinearTransformation(Transform):
raise TypeError("LinearTransformation does not work on PIL Images") raise TypeError("LinearTransformation does not work on PIL Images")
def _transform( def _transform(
self, inpt: Union[features.TensorImageType, features.TensorVideoType], params: Dict[str, Any] self, inpt: Union[datapoints.TensorImageType, datapoints.TensorVideoType], params: Dict[str, Any]
) -> torch.Tensor: ) -> torch.Tensor:
# Image instance after linear transformation is not Image anymore due to unknown data range # Image instance after linear transformation is not Image anymore due to unknown data range
# Thus we will return Tensor for input Image # Thus we will return Tensor for input Image
...@@ -93,7 +94,7 @@ class LinearTransformation(Transform): ...@@ -93,7 +94,7 @@ class LinearTransformation(Transform):
class Normalize(Transform): class Normalize(Transform):
_transformed_types = (features.Image, features.is_simple_tensor, features.Video) _transformed_types = (datapoints.Image, is_simple_tensor, datapoints.Video)
def __init__(self, mean: Sequence[float], std: Sequence[float], inplace: bool = False): def __init__(self, mean: Sequence[float], std: Sequence[float], inplace: bool = False):
super().__init__() super().__init__()
...@@ -106,7 +107,7 @@ class Normalize(Transform): ...@@ -106,7 +107,7 @@ class Normalize(Transform):
raise TypeError(f"{type(self).__name__}() does not support PIL images.") raise TypeError(f"{type(self).__name__}() does not support PIL images.")
def _transform( def _transform(
self, inpt: Union[features.TensorImageType, features.TensorVideoType], params: Dict[str, Any] self, inpt: Union[datapoints.TensorImageType, datapoints.TensorVideoType], params: Dict[str, Any]
) -> torch.Tensor: ) -> torch.Tensor:
return F.normalize(inpt, mean=self.mean, std=self.std, inplace=self.inplace) return F.normalize(inpt, mean=self.mean, std=self.std, inplace=self.inplace)
...@@ -158,7 +159,7 @@ class ToDtype(Transform): ...@@ -158,7 +159,7 @@ class ToDtype(Transform):
class PermuteDimensions(Transform): class PermuteDimensions(Transform):
_transformed_types = (features.is_simple_tensor, features.Image, features.Video) _transformed_types = (is_simple_tensor, datapoints.Image, datapoints.Video)
def __init__(self, dims: Union[Sequence[int], Dict[Type, Optional[Sequence[int]]]]) -> None: def __init__(self, dims: Union[Sequence[int], Dict[Type, Optional[Sequence[int]]]]) -> None:
super().__init__() super().__init__()
...@@ -167,7 +168,7 @@ class PermuteDimensions(Transform): ...@@ -167,7 +168,7 @@ class PermuteDimensions(Transform):
self.dims = dims self.dims = dims
def _transform( def _transform(
self, inpt: Union[features.TensorImageType, features.TensorVideoType], params: Dict[str, Any] self, inpt: Union[datapoints.TensorImageType, datapoints.TensorVideoType], params: Dict[str, Any]
) -> torch.Tensor: ) -> torch.Tensor:
dims = self.dims[type(inpt)] dims = self.dims[type(inpt)]
if dims is None: if dims is None:
...@@ -176,7 +177,7 @@ class PermuteDimensions(Transform): ...@@ -176,7 +177,7 @@ class PermuteDimensions(Transform):
class TransposeDimensions(Transform): class TransposeDimensions(Transform):
_transformed_types = (features.is_simple_tensor, features.Image, features.Video) _transformed_types = (is_simple_tensor, datapoints.Image, datapoints.Video)
def __init__(self, dims: Union[Tuple[int, int], Dict[Type, Optional[Tuple[int, int]]]]) -> None: def __init__(self, dims: Union[Tuple[int, int], Dict[Type, Optional[Tuple[int, int]]]]) -> None:
super().__init__() super().__init__()
...@@ -185,7 +186,7 @@ class TransposeDimensions(Transform): ...@@ -185,7 +186,7 @@ class TransposeDimensions(Transform):
self.dims = dims self.dims = dims
def _transform( def _transform(
self, inpt: Union[features.TensorImageType, features.TensorVideoType], params: Dict[str, Any] self, inpt: Union[datapoints.TensorImageType, datapoints.TensorVideoType], params: Dict[str, Any]
) -> torch.Tensor: ) -> torch.Tensor:
dims = self.dims[type(inpt)] dims = self.dims[type(inpt)]
if dims is None: if dims is None:
...@@ -194,7 +195,7 @@ class TransposeDimensions(Transform): ...@@ -194,7 +195,7 @@ class TransposeDimensions(Transform):
class RemoveSmallBoundingBoxes(Transform): class RemoveSmallBoundingBoxes(Transform):
_transformed_types = (features.BoundingBox, features.Mask, features.Label, features.OneHotLabel) _transformed_types = (datapoints.BoundingBox, datapoints.Mask, datapoints.Label, datapoints.OneHotLabel)
def __init__(self, min_size: float = 1.0) -> None: def __init__(self, min_size: float = 1.0) -> None:
super().__init__() super().__init__()
...@@ -210,7 +211,7 @@ class RemoveSmallBoundingBoxes(Transform): ...@@ -210,7 +211,7 @@ class RemoveSmallBoundingBoxes(Transform):
bounding_box = F.convert_format_bounding_box( bounding_box = F.convert_format_bounding_box(
bounding_box.as_subclass(torch.Tensor), bounding_box.as_subclass(torch.Tensor),
old_format=bounding_box.format, old_format=bounding_box.format,
new_format=features.BoundingBoxFormat.XYXY, new_format=datapoints.BoundingBoxFormat.XYXY,
) )
valid_indices = remove_small_boxes(bounding_box, min_size=self.min_size) valid_indices = remove_small_boxes(bounding_box, min_size=self.min_size)
......
from typing import Any, Dict from typing import Any, Dict
from torchvision.prototype import features from torchvision.prototype import datapoints
from torchvision.prototype.transforms import functional as F, Transform from torchvision.prototype.transforms import functional as F, Transform
from torchvision.prototype.transforms.utils import is_simple_tensor
class UniformTemporalSubsample(Transform): class UniformTemporalSubsample(Transform):
_transformed_types = (features.is_simple_tensor, features.Video) _transformed_types = (is_simple_tensor, datapoints.Video)
def __init__(self, num_samples: int, temporal_dim: int = -4): def __init__(self, num_samples: int, temporal_dim: int = -4):
super().__init__() super().__init__()
self.num_samples = num_samples self.num_samples = num_samples
self.temporal_dim = temporal_dim self.temporal_dim = temporal_dim
def _transform(self, inpt: features.VideoType, params: Dict[str, Any]) -> features.VideoType: def _transform(self, inpt: datapoints.VideoType, params: Dict[str, Any]) -> datapoints.VideoType:
return F.uniform_temporal_subsample(inpt, self.num_samples, temporal_dim=self.temporal_dim) return F.uniform_temporal_subsample(inpt, self.num_samples, temporal_dim=self.temporal_dim)
...@@ -5,23 +5,26 @@ import PIL.Image ...@@ -5,23 +5,26 @@ import PIL.Image
import torch import torch
from torch.nn.functional import one_hot from torch.nn.functional import one_hot
from torchvision.prototype import features
from torchvision.prototype import datapoints
from torchvision.prototype.transforms import functional as F, Transform from torchvision.prototype.transforms import functional as F, Transform
from torchvision.prototype.transforms.utils import is_simple_tensor
class LabelToOneHot(Transform): class LabelToOneHot(Transform):
_transformed_types = (features.Label,) _transformed_types = (datapoints.Label,)
def __init__(self, num_categories: int = -1): def __init__(self, num_categories: int = -1):
super().__init__() super().__init__()
self.num_categories = num_categories self.num_categories = num_categories
def _transform(self, inpt: features.Label, params: Dict[str, Any]) -> features.OneHotLabel: def _transform(self, inpt: datapoints.Label, params: Dict[str, Any]) -> datapoints.OneHotLabel:
num_categories = self.num_categories num_categories = self.num_categories
if num_categories == -1 and inpt.categories is not None: if num_categories == -1 and inpt.categories is not None:
num_categories = len(inpt.categories) num_categories = len(inpt.categories)
output = one_hot(inpt.as_subclass(torch.Tensor), num_classes=num_categories) output = one_hot(inpt.as_subclass(torch.Tensor), num_classes=num_categories)
return features.OneHotLabel(output, categories=inpt.categories) return datapoints.OneHotLabel(output, categories=inpt.categories)
def extra_repr(self) -> str: def extra_repr(self) -> str:
if self.num_categories == -1: if self.num_categories == -1:
...@@ -38,16 +41,16 @@ class PILToTensor(Transform): ...@@ -38,16 +41,16 @@ class PILToTensor(Transform):
class ToImageTensor(Transform): class ToImageTensor(Transform):
_transformed_types = (features.is_simple_tensor, PIL.Image.Image, np.ndarray) _transformed_types = (is_simple_tensor, PIL.Image.Image, np.ndarray)
def _transform( def _transform(
self, inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray], params: Dict[str, Any] self, inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray], params: Dict[str, Any]
) -> features.Image: ) -> datapoints.Image:
return F.to_image_tensor(inpt) # type: ignore[no-any-return] return F.to_image_tensor(inpt) # type: ignore[no-any-return]
class ToImagePIL(Transform): class ToImagePIL(Transform):
_transformed_types = (features.is_simple_tensor, features.Image, np.ndarray) _transformed_types = (is_simple_tensor, datapoints.Image, np.ndarray)
def __init__(self, mode: Optional[str] = None) -> None: def __init__(self, mode: Optional[str] = None) -> None:
super().__init__() super().__init__()
......
...@@ -3,8 +3,8 @@ import numbers ...@@ -3,8 +3,8 @@ import numbers
from collections import defaultdict from collections import defaultdict
from typing import Any, Dict, Sequence, Type, TypeVar, Union from typing import Any, Dict, Sequence, Type, TypeVar, Union
from torchvision.prototype import features from torchvision.prototype import datapoints
from torchvision.prototype.features._feature import FillType, FillTypeJIT from torchvision.prototype.datapoints._datapoint import FillType, FillTypeJIT
from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401 from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401
...@@ -54,7 +54,7 @@ def _get_defaultdict(default: T) -> Dict[Any, T]: ...@@ -54,7 +54,7 @@ def _get_defaultdict(default: T) -> Dict[Any, T]:
return defaultdict(functools.partial(_default_arg, default)) return defaultdict(functools.partial(_default_arg, default))
def _convert_fill_arg(fill: features.FillType) -> features.FillTypeJIT: def _convert_fill_arg(fill: datapoints.FillType) -> datapoints.FillTypeJIT:
# Fill = 0 is not equivalent to None, https://github.com/pytorch/vision/issues/6517 # Fill = 0 is not equivalent to None, https://github.com/pytorch/vision/issues/6517
# So, we can't reassign fill to 0 # So, we can't reassign fill to 0
# if fill is None: # if fill is None:
......
...@@ -3,7 +3,7 @@ from typing import Union ...@@ -3,7 +3,7 @@ from typing import Union
import PIL.Image import PIL.Image
import torch import torch
from torchvision.prototype import features from torchvision.prototype import datapoints
from torchvision.transforms.functional import pil_to_tensor, to_pil_image from torchvision.transforms.functional import pil_to_tensor, to_pil_image
...@@ -33,28 +33,28 @@ def erase_video( ...@@ -33,28 +33,28 @@ def erase_video(
def erase( def erase(
inpt: Union[features.ImageTypeJIT, features.VideoTypeJIT], inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT],
i: int, i: int,
j: int, j: int,
h: int, h: int,
w: int, w: int,
v: torch.Tensor, v: torch.Tensor,
inplace: bool = False, inplace: bool = False,
) -> Union[features.ImageTypeJIT, features.VideoTypeJIT]: ) -> Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT]:
if isinstance(inpt, torch.Tensor) and ( if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, (features.Image, features.Video)) torch.jit.is_scripting() or not isinstance(inpt, (datapoints.Image, datapoints.Video))
): ):
return erase_image_tensor(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace) return erase_image_tensor(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
elif isinstance(inpt, features.Image): elif isinstance(inpt, datapoints.Image):
output = erase_image_tensor(inpt.as_subclass(torch.Tensor), i=i, j=j, h=h, w=w, v=v, inplace=inplace) output = erase_image_tensor(inpt.as_subclass(torch.Tensor), i=i, j=j, h=h, w=w, v=v, inplace=inplace)
return features.Image.wrap_like(inpt, output) return datapoints.Image.wrap_like(inpt, output)
elif isinstance(inpt, features.Video): elif isinstance(inpt, datapoints.Video):
output = erase_video(inpt.as_subclass(torch.Tensor), i=i, j=j, h=h, w=w, v=v, inplace=inplace) output = erase_video(inpt.as_subclass(torch.Tensor), i=i, j=j, h=h, w=w, v=v, inplace=inplace)
return features.Video.wrap_like(inpt, output) return datapoints.Video.wrap_like(inpt, output)
elif isinstance(inpt, PIL.Image.Image): elif isinstance(inpt, PIL.Image.Image):
return erase_image_pil(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace) return erase_image_pil(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
else: else:
raise TypeError( raise TypeError(
f"Input can either be a plain tensor, an `Image` or `Video` tensor subclass, or a PIL image, " f"Input can either be a plain tensor, an `Image` or `Video` datapoint, or a PIL image, "
f"but got {type(inpt)} instead." f"but got {type(inpt)} instead."
) )
import PIL.Image import PIL.Image
import torch import torch
from torch.nn.functional import conv2d from torch.nn.functional import conv2d
from torchvision.prototype import features from torchvision.prototype import datapoints
from torchvision.transforms import functional_pil as _FP from torchvision.transforms import functional_pil as _FP
from torchvision.transforms.functional_tensor import _max_value from torchvision.transforms.functional_tensor import _max_value
...@@ -37,16 +37,18 @@ def adjust_brightness_video(video: torch.Tensor, brightness_factor: float) -> to ...@@ -37,16 +37,18 @@ def adjust_brightness_video(video: torch.Tensor, brightness_factor: float) -> to
return adjust_brightness_image_tensor(video, brightness_factor=brightness_factor) return adjust_brightness_image_tensor(video, brightness_factor=brightness_factor)
def adjust_brightness(inpt: features.InputTypeJIT, brightness_factor: float) -> features.InputTypeJIT: def adjust_brightness(inpt: datapoints.InputTypeJIT, brightness_factor: float) -> datapoints.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
return adjust_brightness_image_tensor(inpt, brightness_factor=brightness_factor) return adjust_brightness_image_tensor(inpt, brightness_factor=brightness_factor)
elif isinstance(inpt, features._Feature): elif isinstance(inpt, datapoints._datapoint.Datapoint):
return inpt.adjust_brightness(brightness_factor=brightness_factor) return inpt.adjust_brightness(brightness_factor=brightness_factor)
elif isinstance(inpt, PIL.Image.Image): elif isinstance(inpt, PIL.Image.Image):
return adjust_brightness_image_pil(inpt, brightness_factor=brightness_factor) return adjust_brightness_image_pil(inpt, brightness_factor=brightness_factor)
else: else:
raise TypeError( raise TypeError(
f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, " f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead." f"but got {type(inpt)} instead."
) )
...@@ -76,16 +78,18 @@ def adjust_saturation_video(video: torch.Tensor, saturation_factor: float) -> to ...@@ -76,16 +78,18 @@ def adjust_saturation_video(video: torch.Tensor, saturation_factor: float) -> to
return adjust_saturation_image_tensor(video, saturation_factor=saturation_factor) return adjust_saturation_image_tensor(video, saturation_factor=saturation_factor)
def adjust_saturation(inpt: features.InputTypeJIT, saturation_factor: float) -> features.InputTypeJIT: def adjust_saturation(inpt: datapoints.InputTypeJIT, saturation_factor: float) -> datapoints.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
return adjust_saturation_image_tensor(inpt, saturation_factor=saturation_factor) return adjust_saturation_image_tensor(inpt, saturation_factor=saturation_factor)
elif isinstance(inpt, features._Feature): elif isinstance(inpt, datapoints._datapoint.Datapoint):
return inpt.adjust_saturation(saturation_factor=saturation_factor) return inpt.adjust_saturation(saturation_factor=saturation_factor)
elif isinstance(inpt, PIL.Image.Image): elif isinstance(inpt, PIL.Image.Image):
return adjust_saturation_image_pil(inpt, saturation_factor=saturation_factor) return adjust_saturation_image_pil(inpt, saturation_factor=saturation_factor)
else: else:
raise TypeError( raise TypeError(
f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, " f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead." f"but got {type(inpt)} instead."
) )
...@@ -115,16 +119,18 @@ def adjust_contrast_video(video: torch.Tensor, contrast_factor: float) -> torch. ...@@ -115,16 +119,18 @@ def adjust_contrast_video(video: torch.Tensor, contrast_factor: float) -> torch.
return adjust_contrast_image_tensor(video, contrast_factor=contrast_factor) return adjust_contrast_image_tensor(video, contrast_factor=contrast_factor)
def adjust_contrast(inpt: features.InputTypeJIT, contrast_factor: float) -> features.InputTypeJIT: def adjust_contrast(inpt: datapoints.InputTypeJIT, contrast_factor: float) -> datapoints.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
return adjust_contrast_image_tensor(inpt, contrast_factor=contrast_factor) return adjust_contrast_image_tensor(inpt, contrast_factor=contrast_factor)
elif isinstance(inpt, features._Feature): elif isinstance(inpt, datapoints._datapoint.Datapoint):
return inpt.adjust_contrast(contrast_factor=contrast_factor) return inpt.adjust_contrast(contrast_factor=contrast_factor)
elif isinstance(inpt, PIL.Image.Image): elif isinstance(inpt, PIL.Image.Image):
return adjust_contrast_image_pil(inpt, contrast_factor=contrast_factor) return adjust_contrast_image_pil(inpt, contrast_factor=contrast_factor)
else: else:
raise TypeError( raise TypeError(
f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, " f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead." f"but got {type(inpt)} instead."
) )
...@@ -188,16 +194,18 @@ def adjust_sharpness_video(video: torch.Tensor, sharpness_factor: float) -> torc ...@@ -188,16 +194,18 @@ def adjust_sharpness_video(video: torch.Tensor, sharpness_factor: float) -> torc
return adjust_sharpness_image_tensor(video, sharpness_factor=sharpness_factor) return adjust_sharpness_image_tensor(video, sharpness_factor=sharpness_factor)
def adjust_sharpness(inpt: features.InputTypeJIT, sharpness_factor: float) -> features.InputTypeJIT: def adjust_sharpness(inpt: datapoints.InputTypeJIT, sharpness_factor: float) -> datapoints.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
return adjust_sharpness_image_tensor(inpt, sharpness_factor=sharpness_factor) return adjust_sharpness_image_tensor(inpt, sharpness_factor=sharpness_factor)
elif isinstance(inpt, features._Feature): elif isinstance(inpt, datapoints._datapoint.Datapoint):
return inpt.adjust_sharpness(sharpness_factor=sharpness_factor) return inpt.adjust_sharpness(sharpness_factor=sharpness_factor)
elif isinstance(inpt, PIL.Image.Image): elif isinstance(inpt, PIL.Image.Image):
return adjust_sharpness_image_pil(inpt, sharpness_factor=sharpness_factor) return adjust_sharpness_image_pil(inpt, sharpness_factor=sharpness_factor)
else: else:
raise TypeError( raise TypeError(
f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, " f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead." f"but got {type(inpt)} instead."
) )
...@@ -300,16 +308,18 @@ def adjust_hue_video(video: torch.Tensor, hue_factor: float) -> torch.Tensor: ...@@ -300,16 +308,18 @@ def adjust_hue_video(video: torch.Tensor, hue_factor: float) -> torch.Tensor:
return adjust_hue_image_tensor(video, hue_factor=hue_factor) return adjust_hue_image_tensor(video, hue_factor=hue_factor)
def adjust_hue(inpt: features.InputTypeJIT, hue_factor: float) -> features.InputTypeJIT: def adjust_hue(inpt: datapoints.InputTypeJIT, hue_factor: float) -> datapoints.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
return adjust_hue_image_tensor(inpt, hue_factor=hue_factor) return adjust_hue_image_tensor(inpt, hue_factor=hue_factor)
elif isinstance(inpt, features._Feature): elif isinstance(inpt, datapoints._datapoint.Datapoint):
return inpt.adjust_hue(hue_factor=hue_factor) return inpt.adjust_hue(hue_factor=hue_factor)
elif isinstance(inpt, PIL.Image.Image): elif isinstance(inpt, PIL.Image.Image):
return adjust_hue_image_pil(inpt, hue_factor=hue_factor) return adjust_hue_image_pil(inpt, hue_factor=hue_factor)
else: else:
raise TypeError( raise TypeError(
f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, " f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead." f"but got {type(inpt)} instead."
) )
...@@ -340,16 +350,18 @@ def adjust_gamma_video(video: torch.Tensor, gamma: float, gain: float = 1) -> to ...@@ -340,16 +350,18 @@ def adjust_gamma_video(video: torch.Tensor, gamma: float, gain: float = 1) -> to
return adjust_gamma_image_tensor(video, gamma=gamma, gain=gain) return adjust_gamma_image_tensor(video, gamma=gamma, gain=gain)
def adjust_gamma(inpt: features.InputTypeJIT, gamma: float, gain: float = 1) -> features.InputTypeJIT: def adjust_gamma(inpt: datapoints.InputTypeJIT, gamma: float, gain: float = 1) -> datapoints.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
return adjust_gamma_image_tensor(inpt, gamma=gamma, gain=gain) return adjust_gamma_image_tensor(inpt, gamma=gamma, gain=gain)
elif isinstance(inpt, features._Feature): elif isinstance(inpt, datapoints._datapoint.Datapoint):
return inpt.adjust_gamma(gamma=gamma, gain=gain) return inpt.adjust_gamma(gamma=gamma, gain=gain)
elif isinstance(inpt, PIL.Image.Image): elif isinstance(inpt, PIL.Image.Image):
return adjust_gamma_image_pil(inpt, gamma=gamma, gain=gain) return adjust_gamma_image_pil(inpt, gamma=gamma, gain=gain)
else: else:
raise TypeError( raise TypeError(
f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, " f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead." f"but got {type(inpt)} instead."
) )
...@@ -374,16 +386,18 @@ def posterize_video(video: torch.Tensor, bits: int) -> torch.Tensor: ...@@ -374,16 +386,18 @@ def posterize_video(video: torch.Tensor, bits: int) -> torch.Tensor:
return posterize_image_tensor(video, bits=bits) return posterize_image_tensor(video, bits=bits)
def posterize(inpt: features.InputTypeJIT, bits: int) -> features.InputTypeJIT: def posterize(inpt: datapoints.InputTypeJIT, bits: int) -> datapoints.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
return posterize_image_tensor(inpt, bits=bits) return posterize_image_tensor(inpt, bits=bits)
elif isinstance(inpt, features._Feature): elif isinstance(inpt, datapoints._datapoint.Datapoint):
return inpt.posterize(bits=bits) return inpt.posterize(bits=bits)
elif isinstance(inpt, PIL.Image.Image): elif isinstance(inpt, PIL.Image.Image):
return posterize_image_pil(inpt, bits=bits) return posterize_image_pil(inpt, bits=bits)
else: else:
raise TypeError( raise TypeError(
f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, " f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead." f"but got {type(inpt)} instead."
) )
...@@ -402,16 +416,18 @@ def solarize_video(video: torch.Tensor, threshold: float) -> torch.Tensor: ...@@ -402,16 +416,18 @@ def solarize_video(video: torch.Tensor, threshold: float) -> torch.Tensor:
return solarize_image_tensor(video, threshold=threshold) return solarize_image_tensor(video, threshold=threshold)
def solarize(inpt: features.InputTypeJIT, threshold: float) -> features.InputTypeJIT: def solarize(inpt: datapoints.InputTypeJIT, threshold: float) -> datapoints.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
return solarize_image_tensor(inpt, threshold=threshold) return solarize_image_tensor(inpt, threshold=threshold)
elif isinstance(inpt, features._Feature): elif isinstance(inpt, datapoints._datapoint.Datapoint):
return inpt.solarize(threshold=threshold) return inpt.solarize(threshold=threshold)
elif isinstance(inpt, PIL.Image.Image): elif isinstance(inpt, PIL.Image.Image):
return solarize_image_pil(inpt, threshold=threshold) return solarize_image_pil(inpt, threshold=threshold)
else: else:
raise TypeError( raise TypeError(
f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, " f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead." f"but got {type(inpt)} instead."
) )
...@@ -452,16 +468,18 @@ def autocontrast_video(video: torch.Tensor) -> torch.Tensor: ...@@ -452,16 +468,18 @@ def autocontrast_video(video: torch.Tensor) -> torch.Tensor:
return autocontrast_image_tensor(video) return autocontrast_image_tensor(video)
def autocontrast(inpt: features.InputTypeJIT) -> features.InputTypeJIT: def autocontrast(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
return autocontrast_image_tensor(inpt) return autocontrast_image_tensor(inpt)
elif isinstance(inpt, features._Feature): elif isinstance(inpt, datapoints._datapoint.Datapoint):
return inpt.autocontrast() return inpt.autocontrast()
elif isinstance(inpt, PIL.Image.Image): elif isinstance(inpt, PIL.Image.Image):
return autocontrast_image_pil(inpt) return autocontrast_image_pil(inpt)
else: else:
raise TypeError( raise TypeError(
f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, " f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead." f"but got {type(inpt)} instead."
) )
...@@ -542,16 +560,18 @@ def equalize_video(video: torch.Tensor) -> torch.Tensor: ...@@ -542,16 +560,18 @@ def equalize_video(video: torch.Tensor) -> torch.Tensor:
return equalize_image_tensor(video) return equalize_image_tensor(video)
def equalize(inpt: features.InputTypeJIT) -> features.InputTypeJIT: def equalize(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
return equalize_image_tensor(inpt) return equalize_image_tensor(inpt)
elif isinstance(inpt, features._Feature): elif isinstance(inpt, datapoints._datapoint.Datapoint):
return inpt.equalize() return inpt.equalize()
elif isinstance(inpt, PIL.Image.Image): elif isinstance(inpt, PIL.Image.Image):
return equalize_image_pil(inpt) return equalize_image_pil(inpt)
else: else:
raise TypeError( raise TypeError(
f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, " f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead." f"but got {type(inpt)} instead."
) )
...@@ -573,15 +593,17 @@ def invert_video(video: torch.Tensor) -> torch.Tensor: ...@@ -573,15 +593,17 @@ def invert_video(video: torch.Tensor) -> torch.Tensor:
return invert_image_tensor(video) return invert_image_tensor(video)
def invert(inpt: features.InputTypeJIT) -> features.InputTypeJIT: def invert(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
return invert_image_tensor(inpt) return invert_image_tensor(inpt)
elif isinstance(inpt, features._Feature): elif isinstance(inpt, datapoints._datapoint.Datapoint):
return inpt.invert() return inpt.invert()
elif isinstance(inpt, PIL.Image.Image): elif isinstance(inpt, PIL.Image.Image):
return invert_image_pil(inpt) return invert_image_pil(inpt)
else: else:
raise TypeError( raise TypeError(
f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, " f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead." f"but got {type(inpt)} instead."
) )
...@@ -4,16 +4,16 @@ from typing import Any, List, Union ...@@ -4,16 +4,16 @@ from typing import Any, List, Union
import PIL.Image import PIL.Image
import torch import torch
from torchvision.prototype import features from torchvision.prototype import datapoints
from torchvision.transforms import functional as _F from torchvision.transforms import functional as _F
@torch.jit.unused @torch.jit.unused
def to_grayscale(inpt: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Image.Image: def to_grayscale(inpt: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Image.Image:
call = ", num_output_channels=3" if num_output_channels == 3 else "" call = ", num_output_channels=3" if num_output_channels == 3 else ""
replacement = "convert_color_space(..., color_space=features.ColorSpace.GRAY)" replacement = "convert_color_space(..., color_space=datapoints.ColorSpace.GRAY)"
if num_output_channels == 3: if num_output_channels == 3:
replacement = f"convert_color_space({replacement}, color_space=features.ColorSpace.RGB)" replacement = f"convert_color_space({replacement}, color_space=datapoints.ColorSpace.RGB)"
warnings.warn( warnings.warn(
f"The function `to_grayscale(...{call})` is deprecated in will be removed in a future release. " f"The function `to_grayscale(...{call})` is deprecated in will be removed in a future release. "
f"Instead, please use `{replacement}`.", f"Instead, please use `{replacement}`.",
...@@ -23,25 +23,25 @@ def to_grayscale(inpt: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Ima ...@@ -23,25 +23,25 @@ def to_grayscale(inpt: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Ima
def rgb_to_grayscale( def rgb_to_grayscale(
inpt: Union[features.ImageTypeJIT, features.VideoTypeJIT], num_output_channels: int = 1 inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT], num_output_channels: int = 1
) -> Union[features.ImageTypeJIT, features.VideoTypeJIT]: ) -> Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT]:
if not torch.jit.is_scripting() and isinstance(inpt, (features.Image, features.Video)): if not torch.jit.is_scripting() and isinstance(inpt, (datapoints.Image, datapoints.Video)):
inpt = inpt.as_subclass(torch.Tensor) inpt = inpt.as_subclass(torch.Tensor)
old_color_space = None old_color_space = None
elif isinstance(inpt, torch.Tensor): elif isinstance(inpt, torch.Tensor):
old_color_space = features._image._from_tensor_shape(inpt.shape) # type: ignore[arg-type] old_color_space = datapoints._image._from_tensor_shape(inpt.shape) # type: ignore[arg-type]
else: else:
old_color_space = None old_color_space = None
call = ", num_output_channels=3" if num_output_channels == 3 else "" call = ", num_output_channels=3" if num_output_channels == 3 else ""
replacement = ( replacement = (
f"convert_color_space(..., color_space=features.ColorSpace.GRAY" f"convert_color_space(..., color_space=datapoints.ColorSpace.GRAY"
f"{f', old_color_space=features.ColorSpace.{old_color_space}' if old_color_space is not None else ''})" f"{f', old_color_space=datapoints.ColorSpace.{old_color_space}' if old_color_space is not None else ''})"
) )
if num_output_channels == 3: if num_output_channels == 3:
replacement = ( replacement = (
f"convert_color_space({replacement}, color_space=features.ColorSpace.RGB" f"convert_color_space({replacement}, color_space=datapoints.ColorSpace.RGB"
f"{f', old_color_space=features.ColorSpace.GRAY' if old_color_space is not None else ''})" f"{f', old_color_space=datapoints.ColorSpace.GRAY' if old_color_space is not None else ''})"
) )
warnings.warn( warnings.warn(
f"The function `rgb_to_grayscale(...{call})` is deprecated in will be removed in a future release. " f"The function `rgb_to_grayscale(...{call})` is deprecated in will be removed in a future release. "
...@@ -60,7 +60,7 @@ def to_tensor(inpt: Any) -> torch.Tensor: ...@@ -60,7 +60,7 @@ def to_tensor(inpt: Any) -> torch.Tensor:
return _F.to_tensor(inpt) return _F.to_tensor(inpt)
def get_image_size(inpt: Union[features.ImageTypeJIT, features.VideoTypeJIT]) -> List[int]: def get_image_size(inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT]) -> List[int]:
warnings.warn( warnings.warn(
"The function `get_image_size(...)` is deprecated and will be removed in a future release. " "The function `get_image_size(...)` is deprecated and will be removed in a future release. "
"Instead, please use `get_spatial_size(...)` which returns `[h, w]` instead of `[w, h]`." "Instead, please use `get_spatial_size(...)` which returns `[h, w]` instead of `[w, h]`."
......
...@@ -2,8 +2,8 @@ from typing import List, Optional, Tuple, Union ...@@ -2,8 +2,8 @@ from typing import List, Optional, Tuple, Union
import PIL.Image import PIL.Image
import torch import torch
from torchvision.prototype import features from torchvision.prototype import datapoints
from torchvision.prototype.features import BoundingBoxFormat, ColorSpace from torchvision.prototype.datapoints import BoundingBoxFormat, ColorSpace
from torchvision.transforms import functional_pil as _FP from torchvision.transforms import functional_pil as _FP
from torchvision.transforms.functional_tensor import _max_value from torchvision.transforms.functional_tensor import _max_value
...@@ -23,12 +23,12 @@ def get_dimensions_image_tensor(image: torch.Tensor) -> List[int]: ...@@ -23,12 +23,12 @@ def get_dimensions_image_tensor(image: torch.Tensor) -> List[int]:
get_dimensions_image_pil = _FP.get_dimensions get_dimensions_image_pil = _FP.get_dimensions
def get_dimensions(inpt: Union[features.ImageTypeJIT, features.VideoTypeJIT]) -> List[int]: def get_dimensions(inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT]) -> List[int]:
if isinstance(inpt, torch.Tensor) and ( if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, (features.Image, features.Video)) torch.jit.is_scripting() or not isinstance(inpt, (datapoints.Image, datapoints.Video))
): ):
return get_dimensions_image_tensor(inpt) return get_dimensions_image_tensor(inpt)
elif isinstance(inpt, (features.Image, features.Video)): elif isinstance(inpt, (datapoints.Image, datapoints.Video)):
channels = inpt.num_channels channels = inpt.num_channels
height, width = inpt.spatial_size height, width = inpt.spatial_size
return [channels, height, width] return [channels, height, width]
...@@ -36,7 +36,7 @@ def get_dimensions(inpt: Union[features.ImageTypeJIT, features.VideoTypeJIT]) -> ...@@ -36,7 +36,7 @@ def get_dimensions(inpt: Union[features.ImageTypeJIT, features.VideoTypeJIT]) ->
return get_dimensions_image_pil(inpt) return get_dimensions_image_pil(inpt)
else: else:
raise TypeError( raise TypeError(
f"Input can either be a plain tensor, an `Image` or `Video` tensor subclass, or a PIL image, " f"Input can either be a plain tensor, an `Image` or `Video` datapoint, or a PIL image, "
f"but got {type(inpt)} instead." f"but got {type(inpt)} instead."
) )
...@@ -59,18 +59,18 @@ def get_num_channels_video(video: torch.Tensor) -> int: ...@@ -59,18 +59,18 @@ def get_num_channels_video(video: torch.Tensor) -> int:
return get_num_channels_image_tensor(video) return get_num_channels_image_tensor(video)
def get_num_channels(inpt: Union[features.ImageTypeJIT, features.VideoTypeJIT]) -> int: def get_num_channels(inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT]) -> int:
if isinstance(inpt, torch.Tensor) and ( if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, (features.Image, features.Video)) torch.jit.is_scripting() or not isinstance(inpt, (datapoints.Image, datapoints.Video))
): ):
return get_num_channels_image_tensor(inpt) return get_num_channels_image_tensor(inpt)
elif isinstance(inpt, (features.Image, features.Video)): elif isinstance(inpt, (datapoints.Image, datapoints.Video)):
return inpt.num_channels return inpt.num_channels
elif isinstance(inpt, PIL.Image.Image): elif isinstance(inpt, PIL.Image.Image):
return get_num_channels_image_pil(inpt) return get_num_channels_image_pil(inpt)
else: else:
raise TypeError( raise TypeError(
f"Input can either be a plain tensor, an `Image` or `Video` tensor subclass, or a PIL image, " f"Input can either be a plain tensor, an `Image` or `Video` datapoint, or a PIL image, "
f"but got {type(inpt)} instead." f"but got {type(inpt)} instead."
) )
...@@ -104,20 +104,22 @@ def get_spatial_size_mask(mask: torch.Tensor) -> List[int]: ...@@ -104,20 +104,22 @@ def get_spatial_size_mask(mask: torch.Tensor) -> List[int]:
@torch.jit.unused @torch.jit.unused
def get_spatial_size_bounding_box(bounding_box: features.BoundingBox) -> List[int]: def get_spatial_size_bounding_box(bounding_box: datapoints.BoundingBox) -> List[int]:
return list(bounding_box.spatial_size) return list(bounding_box.spatial_size)
def get_spatial_size(inpt: features.InputTypeJIT) -> List[int]: def get_spatial_size(inpt: datapoints.InputTypeJIT) -> List[int]:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
return get_spatial_size_image_tensor(inpt) return get_spatial_size_image_tensor(inpt)
elif isinstance(inpt, (features.Image, features.Video, features.BoundingBox, features.Mask)): elif isinstance(inpt, (datapoints.Image, datapoints.Video, datapoints.BoundingBox, datapoints.Mask)):
return list(inpt.spatial_size) return list(inpt.spatial_size)
elif isinstance(inpt, PIL.Image.Image): elif isinstance(inpt, PIL.Image.Image):
return get_spatial_size_image_pil(inpt) # type: ignore[no-any-return] return get_spatial_size_image_pil(inpt) # type: ignore[no-any-return]
else: else:
raise TypeError( raise TypeError(
f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, " f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead." f"but got {type(inpt)} instead."
) )
...@@ -126,15 +128,13 @@ def get_num_frames_video(video: torch.Tensor) -> int: ...@@ -126,15 +128,13 @@ def get_num_frames_video(video: torch.Tensor) -> int:
return video.shape[-4] return video.shape[-4]
def get_num_frames(inpt: features.VideoTypeJIT) -> int: def get_num_frames(inpt: datapoints.VideoTypeJIT) -> int:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features.Video)): if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, datapoints.Video)):
return get_num_frames_video(inpt) return get_num_frames_video(inpt)
elif isinstance(inpt, features.Video): elif isinstance(inpt, datapoints.Video):
return inpt.num_frames return inpt.num_frames
else: else:
raise TypeError( raise TypeError(f"Input can either be a plain tensor or a `Video` datapoint, but got {type(inpt)} instead.")
f"Input can either be a plain tensor or a `Video` tensor subclass, but got {type(inpt)} instead."
)
def _xywh_to_xyxy(xywh: torch.Tensor, inplace: bool) -> torch.Tensor: def _xywh_to_xyxy(xywh: torch.Tensor, inplace: bool) -> torch.Tensor:
...@@ -202,7 +202,7 @@ def clamp_bounding_box( ...@@ -202,7 +202,7 @@ def clamp_bounding_box(
# TODO: Investigate if it makes sense from a performance perspective to have an implementation for every # TODO: Investigate if it makes sense from a performance perspective to have an implementation for every
# BoundingBoxFormat instead of converting back and forth # BoundingBoxFormat instead of converting back and forth
xyxy_boxes = convert_format_bounding_box( xyxy_boxes = convert_format_bounding_box(
bounding_box.clone(), old_format=format, new_format=features.BoundingBoxFormat.XYXY, inplace=True bounding_box.clone(), old_format=format, new_format=datapoints.BoundingBoxFormat.XYXY, inplace=True
) )
xyxy_boxes[..., 0::2].clamp_(min=0, max=spatial_size[1]) xyxy_boxes[..., 0::2].clamp_(min=0, max=spatial_size[1])
xyxy_boxes[..., 1::2].clamp_(min=0, max=spatial_size[0]) xyxy_boxes[..., 1::2].clamp_(min=0, max=spatial_size[0])
...@@ -309,12 +309,12 @@ def convert_color_space_video( ...@@ -309,12 +309,12 @@ def convert_color_space_video(
def convert_color_space( def convert_color_space(
inpt: Union[features.ImageTypeJIT, features.VideoTypeJIT], inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT],
color_space: ColorSpace, color_space: ColorSpace,
old_color_space: Optional[ColorSpace] = None, old_color_space: Optional[ColorSpace] = None,
) -> Union[features.ImageTypeJIT, features.VideoTypeJIT]: ) -> Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT]:
if isinstance(inpt, torch.Tensor) and ( if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, (features.Image, features.Video)) torch.jit.is_scripting() or not isinstance(inpt, (datapoints.Image, datapoints.Video))
): ):
if old_color_space is None: if old_color_space is None:
raise RuntimeError( raise RuntimeError(
...@@ -322,21 +322,21 @@ def convert_color_space( ...@@ -322,21 +322,21 @@ def convert_color_space(
"the `old_color_space=...` parameter needs to be passed." "the `old_color_space=...` parameter needs to be passed."
) )
return convert_color_space_image_tensor(inpt, old_color_space=old_color_space, new_color_space=color_space) return convert_color_space_image_tensor(inpt, old_color_space=old_color_space, new_color_space=color_space)
elif isinstance(inpt, features.Image): elif isinstance(inpt, datapoints.Image):
output = convert_color_space_image_tensor( output = convert_color_space_image_tensor(
inpt.as_subclass(torch.Tensor), old_color_space=inpt.color_space, new_color_space=color_space inpt.as_subclass(torch.Tensor), old_color_space=inpt.color_space, new_color_space=color_space
) )
return features.Image.wrap_like(inpt, output, color_space=color_space) return datapoints.Image.wrap_like(inpt, output, color_space=color_space)
elif isinstance(inpt, features.Video): elif isinstance(inpt, datapoints.Video):
output = convert_color_space_video( output = convert_color_space_video(
inpt.as_subclass(torch.Tensor), old_color_space=inpt.color_space, new_color_space=color_space inpt.as_subclass(torch.Tensor), old_color_space=inpt.color_space, new_color_space=color_space
) )
return features.Video.wrap_like(inpt, output, color_space=color_space) return datapoints.Video.wrap_like(inpt, output, color_space=color_space)
elif isinstance(inpt, PIL.Image.Image): elif isinstance(inpt, PIL.Image.Image):
return convert_color_space_image_pil(inpt, color_space=color_space) return convert_color_space_image_pil(inpt, color_space=color_space)
else: else:
raise TypeError( raise TypeError(
f"Input can either be a plain tensor, an `Image` or `Video` tensor subclass, or a PIL image, " f"Input can either be a plain tensor, an `Image` or `Video` datapoint, or a PIL image, "
f"but got {type(inpt)} instead." f"but got {type(inpt)} instead."
) )
...@@ -415,20 +415,19 @@ def convert_dtype_video(video: torch.Tensor, dtype: torch.dtype = torch.float) - ...@@ -415,20 +415,19 @@ def convert_dtype_video(video: torch.Tensor, dtype: torch.dtype = torch.float) -
def convert_dtype( def convert_dtype(
inpt: Union[features.ImageTypeJIT, features.VideoTypeJIT], dtype: torch.dtype = torch.float inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT], dtype: torch.dtype = torch.float
) -> torch.Tensor: ) -> torch.Tensor:
if isinstance(inpt, torch.Tensor) and ( if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, (features.Image, features.Video)) torch.jit.is_scripting() or not isinstance(inpt, (datapoints.Image, datapoints.Video))
): ):
return convert_dtype_image_tensor(inpt, dtype) return convert_dtype_image_tensor(inpt, dtype)
elif isinstance(inpt, features.Image): elif isinstance(inpt, datapoints.Image):
output = convert_dtype_image_tensor(inpt.as_subclass(torch.Tensor), dtype) output = convert_dtype_image_tensor(inpt.as_subclass(torch.Tensor), dtype)
return features.Image.wrap_like(inpt, output) return datapoints.Image.wrap_like(inpt, output)
elif isinstance(inpt, features.Video): elif isinstance(inpt, datapoints.Video):
output = convert_dtype_video(inpt.as_subclass(torch.Tensor), dtype) output = convert_dtype_video(inpt.as_subclass(torch.Tensor), dtype)
return features.Video.wrap_like(inpt, output) return datapoints.Video.wrap_like(inpt, output)
else: else:
raise TypeError( raise TypeError(
f"Input can either be a plain tensor or an `Image` or `Video` tensor subclass, " f"Input can either be a plain tensor or an `Image` or `Video` datapoint, " f"but got {type(inpt)} instead."
f"but got {type(inpt)} instead."
) )
...@@ -4,9 +4,12 @@ from typing import List, Optional, Union ...@@ -4,9 +4,12 @@ from typing import List, Optional, Union
import PIL.Image import PIL.Image
import torch import torch
from torch.nn.functional import conv2d, pad as torch_pad from torch.nn.functional import conv2d, pad as torch_pad
from torchvision.prototype import features
from torchvision.prototype import datapoints
from torchvision.transforms.functional import pil_to_tensor, to_pil_image from torchvision.transforms.functional import pil_to_tensor, to_pil_image
from ..utils import is_simple_tensor
def normalize_image_tensor( def normalize_image_tensor(
image: torch.Tensor, mean: List[float], std: List[float], inplace: bool = False image: torch.Tensor, mean: List[float], std: List[float], inplace: bool = False
...@@ -48,17 +51,17 @@ def normalize_video(video: torch.Tensor, mean: List[float], std: List[float], in ...@@ -48,17 +51,17 @@ def normalize_video(video: torch.Tensor, mean: List[float], std: List[float], in
def normalize( def normalize(
inpt: Union[features.TensorImageTypeJIT, features.TensorVideoTypeJIT], inpt: Union[datapoints.TensorImageTypeJIT, datapoints.TensorVideoTypeJIT],
mean: List[float], mean: List[float],
std: List[float], std: List[float],
inplace: bool = False, inplace: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
if features.is_simple_tensor(inpt) or isinstance(inpt, (features.Image, features.Video)): if is_simple_tensor(inpt) or isinstance(inpt, (datapoints.Image, datapoints.Video)):
inpt = inpt.as_subclass(torch.Tensor) inpt = inpt.as_subclass(torch.Tensor)
else: else:
raise TypeError( raise TypeError(
f"Input can either be a plain tensor or an `Image` or `Video` tensor subclass, " f"Input can either be a plain tensor or an `Image` or `Video` datapoint, "
f"but got {type(inpt)} instead." f"but got {type(inpt)} instead."
) )
...@@ -163,16 +166,18 @@ def gaussian_blur_video( ...@@ -163,16 +166,18 @@ def gaussian_blur_video(
def gaussian_blur( def gaussian_blur(
inpt: features.InputTypeJIT, kernel_size: List[int], sigma: Optional[List[float]] = None inpt: datapoints.InputTypeJIT, kernel_size: List[int], sigma: Optional[List[float]] = None
) -> features.InputTypeJIT: ) -> datapoints.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
return gaussian_blur_image_tensor(inpt, kernel_size=kernel_size, sigma=sigma) return gaussian_blur_image_tensor(inpt, kernel_size=kernel_size, sigma=sigma)
elif isinstance(inpt, features._Feature): elif isinstance(inpt, datapoints._datapoint.Datapoint):
return inpt.gaussian_blur(kernel_size=kernel_size, sigma=sigma) return inpt.gaussian_blur(kernel_size=kernel_size, sigma=sigma)
elif isinstance(inpt, PIL.Image.Image): elif isinstance(inpt, PIL.Image.Image):
return gaussian_blur_image_pil(inpt, kernel_size=kernel_size, sigma=sigma) return gaussian_blur_image_pil(inpt, kernel_size=kernel_size, sigma=sigma)
else: else:
raise TypeError( raise TypeError(
f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, " f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead." f"but got {type(inpt)} instead."
) )
import torch import torch
from torchvision.prototype import features from torchvision.prototype import datapoints
def uniform_temporal_subsample_video(video: torch.Tensor, num_samples: int, temporal_dim: int = -4) -> torch.Tensor: def uniform_temporal_subsample_video(video: torch.Tensor, num_samples: int, temporal_dim: int = -4) -> torch.Tensor:
...@@ -11,18 +11,16 @@ def uniform_temporal_subsample_video(video: torch.Tensor, num_samples: int, temp ...@@ -11,18 +11,16 @@ def uniform_temporal_subsample_video(video: torch.Tensor, num_samples: int, temp
def uniform_temporal_subsample( def uniform_temporal_subsample(
inpt: features.VideoTypeJIT, num_samples: int, temporal_dim: int = -4 inpt: datapoints.VideoTypeJIT, num_samples: int, temporal_dim: int = -4
) -> features.VideoTypeJIT: ) -> datapoints.VideoTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features.Video)): if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, datapoints.Video)):
return uniform_temporal_subsample_video(inpt, num_samples, temporal_dim=temporal_dim) return uniform_temporal_subsample_video(inpt, num_samples, temporal_dim=temporal_dim)
elif isinstance(inpt, features.Video): elif isinstance(inpt, datapoints.Video):
if temporal_dim != -4 and inpt.ndim - 4 != temporal_dim: if temporal_dim != -4 and inpt.ndim - 4 != temporal_dim:
raise ValueError("Video inputs must have temporal_dim equivalent to -4") raise ValueError("Video inputs must have temporal_dim equivalent to -4")
output = uniform_temporal_subsample_video( output = uniform_temporal_subsample_video(
inpt.as_subclass(torch.Tensor), num_samples, temporal_dim=temporal_dim inpt.as_subclass(torch.Tensor), num_samples, temporal_dim=temporal_dim
) )
return features.Video.wrap_like(inpt, output) return datapoints.Video.wrap_like(inpt, output)
else: else:
raise TypeError( raise TypeError(f"Input can either be a plain tensor or a `Video` datapoint, but got {type(inpt)} instead.")
f"Input can either be a plain tensor or a `Video` tensor subclass, but got {type(inpt)} instead."
)
...@@ -3,12 +3,12 @@ from typing import Union ...@@ -3,12 +3,12 @@ from typing import Union
import numpy as np import numpy as np
import PIL.Image import PIL.Image
import torch import torch
from torchvision.prototype import features from torchvision.prototype import datapoints
from torchvision.transforms import functional as _F from torchvision.transforms import functional as _F
@torch.jit.unused @torch.jit.unused
def to_image_tensor(inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray]) -> features.Image: def to_image_tensor(inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray]) -> datapoints.Image:
if isinstance(inpt, np.ndarray): if isinstance(inpt, np.ndarray):
output = torch.from_numpy(inpt).permute((2, 0, 1)).contiguous() output = torch.from_numpy(inpt).permute((2, 0, 1)).contiguous()
elif isinstance(inpt, PIL.Image.Image): elif isinstance(inpt, PIL.Image.Image):
...@@ -17,7 +17,7 @@ def to_image_tensor(inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray]) -> f ...@@ -17,7 +17,7 @@ def to_image_tensor(inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray]) -> f
output = inpt output = inpt
else: else:
raise TypeError(f"Input can either be a numpy array or a PIL image, but got {type(inpt)} instead.") raise TypeError(f"Input can either be a numpy array or a PIL image, but got {type(inpt)} instead.")
return features.Image(output) return datapoints.Image(output)
to_image_pil = _F.to_pil_image to_image_pil = _F.to_pil_image
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment