Unverified Commit a8007dcd authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

rename features._Feature to datapoints._Datapoint (#7002)

* rename features._Feature to datapoints.Datapoint

* _Datapoint to Datapoint

* move is_simple_tensor to transforms.utils

* fix CI

* move Datapoint out of public namespace
parent c093b9c0
......@@ -5,9 +5,9 @@ import pathlib
from typing import Any, BinaryIO, Collection, Dict, List, Optional, Tuple, Union
from torchdata.datapipes.iter import FileLister, FileOpener, Filter, IterDataPipe, Mapper
from torchvision.prototype.datapoints import Label
from torchvision.prototype.datasets.utils import EncodedData, EncodedImage
from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling
from torchvision.prototype.features import Label
__all__ = ["from_data_folder", "from_image_folder"]
......
......@@ -7,13 +7,13 @@ from typing import Any, BinaryIO, Optional, Tuple, Type, TypeVar, Union
import PIL.Image
import torch
from torchvision.prototype.features._feature import _Feature
from torchvision.prototype.datapoints._datapoint import Datapoint
from torchvision.prototype.utils._internal import fromfile, ReadOnlyTensorBuffer
D = TypeVar("D", bound="EncodedData")
class EncodedData(_Feature):
class EncodedData(Datapoint):
@classmethod
def _wrap(cls: Type[D], tensor: torch.Tensor) -> D:
return tensor.as_subclass(cls)
......
......@@ -6,16 +6,17 @@ from typing import Any, cast, Dict, List, Optional, Tuple, Union
import PIL.Image
import torch
from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision.ops import masks_to_boxes
from torchvision.prototype import features
from torchvision.prototype import datapoints
from torchvision.prototype.transforms import functional as F, InterpolationMode
from ._transform import _RandomApplyTransform
from .utils import has_any, query_chw, query_spatial_size
from .utils import has_any, is_simple_tensor, query_chw, query_spatial_size
class RandomErasing(_RandomApplyTransform):
_transformed_types = (features.is_simple_tensor, features.Image, PIL.Image.Image, features.Video)
_transformed_types = (is_simple_tensor, datapoints.Image, PIL.Image.Image, datapoints.Video)
def __init__(
self,
......@@ -91,8 +92,8 @@ class RandomErasing(_RandomApplyTransform):
return dict(i=i, j=j, h=h, w=w, v=v)
def _transform(
self, inpt: Union[features.ImageType, features.VideoType], params: Dict[str, Any]
) -> Union[features.ImageType, features.VideoType]:
self, inpt: Union[datapoints.ImageType, datapoints.VideoType], params: Dict[str, Any]
) -> Union[datapoints.ImageType, datapoints.VideoType]:
if params["v"] is not None:
inpt = F.erase(inpt, **params, inplace=self.inplace)
......@@ -107,20 +108,20 @@ class _BaseMixupCutmix(_RandomApplyTransform):
def _check_inputs(self, flat_inputs: List[Any]) -> None:
if not (
has_any(flat_inputs, features.Image, features.Video, features.is_simple_tensor)
and has_any(flat_inputs, features.OneHotLabel)
has_any(flat_inputs, datapoints.Image, datapoints.Video, is_simple_tensor)
and has_any(flat_inputs, datapoints.OneHotLabel)
):
raise TypeError(f"{type(self).__name__}() is only defined for tensor images/videos and one-hot labels.")
if has_any(flat_inputs, PIL.Image.Image, features.BoundingBox, features.Mask, features.Label):
if has_any(flat_inputs, PIL.Image.Image, datapoints.BoundingBox, datapoints.Mask, datapoints.Label):
raise TypeError(
f"{type(self).__name__}() does not support PIL images, bounding boxes, masks and plain labels."
)
def _mixup_onehotlabel(self, inpt: features.OneHotLabel, lam: float) -> features.OneHotLabel:
def _mixup_onehotlabel(self, inpt: datapoints.OneHotLabel, lam: float) -> datapoints.OneHotLabel:
if inpt.ndim < 2:
raise ValueError("Need a batch of one hot labels")
output = inpt.roll(1, 0).mul_(1.0 - lam).add_(inpt.mul(lam))
return features.OneHotLabel.wrap_like(inpt, output)
return datapoints.OneHotLabel.wrap_like(inpt, output)
class RandomMixup(_BaseMixupCutmix):
......@@ -129,17 +130,17 @@ class RandomMixup(_BaseMixupCutmix):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
lam = params["lam"]
if isinstance(inpt, (features.Image, features.Video)) or features.is_simple_tensor(inpt):
expected_ndim = 5 if isinstance(inpt, features.Video) else 4
if isinstance(inpt, (datapoints.Image, datapoints.Video)) or is_simple_tensor(inpt):
expected_ndim = 5 if isinstance(inpt, datapoints.Video) else 4
if inpt.ndim < expected_ndim:
raise ValueError("The transform expects a batched input")
output = inpt.roll(1, 0).mul_(1.0 - lam).add_(inpt.mul(lam))
if isinstance(inpt, (features.Image, features.Video)):
if isinstance(inpt, (datapoints.Image, datapoints.Video)):
output = type(inpt).wrap_like(inpt, output) # type: ignore[arg-type]
return output
elif isinstance(inpt, features.OneHotLabel):
elif isinstance(inpt, datapoints.OneHotLabel):
return self._mixup_onehotlabel(inpt, lam)
else:
return inpt
......@@ -169,9 +170,9 @@ class RandomCutmix(_BaseMixupCutmix):
return dict(box=box, lam_adjusted=lam_adjusted)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if isinstance(inpt, (features.Image, features.Video)) or features.is_simple_tensor(inpt):
if isinstance(inpt, (datapoints.Image, datapoints.Video)) or is_simple_tensor(inpt):
box = params["box"]
expected_ndim = 5 if isinstance(inpt, features.Video) else 4
expected_ndim = 5 if isinstance(inpt, datapoints.Video) else 4
if inpt.ndim < expected_ndim:
raise ValueError("The transform expects a batched input")
x1, y1, x2, y2 = box
......@@ -179,11 +180,11 @@ class RandomCutmix(_BaseMixupCutmix):
output = inpt.clone()
output[..., y1:y2, x1:x2] = rolled[..., y1:y2, x1:x2]
if isinstance(inpt, (features.Image, features.Video)):
if isinstance(inpt, (datapoints.Image, datapoints.Video)):
output = inpt.wrap_like(inpt, output) # type: ignore[arg-type]
return output
elif isinstance(inpt, features.OneHotLabel):
elif isinstance(inpt, datapoints.OneHotLabel):
lam_adjusted = params["lam_adjusted"]
return self._mixup_onehotlabel(inpt, lam_adjusted)
else:
......@@ -205,15 +206,15 @@ class SimpleCopyPaste(_RandomApplyTransform):
def _copy_paste(
self,
image: features.TensorImageType,
image: datapoints.TensorImageType,
target: Dict[str, Any],
paste_image: features.TensorImageType,
paste_image: datapoints.TensorImageType,
paste_target: Dict[str, Any],
random_selection: torch.Tensor,
blending: bool,
resize_interpolation: F.InterpolationMode,
antialias: Optional[bool],
) -> Tuple[features.TensorImageType, Dict[str, Any]]:
) -> Tuple[datapoints.TensorImageType, Dict[str, Any]]:
paste_masks = paste_target["masks"].wrap_like(paste_target["masks"], paste_target["masks"][random_selection])
paste_boxes = paste_target["boxes"].wrap_like(paste_target["boxes"], paste_target["boxes"][random_selection])
......@@ -262,7 +263,7 @@ class SimpleCopyPaste(_RandomApplyTransform):
# https://github.com/pytorch/vision/blob/b6feccbc4387766b76a3e22b13815dbbbfa87c0f/torchvision/models/detection/roi_heads.py#L418-L422
xyxy_boxes[:, 2:] += 1
boxes = F.convert_format_bounding_box(
xyxy_boxes, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox_format, inplace=True
xyxy_boxes, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=bbox_format, inplace=True
)
out_target["boxes"] = torch.cat([boxes, paste_boxes])
......@@ -271,7 +272,7 @@ class SimpleCopyPaste(_RandomApplyTransform):
# Check for degenerated boxes and remove them
boxes = F.convert_format_bounding_box(
out_target["boxes"], old_format=bbox_format, new_format=features.BoundingBoxFormat.XYXY
out_target["boxes"], old_format=bbox_format, new_format=datapoints.BoundingBoxFormat.XYXY
)
degenerate_boxes = boxes[:, 2:] <= boxes[:, :2]
if degenerate_boxes.any():
......@@ -285,20 +286,20 @@ class SimpleCopyPaste(_RandomApplyTransform):
def _extract_image_targets(
self, flat_sample: List[Any]
) -> Tuple[List[features.TensorImageType], List[Dict[str, Any]]]:
) -> Tuple[List[datapoints.TensorImageType], List[Dict[str, Any]]]:
# fetch all images, bboxes, masks and labels from unstructured input
# with List[image], List[BoundingBox], List[Mask], List[Label]
images, bboxes, masks, labels = [], [], [], []
for obj in flat_sample:
if isinstance(obj, features.Image) or features.is_simple_tensor(obj):
if isinstance(obj, datapoints.Image) or is_simple_tensor(obj):
images.append(obj)
elif isinstance(obj, PIL.Image.Image):
images.append(F.to_image_tensor(obj))
elif isinstance(obj, features.BoundingBox):
elif isinstance(obj, datapoints.BoundingBox):
bboxes.append(obj)
elif isinstance(obj, features.Mask):
elif isinstance(obj, datapoints.Mask):
masks.append(obj)
elif isinstance(obj, (features.Label, features.OneHotLabel)):
elif isinstance(obj, (datapoints.Label, datapoints.OneHotLabel)):
labels.append(obj)
if not (len(images) == len(bboxes) == len(masks) == len(labels)):
......@@ -316,27 +317,27 @@ class SimpleCopyPaste(_RandomApplyTransform):
def _insert_outputs(
self,
flat_sample: List[Any],
output_images: List[features.TensorImageType],
output_images: List[datapoints.TensorImageType],
output_targets: List[Dict[str, Any]],
) -> None:
c0, c1, c2, c3 = 0, 0, 0, 0
for i, obj in enumerate(flat_sample):
if isinstance(obj, features.Image):
flat_sample[i] = features.Image.wrap_like(obj, output_images[c0])
if isinstance(obj, datapoints.Image):
flat_sample[i] = datapoints.Image.wrap_like(obj, output_images[c0])
c0 += 1
elif isinstance(obj, PIL.Image.Image):
flat_sample[i] = F.to_image_pil(output_images[c0])
c0 += 1
elif features.is_simple_tensor(obj):
elif is_simple_tensor(obj):
flat_sample[i] = output_images[c0]
c0 += 1
elif isinstance(obj, features.BoundingBox):
flat_sample[i] = features.BoundingBox.wrap_like(obj, output_targets[c1]["boxes"])
elif isinstance(obj, datapoints.BoundingBox):
flat_sample[i] = datapoints.BoundingBox.wrap_like(obj, output_targets[c1]["boxes"])
c1 += 1
elif isinstance(obj, features.Mask):
flat_sample[i] = features.Mask.wrap_like(obj, output_targets[c2]["masks"])
elif isinstance(obj, datapoints.Mask):
flat_sample[i] = datapoints.Mask.wrap_like(obj, output_targets[c2]["masks"])
c2 += 1
elif isinstance(obj, (features.Label, features.OneHotLabel)):
elif isinstance(obj, (datapoints.Label, datapoints.OneHotLabel)):
flat_sample[i] = obj.wrap_like(obj, output_targets[c3]["labels"]) # type: ignore[arg-type]
c3 += 1
......
......@@ -5,13 +5,14 @@ import PIL.Image
import torch
from torch.utils._pytree import tree_flatten, tree_unflatten, TreeSpec
from torchvision.prototype import features
from torchvision.prototype import datapoints
from torchvision.prototype.transforms import AutoAugmentPolicy, functional as F, InterpolationMode, Transform
from torchvision.prototype.transforms.functional._meta import get_spatial_size
from torchvision.transforms import functional_tensor as _FT
from ._utils import _setup_fill_arg
from .utils import check_type
from .utils import check_type, is_simple_tensor
class _AutoAugmentBase(Transform):
......@@ -19,7 +20,7 @@ class _AutoAugmentBase(Transform):
self,
*,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Union[features.FillType, Dict[Type, features.FillType]] = None,
fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = None,
) -> None:
super().__init__()
self.interpolation = interpolation
......@@ -33,13 +34,21 @@ class _AutoAugmentBase(Transform):
def _flatten_and_extract_image_or_video(
self,
inputs: Any,
unsupported_types: Tuple[Type, ...] = (features.BoundingBox, features.Mask),
) -> Tuple[Tuple[List[Any], TreeSpec, int], Union[features.ImageType, features.VideoType]]:
unsupported_types: Tuple[Type, ...] = (datapoints.BoundingBox, datapoints.Mask),
) -> Tuple[Tuple[List[Any], TreeSpec, int], Union[datapoints.ImageType, datapoints.VideoType]]:
flat_inputs, spec = tree_flatten(inputs if len(inputs) > 1 else inputs[0])
image_or_videos = []
for idx, inpt in enumerate(flat_inputs):
if check_type(inpt, (features.Image, PIL.Image.Image, features.is_simple_tensor, features.Video)):
if check_type(
inpt,
(
datapoints.Image,
PIL.Image.Image,
is_simple_tensor,
datapoints.Video,
),
):
image_or_videos.append((idx, inpt))
elif isinstance(inpt, unsupported_types):
raise TypeError(f"Inputs of type {type(inpt).__name__} are not supported by {type(self).__name__}()")
......@@ -58,7 +67,7 @@ class _AutoAugmentBase(Transform):
def _unflatten_and_insert_image_or_video(
self,
flat_inputs_with_spec: Tuple[List[Any], TreeSpec, int],
image_or_video: Union[features.ImageType, features.VideoType],
image_or_video: Union[datapoints.ImageType, datapoints.VideoType],
) -> Any:
flat_inputs, spec, idx = flat_inputs_with_spec
flat_inputs[idx] = image_or_video
......@@ -66,12 +75,12 @@ class _AutoAugmentBase(Transform):
def _apply_image_or_video_transform(
self,
image: Union[features.ImageType, features.VideoType],
image: Union[datapoints.ImageType, datapoints.VideoType],
transform_id: str,
magnitude: float,
interpolation: InterpolationMode,
fill: Dict[Type, features.FillTypeJIT],
) -> Union[features.ImageType, features.VideoType]:
fill: Dict[Type, datapoints.FillTypeJIT],
) -> Union[datapoints.ImageType, datapoints.VideoType]:
fill_ = fill[type(image)]
if transform_id == "Identity":
......@@ -182,7 +191,7 @@ class AutoAugment(_AutoAugmentBase):
self,
policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Union[features.FillType, Dict[Type, features.FillType]] = None,
fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = None,
) -> None:
super().__init__(interpolation=interpolation, fill=fill)
self.policy = policy
......@@ -338,7 +347,7 @@ class RandAugment(_AutoAugmentBase):
magnitude: int = 9,
num_magnitude_bins: int = 31,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Union[features.FillType, Dict[Type, features.FillType]] = None,
fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = None,
) -> None:
super().__init__(interpolation=interpolation, fill=fill)
self.num_ops = num_ops
......@@ -390,7 +399,7 @@ class TrivialAugmentWide(_AutoAugmentBase):
self,
num_magnitude_bins: int = 31,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Union[features.FillType, Dict[Type, features.FillType]] = None,
fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = None,
):
super().__init__(interpolation=interpolation, fill=fill)
self.num_magnitude_bins = num_magnitude_bins
......@@ -446,7 +455,7 @@ class AugMix(_AutoAugmentBase):
alpha: float = 1.0,
all_ops: bool = True,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Union[features.FillType, Dict[Type, features.FillType]] = None,
fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = None,
) -> None:
super().__init__(interpolation=interpolation, fill=fill)
self._PARAMETER_MAX = 10
......@@ -474,7 +483,7 @@ class AugMix(_AutoAugmentBase):
augmentation_space = self._AUGMENTATION_SPACE if self.all_ops else self._PARTIAL_AUGMENTATION_SPACE
orig_dims = list(image_or_video.shape)
expected_ndim = 5 if isinstance(orig_image_or_video, features.Video) else 4
expected_ndim = 5 if isinstance(orig_image_or_video, datapoints.Video) else 4
batch = image_or_video.reshape([1] * max(expected_ndim - image_or_video.ndim, 0) + orig_dims)
batch_dims = [batch.size(0)] + [1] * (batch.ndim - 1)
......@@ -511,7 +520,7 @@ class AugMix(_AutoAugmentBase):
mix.add_(combined_weights[:, i].reshape(batch_dims) * aug)
mix = mix.reshape(orig_dims).to(dtype=image_or_video.dtype)
if isinstance(orig_image_or_video, (features.Image, features.Video)):
if isinstance(orig_image_or_video, (datapoints.Image, datapoints.Video)):
mix = orig_image_or_video.wrap_like(orig_image_or_video, mix) # type: ignore[arg-type]
elif isinstance(orig_image_or_video, PIL.Image.Image):
mix = F.to_image_pil(mix)
......
......@@ -3,11 +3,12 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import PIL.Image
import torch
from torchvision.prototype import features
from torchvision.prototype import datapoints
from torchvision.prototype.transforms import functional as F, Transform
from ._transform import _RandomApplyTransform
from .utils import query_chw
from .utils import is_simple_tensor, query_chw
class ColorJitter(Transform):
......@@ -82,7 +83,12 @@ class ColorJitter(Transform):
class RandomPhotometricDistort(Transform):
_transformed_types = (features.Image, PIL.Image.Image, features.is_simple_tensor, features.Video)
_transformed_types = (
datapoints.Image,
PIL.Image.Image,
is_simple_tensor,
datapoints.Video,
)
def __init__(
self,
......@@ -111,15 +117,15 @@ class RandomPhotometricDistort(Transform):
)
def _permute_channels(
self, inpt: Union[features.ImageType, features.VideoType], permutation: torch.Tensor
) -> Union[features.ImageType, features.VideoType]:
self, inpt: Union[datapoints.ImageType, datapoints.VideoType], permutation: torch.Tensor
) -> Union[datapoints.ImageType, datapoints.VideoType]:
if isinstance(inpt, PIL.Image.Image):
inpt = F.pil_to_tensor(inpt)
output = inpt[..., permutation, :, :]
if isinstance(inpt, (features.Image, features.Video)):
output = inpt.wrap_like(inpt, output, color_space=features.ColorSpace.OTHER) # type: ignore[arg-type]
if isinstance(inpt, (datapoints.Image, datapoints.Video)):
output = inpt.wrap_like(inpt, output, color_space=datapoints.ColorSpace.OTHER) # type: ignore[arg-type]
elif isinstance(inpt, PIL.Image.Image):
output = F.to_image_pil(output)
......@@ -127,8 +133,8 @@ class RandomPhotometricDistort(Transform):
return output
def _transform(
self, inpt: Union[features.ImageType, features.VideoType], params: Dict[str, Any]
) -> Union[features.ImageType, features.VideoType]:
self, inpt: Union[datapoints.ImageType, datapoints.VideoType], params: Dict[str, Any]
) -> Union[datapoints.ImageType, datapoints.VideoType]:
if params["brightness"]:
inpt = F.adjust_brightness(
inpt, brightness_factor=ColorJitter._generate_value(self.brightness[0], self.brightness[1])
......
......@@ -5,13 +5,13 @@ import numpy as np
import PIL.Image
import torch
from torchvision.prototype import features
from torchvision.prototype import datapoints
from torchvision.prototype.transforms import Transform
from torchvision.transforms import functional as _F
from typing_extensions import Literal
from ._transform import _RandomApplyTransform
from .utils import query_chw
from .utils import is_simple_tensor, query_chw
class ToTensor(Transform):
......@@ -29,7 +29,12 @@ class ToTensor(Transform):
class Grayscale(Transform):
_transformed_types = (features.Image, PIL.Image.Image, features.is_simple_tensor, features.Video)
_transformed_types = (
datapoints.Image,
PIL.Image.Image,
is_simple_tensor,
datapoints.Video,
)
def __init__(self, num_output_channels: Literal[1, 3] = 1) -> None:
deprecation_msg = (
......@@ -53,16 +58,21 @@ class Grayscale(Transform):
self.num_output_channels = num_output_channels
def _transform(
self, inpt: Union[features.ImageType, features.VideoType], params: Dict[str, Any]
) -> Union[features.ImageType, features.VideoType]:
self, inpt: Union[datapoints.ImageType, datapoints.VideoType], params: Dict[str, Any]
) -> Union[datapoints.ImageType, datapoints.VideoType]:
output = _F.rgb_to_grayscale(inpt, num_output_channels=self.num_output_channels)
if isinstance(inpt, (features.Image, features.Video)):
output = inpt.wrap_like(inpt, output, color_space=features.ColorSpace.GRAY) # type: ignore[arg-type]
if isinstance(inpt, (datapoints.Image, datapoints.Video)):
output = inpt.wrap_like(inpt, output, color_space=datapoints.ColorSpace.GRAY) # type: ignore[arg-type]
return output
class RandomGrayscale(_RandomApplyTransform):
_transformed_types = (features.Image, PIL.Image.Image, features.is_simple_tensor, features.Video)
_transformed_types = (
datapoints.Image,
PIL.Image.Image,
is_simple_tensor,
datapoints.Video,
)
def __init__(self, p: float = 0.1) -> None:
warnings.warn(
......@@ -84,9 +94,9 @@ class RandomGrayscale(_RandomApplyTransform):
return dict(num_input_channels=num_input_channels)
def _transform(
self, inpt: Union[features.ImageType, features.VideoType], params: Dict[str, Any]
) -> Union[features.ImageType, features.VideoType]:
self, inpt: Union[datapoints.ImageType, datapoints.VideoType], params: Dict[str, Any]
) -> Union[datapoints.ImageType, datapoints.VideoType]:
output = _F.rgb_to_grayscale(inpt, num_output_channels=params["num_input_channels"])
if isinstance(inpt, (features.Image, features.Video)):
output = inpt.wrap_like(inpt, output, color_space=features.ColorSpace.GRAY) # type: ignore[arg-type]
if isinstance(inpt, (datapoints.Image, datapoints.Video)):
output = inpt.wrap_like(inpt, output, color_space=datapoints.ColorSpace.GRAY) # type: ignore[arg-type]
return output
......@@ -5,8 +5,9 @@ from typing import Any, cast, Dict, List, Optional, Sequence, Tuple, Type, Union
import PIL.Image
import torch
from torchvision.ops.boxes import box_iou
from torchvision.prototype import features
from torchvision.prototype import datapoints
from torchvision.prototype.transforms import functional as F, InterpolationMode, Transform
from torchvision.transforms.functional import _get_perspective_coeffs
......@@ -22,7 +23,7 @@ from ._utils import (
_setup_float_or_seq,
_setup_size,
)
from .utils import has_all, has_any, query_bounding_box, query_spatial_size
from .utils import has_all, has_any, is_simple_tensor, query_bounding_box, query_spatial_size
class RandomHorizontalFlip(_RandomApplyTransform):
......@@ -145,23 +146,23 @@ class RandomResizedCrop(Transform):
)
ImageOrVideoTypeJIT = Union[features.ImageTypeJIT, features.VideoTypeJIT]
ImageOrVideoTypeJIT = Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT]
class FiveCrop(Transform):
"""
Example:
>>> class BatchMultiCrop(transforms.Transform):
... def forward(self, sample: Tuple[Tuple[Union[features.Image, features.Video], ...], features.Label]):
... def forward(self, sample: Tuple[Tuple[Union[datapoints.Image, datapoints.Video], ...], datapoints.Label]):
... images_or_videos, labels = sample
... batch_size = len(images_or_videos)
... image_or_video = images_or_videos[0]
... images_or_videos = image_or_video.wrap_like(image_or_video, torch.stack(images_or_videos))
... labels = features.Label.wrap_like(labels, labels.repeat(batch_size))
... labels = datapoints.Label.wrap_like(labels, labels.repeat(batch_size))
... return images_or_videos, labels
...
>>> image = features.Image(torch.rand(3, 256, 256))
>>> label = features.Label(0)
>>> image = datapoints.Image(torch.rand(3, 256, 256))
>>> label = datapoints.Label(0)
>>> transform = transforms.Compose([transforms.FiveCrop(), BatchMultiCrop()])
>>> images, labels = transform(image, label)
>>> images.shape
......@@ -170,7 +171,12 @@ class FiveCrop(Transform):
torch.Size([5])
"""
_transformed_types = (features.Image, PIL.Image.Image, features.is_simple_tensor, features.Video)
_transformed_types = (
datapoints.Image,
PIL.Image.Image,
is_simple_tensor,
datapoints.Video,
)
def __init__(self, size: Union[int, Sequence[int]]) -> None:
super().__init__()
......@@ -182,7 +188,7 @@ class FiveCrop(Transform):
return F.five_crop(inpt, self.size)
def _check_inputs(self, flat_inputs: List[Any]) -> None:
if has_any(flat_inputs, features.BoundingBox, features.Mask):
if has_any(flat_inputs, datapoints.BoundingBox, datapoints.Mask):
raise TypeError(f"BoundingBox'es and Mask's are not supported by {type(self).__name__}()")
......@@ -191,7 +197,12 @@ class TenCrop(Transform):
See :class:`~torchvision.prototype.transforms.FiveCrop` for an example.
"""
_transformed_types = (features.Image, PIL.Image.Image, features.is_simple_tensor, features.Video)
_transformed_types = (
datapoints.Image,
PIL.Image.Image,
is_simple_tensor,
datapoints.Video,
)
def __init__(self, size: Union[int, Sequence[int]], vertical_flip: bool = False) -> None:
super().__init__()
......@@ -199,12 +210,12 @@ class TenCrop(Transform):
self.vertical_flip = vertical_flip
def _check_inputs(self, flat_inputs: List[Any]) -> None:
if has_any(flat_inputs, features.BoundingBox, features.Mask):
if has_any(flat_inputs, datapoints.BoundingBox, datapoints.Mask):
raise TypeError(f"BoundingBox'es and Mask's are not supported by {type(self).__name__}()")
def _transform(
self, inpt: Union[features.ImageType, features.VideoType], params: Dict[str, Any]
) -> Union[List[features.ImageTypeJIT], List[features.VideoTypeJIT]]:
self, inpt: Union[datapoints.ImageType, datapoints.VideoType], params: Dict[str, Any]
) -> Union[List[datapoints.ImageTypeJIT], List[datapoints.VideoTypeJIT]]:
return F.ten_crop(inpt, self.size, vertical_flip=self.vertical_flip)
......@@ -212,7 +223,7 @@ class Pad(Transform):
def __init__(
self,
padding: Union[int, Sequence[int]],
fill: Union[features.FillType, Dict[Type, features.FillType]] = 0,
fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = 0,
padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant",
) -> None:
super().__init__()
......@@ -235,7 +246,7 @@ class Pad(Transform):
class RandomZoomOut(_RandomApplyTransform):
def __init__(
self,
fill: Union[features.FillType, Dict[Type, features.FillType]] = 0,
fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = 0,
side_range: Sequence[float] = (1.0, 4.0),
p: float = 0.5,
) -> None:
......@@ -276,7 +287,7 @@ class RandomRotation(Transform):
degrees: Union[numbers.Number, Sequence],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
expand: bool = False,
fill: Union[features.FillType, Dict[Type, features.FillType]] = 0,
fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = 0,
center: Optional[List[float]] = None,
) -> None:
super().__init__()
......@@ -315,7 +326,7 @@ class RandomAffine(Transform):
scale: Optional[Sequence[float]] = None,
shear: Optional[Union[int, float, Sequence[float]]] = None,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Union[features.FillType, Dict[Type, features.FillType]] = 0,
fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = 0,
center: Optional[List[float]] = None,
) -> None:
super().__init__()
......@@ -390,7 +401,7 @@ class RandomCrop(Transform):
size: Union[int, Sequence[int]],
padding: Optional[Union[int, Sequence[int]]] = None,
pad_if_needed: bool = False,
fill: Union[features.FillType, Dict[Type, features.FillType]] = 0,
fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = 0,
padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant",
) -> None:
super().__init__()
......@@ -480,7 +491,7 @@ class RandomPerspective(_RandomApplyTransform):
def __init__(
self,
distortion_scale: float = 0.5,
fill: Union[features.FillType, Dict[Type, features.FillType]] = 0,
fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = 0,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
p: float = 0.5,
) -> None:
......@@ -540,7 +551,7 @@ class ElasticTransform(Transform):
self,
alpha: Union[float, Sequence[float]] = 50.0,
sigma: Union[float, Sequence[float]] = 5.0,
fill: Union[features.FillType, Dict[Type, features.FillType]] = 0,
fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = 0,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
) -> None:
super().__init__()
......@@ -606,9 +617,9 @@ class RandomIoUCrop(Transform):
def _check_inputs(self, flat_inputs: List[Any]) -> None:
if not (
has_all(flat_inputs, features.BoundingBox)
and has_any(flat_inputs, PIL.Image.Image, features.Image, features.is_simple_tensor)
and has_any(flat_inputs, features.Label, features.OneHotLabel)
has_all(flat_inputs, datapoints.BoundingBox)
and has_any(flat_inputs, PIL.Image.Image, datapoints.Image, is_simple_tensor)
and has_any(flat_inputs, datapoints.Label, datapoints.OneHotLabel)
):
raise TypeError(
f"{type(self).__name__}() requires input sample to contain Images or PIL Images, "
......@@ -646,7 +657,7 @@ class RandomIoUCrop(Transform):
# check for any valid boxes with centers within the crop area
xyxy_bboxes = F.convert_format_bounding_box(
bboxes.as_subclass(torch.Tensor), bboxes.format, features.BoundingBoxFormat.XYXY
bboxes.as_subclass(torch.Tensor), bboxes.format, datapoints.BoundingBoxFormat.XYXY
)
cx = 0.5 * (xyxy_bboxes[..., 0] + xyxy_bboxes[..., 2])
cy = 0.5 * (xyxy_bboxes[..., 1] + xyxy_bboxes[..., 3])
......@@ -671,19 +682,19 @@ class RandomIoUCrop(Transform):
is_within_crop_area = params["is_within_crop_area"]
if isinstance(inpt, (features.Label, features.OneHotLabel)):
if isinstance(inpt, (datapoints.Label, datapoints.OneHotLabel)):
return inpt.wrap_like(inpt, inpt[is_within_crop_area]) # type: ignore[arg-type]
output = F.crop(inpt, top=params["top"], left=params["left"], height=params["height"], width=params["width"])
if isinstance(output, features.BoundingBox):
if isinstance(output, datapoints.BoundingBox):
bboxes = output[is_within_crop_area]
bboxes = F.clamp_bounding_box(bboxes, output.format, output.spatial_size)
output = features.BoundingBox.wrap_like(output, bboxes)
elif isinstance(output, features.Mask):
output = datapoints.BoundingBox.wrap_like(output, bboxes)
elif isinstance(output, datapoints.Mask):
# apply is_within_crop_area if mask is one-hot encoded
masks = output[is_within_crop_area]
output = features.Mask.wrap_like(output, masks)
output = datapoints.Mask.wrap_like(output, masks)
return output
......@@ -751,7 +762,7 @@ class FixedSizeCrop(Transform):
def __init__(
self,
size: Union[int, Sequence[int]],
fill: Union[features.FillType, Dict[Type, features.FillType]] = 0,
fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = 0,
padding_mode: str = "constant",
) -> None:
super().__init__()
......@@ -764,13 +775,19 @@ class FixedSizeCrop(Transform):
self.padding_mode = padding_mode
def _check_inputs(self, flat_inputs: List[Any]) -> None:
if not has_any(flat_inputs, PIL.Image.Image, features.Image, features.is_simple_tensor, features.Video):
if not has_any(
flat_inputs,
PIL.Image.Image,
datapoints.Image,
is_simple_tensor,
datapoints.Video,
):
raise TypeError(
f"{type(self).__name__}() requires input sample to contain an tensor or PIL image or a Video."
)
if has_any(flat_inputs, features.BoundingBox) and not has_any(
flat_inputs, features.Label, features.OneHotLabel
if has_any(flat_inputs, datapoints.BoundingBox) and not has_any(
flat_inputs, datapoints.Label, datapoints.OneHotLabel
):
raise TypeError(
f"If a BoundingBox is contained in the input sample, "
......@@ -809,7 +826,7 @@ class FixedSizeCrop(Transform):
)
bounding_boxes = F.clamp_bounding_box(bounding_boxes, format=format, spatial_size=spatial_size)
height_and_width = F.convert_format_bounding_box(
bounding_boxes, old_format=format, new_format=features.BoundingBoxFormat.XYWH
bounding_boxes, old_format=format, new_format=datapoints.BoundingBoxFormat.XYWH
)[..., 2:]
is_valid = torch.all(height_and_width > 0, dim=-1)
else:
......@@ -842,10 +859,10 @@ class FixedSizeCrop(Transform):
)
if params["is_valid"] is not None:
if isinstance(inpt, (features.Label, features.OneHotLabel, features.Mask)):
if isinstance(inpt, (datapoints.Label, datapoints.OneHotLabel, datapoints.Mask)):
inpt = inpt.wrap_like(inpt, inpt[params["is_valid"]]) # type: ignore[arg-type]
elif isinstance(inpt, features.BoundingBox):
inpt = features.BoundingBox.wrap_like(
elif isinstance(inpt, datapoints.BoundingBox):
inpt = datapoints.BoundingBox.wrap_like(
inpt,
F.clamp_bounding_box(inpt[params["is_valid"]], format=inpt.format, spatial_size=inpt.spatial_size),
)
......
......@@ -3,38 +3,41 @@ from typing import Any, Dict, Optional, Union
import PIL.Image
import torch
from torchvision.prototype import features
from torchvision.prototype import datapoints
from torchvision.prototype.transforms import functional as F, Transform
from .utils import is_simple_tensor
class ConvertBoundingBoxFormat(Transform):
_transformed_types = (features.BoundingBox,)
_transformed_types = (datapoints.BoundingBox,)
def __init__(self, format: Union[str, features.BoundingBoxFormat]) -> None:
def __init__(self, format: Union[str, datapoints.BoundingBoxFormat]) -> None:
super().__init__()
if isinstance(format, str):
format = features.BoundingBoxFormat[format]
format = datapoints.BoundingBoxFormat[format]
self.format = format
def _transform(self, inpt: features.BoundingBox, params: Dict[str, Any]) -> features.BoundingBox:
def _transform(self, inpt: datapoints.BoundingBox, params: Dict[str, Any]) -> datapoints.BoundingBox:
# We need to unwrap here to avoid unnecessary `__torch_function__` calls,
# since `convert_format_bounding_box` does not have a dispatcher function that would do that for us
output = F.convert_format_bounding_box(
inpt.as_subclass(torch.Tensor), old_format=inpt.format, new_format=params["format"]
)
return features.BoundingBox.wrap_like(inpt, output, format=params["format"])
return datapoints.BoundingBox.wrap_like(inpt, output, format=params["format"])
class ConvertDtype(Transform):
_transformed_types = (features.is_simple_tensor, features.Image, features.Video)
_transformed_types = (is_simple_tensor, datapoints.Image, datapoints.Video)
def __init__(self, dtype: torch.dtype = torch.float32) -> None:
super().__init__()
self.dtype = dtype
def _transform(
self, inpt: Union[features.TensorImageType, features.TensorVideoType], params: Dict[str, Any]
) -> Union[features.TensorImageType, features.TensorVideoType]:
self, inpt: Union[datapoints.TensorImageType, datapoints.TensorVideoType], params: Dict[str, Any]
) -> Union[datapoints.TensorImageType, datapoints.TensorVideoType]:
return F.convert_dtype(inpt, self.dtype)
......@@ -44,36 +47,41 @@ ConvertImageDtype = ConvertDtype
class ConvertColorSpace(Transform):
_transformed_types = (features.is_simple_tensor, features.Image, PIL.Image.Image, features.Video)
_transformed_types = (
is_simple_tensor,
datapoints.Image,
PIL.Image.Image,
datapoints.Video,
)
def __init__(
self,
color_space: Union[str, features.ColorSpace],
old_color_space: Optional[Union[str, features.ColorSpace]] = None,
color_space: Union[str, datapoints.ColorSpace],
old_color_space: Optional[Union[str, datapoints.ColorSpace]] = None,
) -> None:
super().__init__()
if isinstance(color_space, str):
color_space = features.ColorSpace.from_str(color_space)
color_space = datapoints.ColorSpace.from_str(color_space)
self.color_space = color_space
if isinstance(old_color_space, str):
old_color_space = features.ColorSpace.from_str(old_color_space)
old_color_space = datapoints.ColorSpace.from_str(old_color_space)
self.old_color_space = old_color_space
def _transform(
self, inpt: Union[features.ImageType, features.VideoType], params: Dict[str, Any]
) -> Union[features.ImageType, features.VideoType]:
self, inpt: Union[datapoints.ImageType, datapoints.VideoType], params: Dict[str, Any]
) -> Union[datapoints.ImageType, datapoints.VideoType]:
return F.convert_color_space(inpt, color_space=self.color_space, old_color_space=self.old_color_space)
class ClampBoundingBoxes(Transform):
_transformed_types = (features.BoundingBox,)
_transformed_types = (datapoints.BoundingBox,)
def _transform(self, inpt: features.BoundingBox, params: Dict[str, Any]) -> features.BoundingBox:
def _transform(self, inpt: datapoints.BoundingBox, params: Dict[str, Any]) -> datapoints.BoundingBox:
# We need to unwrap here to avoid unnecessary `__torch_function__` calls,
# since `clamp_bounding_box` does not have a dispatcher function that would do that for us
output = F.clamp_bounding_box(
inpt.as_subclass(torch.Tensor), format=inpt.format, spatial_size=inpt.spatial_size
)
return features.BoundingBox.wrap_like(inpt, output)
return datapoints.BoundingBox.wrap_like(inpt, output)
......@@ -3,12 +3,13 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, U
import PIL.Image
import torch
from torchvision.ops import remove_small_boxes
from torchvision.prototype import features
from torchvision.prototype import datapoints
from torchvision.prototype.transforms import functional as F, Transform
from ._utils import _get_defaultdict, _setup_float_or_seq, _setup_size
from .utils import has_any, query_bounding_box
from .utils import has_any, is_simple_tensor, query_bounding_box
class Identity(Transform):
......@@ -38,7 +39,7 @@ class Lambda(Transform):
class LinearTransformation(Transform):
_transformed_types = (features.is_simple_tensor, features.Image, features.Video)
_transformed_types = (is_simple_tensor, datapoints.Image, datapoints.Video)
def __init__(self, transformation_matrix: torch.Tensor, mean_vector: torch.Tensor):
super().__init__()
......@@ -67,7 +68,7 @@ class LinearTransformation(Transform):
raise TypeError("LinearTransformation does not work on PIL Images")
def _transform(
self, inpt: Union[features.TensorImageType, features.TensorVideoType], params: Dict[str, Any]
self, inpt: Union[datapoints.TensorImageType, datapoints.TensorVideoType], params: Dict[str, Any]
) -> torch.Tensor:
# Image instance after linear transformation is not Image anymore due to unknown data range
# Thus we will return Tensor for input Image
......@@ -93,7 +94,7 @@ class LinearTransformation(Transform):
class Normalize(Transform):
_transformed_types = (features.Image, features.is_simple_tensor, features.Video)
_transformed_types = (datapoints.Image, is_simple_tensor, datapoints.Video)
def __init__(self, mean: Sequence[float], std: Sequence[float], inplace: bool = False):
super().__init__()
......@@ -106,7 +107,7 @@ class Normalize(Transform):
raise TypeError(f"{type(self).__name__}() does not support PIL images.")
def _transform(
self, inpt: Union[features.TensorImageType, features.TensorVideoType], params: Dict[str, Any]
self, inpt: Union[datapoints.TensorImageType, datapoints.TensorVideoType], params: Dict[str, Any]
) -> torch.Tensor:
return F.normalize(inpt, mean=self.mean, std=self.std, inplace=self.inplace)
......@@ -158,7 +159,7 @@ class ToDtype(Transform):
class PermuteDimensions(Transform):
_transformed_types = (features.is_simple_tensor, features.Image, features.Video)
_transformed_types = (is_simple_tensor, datapoints.Image, datapoints.Video)
def __init__(self, dims: Union[Sequence[int], Dict[Type, Optional[Sequence[int]]]]) -> None:
super().__init__()
......@@ -167,7 +168,7 @@ class PermuteDimensions(Transform):
self.dims = dims
def _transform(
self, inpt: Union[features.TensorImageType, features.TensorVideoType], params: Dict[str, Any]
self, inpt: Union[datapoints.TensorImageType, datapoints.TensorVideoType], params: Dict[str, Any]
) -> torch.Tensor:
dims = self.dims[type(inpt)]
if dims is None:
......@@ -176,7 +177,7 @@ class PermuteDimensions(Transform):
class TransposeDimensions(Transform):
_transformed_types = (features.is_simple_tensor, features.Image, features.Video)
_transformed_types = (is_simple_tensor, datapoints.Image, datapoints.Video)
def __init__(self, dims: Union[Tuple[int, int], Dict[Type, Optional[Tuple[int, int]]]]) -> None:
super().__init__()
......@@ -185,7 +186,7 @@ class TransposeDimensions(Transform):
self.dims = dims
def _transform(
self, inpt: Union[features.TensorImageType, features.TensorVideoType], params: Dict[str, Any]
self, inpt: Union[datapoints.TensorImageType, datapoints.TensorVideoType], params: Dict[str, Any]
) -> torch.Tensor:
dims = self.dims[type(inpt)]
if dims is None:
......@@ -194,7 +195,7 @@ class TransposeDimensions(Transform):
class RemoveSmallBoundingBoxes(Transform):
_transformed_types = (features.BoundingBox, features.Mask, features.Label, features.OneHotLabel)
_transformed_types = (datapoints.BoundingBox, datapoints.Mask, datapoints.Label, datapoints.OneHotLabel)
def __init__(self, min_size: float = 1.0) -> None:
super().__init__()
......@@ -210,7 +211,7 @@ class RemoveSmallBoundingBoxes(Transform):
bounding_box = F.convert_format_bounding_box(
bounding_box.as_subclass(torch.Tensor),
old_format=bounding_box.format,
new_format=features.BoundingBoxFormat.XYXY,
new_format=datapoints.BoundingBoxFormat.XYXY,
)
valid_indices = remove_small_boxes(bounding_box, min_size=self.min_size)
......
from typing import Any, Dict
from torchvision.prototype import features
from torchvision.prototype import datapoints
from torchvision.prototype.transforms import functional as F, Transform
from torchvision.prototype.transforms.utils import is_simple_tensor
class UniformTemporalSubsample(Transform):
_transformed_types = (features.is_simple_tensor, features.Video)
_transformed_types = (is_simple_tensor, datapoints.Video)
def __init__(self, num_samples: int, temporal_dim: int = -4):
super().__init__()
self.num_samples = num_samples
self.temporal_dim = temporal_dim
def _transform(self, inpt: features.VideoType, params: Dict[str, Any]) -> features.VideoType:
def _transform(self, inpt: datapoints.VideoType, params: Dict[str, Any]) -> datapoints.VideoType:
return F.uniform_temporal_subsample(inpt, self.num_samples, temporal_dim=self.temporal_dim)
......@@ -5,23 +5,26 @@ import PIL.Image
import torch
from torch.nn.functional import one_hot
from torchvision.prototype import features
from torchvision.prototype import datapoints
from torchvision.prototype.transforms import functional as F, Transform
from torchvision.prototype.transforms.utils import is_simple_tensor
class LabelToOneHot(Transform):
_transformed_types = (features.Label,)
_transformed_types = (datapoints.Label,)
def __init__(self, num_categories: int = -1):
super().__init__()
self.num_categories = num_categories
def _transform(self, inpt: features.Label, params: Dict[str, Any]) -> features.OneHotLabel:
def _transform(self, inpt: datapoints.Label, params: Dict[str, Any]) -> datapoints.OneHotLabel:
num_categories = self.num_categories
if num_categories == -1 and inpt.categories is not None:
num_categories = len(inpt.categories)
output = one_hot(inpt.as_subclass(torch.Tensor), num_classes=num_categories)
return features.OneHotLabel(output, categories=inpt.categories)
return datapoints.OneHotLabel(output, categories=inpt.categories)
def extra_repr(self) -> str:
if self.num_categories == -1:
......@@ -38,16 +41,16 @@ class PILToTensor(Transform):
class ToImageTensor(Transform):
_transformed_types = (features.is_simple_tensor, PIL.Image.Image, np.ndarray)
_transformed_types = (is_simple_tensor, PIL.Image.Image, np.ndarray)
def _transform(
self, inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray], params: Dict[str, Any]
) -> features.Image:
) -> datapoints.Image:
return F.to_image_tensor(inpt) # type: ignore[no-any-return]
class ToImagePIL(Transform):
_transformed_types = (features.is_simple_tensor, features.Image, np.ndarray)
_transformed_types = (is_simple_tensor, datapoints.Image, np.ndarray)
def __init__(self, mode: Optional[str] = None) -> None:
super().__init__()
......
......@@ -3,8 +3,8 @@ import numbers
from collections import defaultdict
from typing import Any, Dict, Sequence, Type, TypeVar, Union
from torchvision.prototype import features
from torchvision.prototype.features._feature import FillType, FillTypeJIT
from torchvision.prototype import datapoints
from torchvision.prototype.datapoints._datapoint import FillType, FillTypeJIT
from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401
......@@ -54,7 +54,7 @@ def _get_defaultdict(default: T) -> Dict[Any, T]:
return defaultdict(functools.partial(_default_arg, default))
def _convert_fill_arg(fill: features.FillType) -> features.FillTypeJIT:
def _convert_fill_arg(fill: datapoints.FillType) -> datapoints.FillTypeJIT:
# Fill = 0 is not equivalent to None, https://github.com/pytorch/vision/issues/6517
# So, we can't reassign fill to 0
# if fill is None:
......
......@@ -3,7 +3,7 @@ from typing import Union
import PIL.Image
import torch
from torchvision.prototype import features
from torchvision.prototype import datapoints
from torchvision.transforms.functional import pil_to_tensor, to_pil_image
......@@ -33,28 +33,28 @@ def erase_video(
def erase(
inpt: Union[features.ImageTypeJIT, features.VideoTypeJIT],
inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT],
i: int,
j: int,
h: int,
w: int,
v: torch.Tensor,
inplace: bool = False,
) -> Union[features.ImageTypeJIT, features.VideoTypeJIT]:
) -> Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT]:
if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, (features.Image, features.Video))
torch.jit.is_scripting() or not isinstance(inpt, (datapoints.Image, datapoints.Video))
):
return erase_image_tensor(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
elif isinstance(inpt, features.Image):
elif isinstance(inpt, datapoints.Image):
output = erase_image_tensor(inpt.as_subclass(torch.Tensor), i=i, j=j, h=h, w=w, v=v, inplace=inplace)
return features.Image.wrap_like(inpt, output)
elif isinstance(inpt, features.Video):
return datapoints.Image.wrap_like(inpt, output)
elif isinstance(inpt, datapoints.Video):
output = erase_video(inpt.as_subclass(torch.Tensor), i=i, j=j, h=h, w=w, v=v, inplace=inplace)
return features.Video.wrap_like(inpt, output)
return datapoints.Video.wrap_like(inpt, output)
elif isinstance(inpt, PIL.Image.Image):
return erase_image_pil(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
else:
raise TypeError(
f"Input can either be a plain tensor, an `Image` or `Video` tensor subclass, or a PIL image, "
f"Input can either be a plain tensor, an `Image` or `Video` datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
import PIL.Image
import torch
from torch.nn.functional import conv2d
from torchvision.prototype import features
from torchvision.prototype import datapoints
from torchvision.transforms import functional_pil as _FP
from torchvision.transforms.functional_tensor import _max_value
......@@ -37,16 +37,18 @@ def adjust_brightness_video(video: torch.Tensor, brightness_factor: float) -> to
return adjust_brightness_image_tensor(video, brightness_factor=brightness_factor)
def adjust_brightness(inpt: features.InputTypeJIT, brightness_factor: float) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
def adjust_brightness(inpt: datapoints.InputTypeJIT, brightness_factor: float) -> datapoints.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
return adjust_brightness_image_tensor(inpt, brightness_factor=brightness_factor)
elif isinstance(inpt, features._Feature):
elif isinstance(inpt, datapoints._datapoint.Datapoint):
return inpt.adjust_brightness(brightness_factor=brightness_factor)
elif isinstance(inpt, PIL.Image.Image):
return adjust_brightness_image_pil(inpt, brightness_factor=brightness_factor)
else:
raise TypeError(
f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, "
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
......@@ -76,16 +78,18 @@ def adjust_saturation_video(video: torch.Tensor, saturation_factor: float) -> to
return adjust_saturation_image_tensor(video, saturation_factor=saturation_factor)
def adjust_saturation(inpt: features.InputTypeJIT, saturation_factor: float) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
def adjust_saturation(inpt: datapoints.InputTypeJIT, saturation_factor: float) -> datapoints.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
return adjust_saturation_image_tensor(inpt, saturation_factor=saturation_factor)
elif isinstance(inpt, features._Feature):
elif isinstance(inpt, datapoints._datapoint.Datapoint):
return inpt.adjust_saturation(saturation_factor=saturation_factor)
elif isinstance(inpt, PIL.Image.Image):
return adjust_saturation_image_pil(inpt, saturation_factor=saturation_factor)
else:
raise TypeError(
f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, "
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
......@@ -115,16 +119,18 @@ def adjust_contrast_video(video: torch.Tensor, contrast_factor: float) -> torch.
return adjust_contrast_image_tensor(video, contrast_factor=contrast_factor)
def adjust_contrast(inpt: features.InputTypeJIT, contrast_factor: float) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
def adjust_contrast(inpt: datapoints.InputTypeJIT, contrast_factor: float) -> datapoints.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
return adjust_contrast_image_tensor(inpt, contrast_factor=contrast_factor)
elif isinstance(inpt, features._Feature):
elif isinstance(inpt, datapoints._datapoint.Datapoint):
return inpt.adjust_contrast(contrast_factor=contrast_factor)
elif isinstance(inpt, PIL.Image.Image):
return adjust_contrast_image_pil(inpt, contrast_factor=contrast_factor)
else:
raise TypeError(
f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, "
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
......@@ -188,16 +194,18 @@ def adjust_sharpness_video(video: torch.Tensor, sharpness_factor: float) -> torc
return adjust_sharpness_image_tensor(video, sharpness_factor=sharpness_factor)
def adjust_sharpness(inpt: features.InputTypeJIT, sharpness_factor: float) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
def adjust_sharpness(inpt: datapoints.InputTypeJIT, sharpness_factor: float) -> datapoints.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
return adjust_sharpness_image_tensor(inpt, sharpness_factor=sharpness_factor)
elif isinstance(inpt, features._Feature):
elif isinstance(inpt, datapoints._datapoint.Datapoint):
return inpt.adjust_sharpness(sharpness_factor=sharpness_factor)
elif isinstance(inpt, PIL.Image.Image):
return adjust_sharpness_image_pil(inpt, sharpness_factor=sharpness_factor)
else:
raise TypeError(
f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, "
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
......@@ -300,16 +308,18 @@ def adjust_hue_video(video: torch.Tensor, hue_factor: float) -> torch.Tensor:
return adjust_hue_image_tensor(video, hue_factor=hue_factor)
def adjust_hue(inpt: features.InputTypeJIT, hue_factor: float) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
def adjust_hue(inpt: datapoints.InputTypeJIT, hue_factor: float) -> datapoints.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
return adjust_hue_image_tensor(inpt, hue_factor=hue_factor)
elif isinstance(inpt, features._Feature):
elif isinstance(inpt, datapoints._datapoint.Datapoint):
return inpt.adjust_hue(hue_factor=hue_factor)
elif isinstance(inpt, PIL.Image.Image):
return adjust_hue_image_pil(inpt, hue_factor=hue_factor)
else:
raise TypeError(
f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, "
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
......@@ -340,16 +350,18 @@ def adjust_gamma_video(video: torch.Tensor, gamma: float, gain: float = 1) -> to
return adjust_gamma_image_tensor(video, gamma=gamma, gain=gain)
def adjust_gamma(inpt: features.InputTypeJIT, gamma: float, gain: float = 1) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
def adjust_gamma(inpt: datapoints.InputTypeJIT, gamma: float, gain: float = 1) -> datapoints.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
return adjust_gamma_image_tensor(inpt, gamma=gamma, gain=gain)
elif isinstance(inpt, features._Feature):
elif isinstance(inpt, datapoints._datapoint.Datapoint):
return inpt.adjust_gamma(gamma=gamma, gain=gain)
elif isinstance(inpt, PIL.Image.Image):
return adjust_gamma_image_pil(inpt, gamma=gamma, gain=gain)
else:
raise TypeError(
f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, "
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
......@@ -374,16 +386,18 @@ def posterize_video(video: torch.Tensor, bits: int) -> torch.Tensor:
return posterize_image_tensor(video, bits=bits)
def posterize(inpt: features.InputTypeJIT, bits: int) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
def posterize(inpt: datapoints.InputTypeJIT, bits: int) -> datapoints.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
return posterize_image_tensor(inpt, bits=bits)
elif isinstance(inpt, features._Feature):
elif isinstance(inpt, datapoints._datapoint.Datapoint):
return inpt.posterize(bits=bits)
elif isinstance(inpt, PIL.Image.Image):
return posterize_image_pil(inpt, bits=bits)
else:
raise TypeError(
f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, "
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
......@@ -402,16 +416,18 @@ def solarize_video(video: torch.Tensor, threshold: float) -> torch.Tensor:
return solarize_image_tensor(video, threshold=threshold)
def solarize(inpt: features.InputTypeJIT, threshold: float) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
def solarize(inpt: datapoints.InputTypeJIT, threshold: float) -> datapoints.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
return solarize_image_tensor(inpt, threshold=threshold)
elif isinstance(inpt, features._Feature):
elif isinstance(inpt, datapoints._datapoint.Datapoint):
return inpt.solarize(threshold=threshold)
elif isinstance(inpt, PIL.Image.Image):
return solarize_image_pil(inpt, threshold=threshold)
else:
raise TypeError(
f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, "
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
......@@ -452,16 +468,18 @@ def autocontrast_video(video: torch.Tensor) -> torch.Tensor:
return autocontrast_image_tensor(video)
def autocontrast(inpt: features.InputTypeJIT) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
def autocontrast(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
return autocontrast_image_tensor(inpt)
elif isinstance(inpt, features._Feature):
elif isinstance(inpt, datapoints._datapoint.Datapoint):
return inpt.autocontrast()
elif isinstance(inpt, PIL.Image.Image):
return autocontrast_image_pil(inpt)
else:
raise TypeError(
f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, "
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
......@@ -542,16 +560,18 @@ def equalize_video(video: torch.Tensor) -> torch.Tensor:
return equalize_image_tensor(video)
def equalize(inpt: features.InputTypeJIT) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
def equalize(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
return equalize_image_tensor(inpt)
elif isinstance(inpt, features._Feature):
elif isinstance(inpt, datapoints._datapoint.Datapoint):
return inpt.equalize()
elif isinstance(inpt, PIL.Image.Image):
return equalize_image_pil(inpt)
else:
raise TypeError(
f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, "
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
......@@ -573,15 +593,17 @@ def invert_video(video: torch.Tensor) -> torch.Tensor:
return invert_image_tensor(video)
def invert(inpt: features.InputTypeJIT) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
def invert(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
return invert_image_tensor(inpt)
elif isinstance(inpt, features._Feature):
elif isinstance(inpt, datapoints._datapoint.Datapoint):
return inpt.invert()
elif isinstance(inpt, PIL.Image.Image):
return invert_image_pil(inpt)
else:
raise TypeError(
f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, "
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
......@@ -4,16 +4,16 @@ from typing import Any, List, Union
import PIL.Image
import torch
from torchvision.prototype import features
from torchvision.prototype import datapoints
from torchvision.transforms import functional as _F
@torch.jit.unused
def to_grayscale(inpt: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Image.Image:
call = ", num_output_channels=3" if num_output_channels == 3 else ""
replacement = "convert_color_space(..., color_space=features.ColorSpace.GRAY)"
replacement = "convert_color_space(..., color_space=datapoints.ColorSpace.GRAY)"
if num_output_channels == 3:
replacement = f"convert_color_space({replacement}, color_space=features.ColorSpace.RGB)"
replacement = f"convert_color_space({replacement}, color_space=datapoints.ColorSpace.RGB)"
warnings.warn(
f"The function `to_grayscale(...{call})` is deprecated in will be removed in a future release. "
f"Instead, please use `{replacement}`.",
......@@ -23,25 +23,25 @@ def to_grayscale(inpt: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Ima
def rgb_to_grayscale(
inpt: Union[features.ImageTypeJIT, features.VideoTypeJIT], num_output_channels: int = 1
) -> Union[features.ImageTypeJIT, features.VideoTypeJIT]:
if not torch.jit.is_scripting() and isinstance(inpt, (features.Image, features.Video)):
inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT], num_output_channels: int = 1
) -> Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT]:
if not torch.jit.is_scripting() and isinstance(inpt, (datapoints.Image, datapoints.Video)):
inpt = inpt.as_subclass(torch.Tensor)
old_color_space = None
elif isinstance(inpt, torch.Tensor):
old_color_space = features._image._from_tensor_shape(inpt.shape) # type: ignore[arg-type]
old_color_space = datapoints._image._from_tensor_shape(inpt.shape) # type: ignore[arg-type]
else:
old_color_space = None
call = ", num_output_channels=3" if num_output_channels == 3 else ""
replacement = (
f"convert_color_space(..., color_space=features.ColorSpace.GRAY"
f"{f', old_color_space=features.ColorSpace.{old_color_space}' if old_color_space is not None else ''})"
f"convert_color_space(..., color_space=datapoints.ColorSpace.GRAY"
f"{f', old_color_space=datapoints.ColorSpace.{old_color_space}' if old_color_space is not None else ''})"
)
if num_output_channels == 3:
replacement = (
f"convert_color_space({replacement}, color_space=features.ColorSpace.RGB"
f"{f', old_color_space=features.ColorSpace.GRAY' if old_color_space is not None else ''})"
f"convert_color_space({replacement}, color_space=datapoints.ColorSpace.RGB"
f"{f', old_color_space=datapoints.ColorSpace.GRAY' if old_color_space is not None else ''})"
)
warnings.warn(
f"The function `rgb_to_grayscale(...{call})` is deprecated in will be removed in a future release. "
......@@ -60,7 +60,7 @@ def to_tensor(inpt: Any) -> torch.Tensor:
return _F.to_tensor(inpt)
def get_image_size(inpt: Union[features.ImageTypeJIT, features.VideoTypeJIT]) -> List[int]:
def get_image_size(inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT]) -> List[int]:
warnings.warn(
"The function `get_image_size(...)` is deprecated and will be removed in a future release. "
"Instead, please use `get_spatial_size(...)` which returns `[h, w]` instead of `[w, h]`."
......
......@@ -2,8 +2,8 @@ from typing import List, Optional, Tuple, Union
import PIL.Image
import torch
from torchvision.prototype import features
from torchvision.prototype.features import BoundingBoxFormat, ColorSpace
from torchvision.prototype import datapoints
from torchvision.prototype.datapoints import BoundingBoxFormat, ColorSpace
from torchvision.transforms import functional_pil as _FP
from torchvision.transforms.functional_tensor import _max_value
......@@ -23,12 +23,12 @@ def get_dimensions_image_tensor(image: torch.Tensor) -> List[int]:
get_dimensions_image_pil = _FP.get_dimensions
def get_dimensions(inpt: Union[features.ImageTypeJIT, features.VideoTypeJIT]) -> List[int]:
def get_dimensions(inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT]) -> List[int]:
if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, (features.Image, features.Video))
torch.jit.is_scripting() or not isinstance(inpt, (datapoints.Image, datapoints.Video))
):
return get_dimensions_image_tensor(inpt)
elif isinstance(inpt, (features.Image, features.Video)):
elif isinstance(inpt, (datapoints.Image, datapoints.Video)):
channels = inpt.num_channels
height, width = inpt.spatial_size
return [channels, height, width]
......@@ -36,7 +36,7 @@ def get_dimensions(inpt: Union[features.ImageTypeJIT, features.VideoTypeJIT]) ->
return get_dimensions_image_pil(inpt)
else:
raise TypeError(
f"Input can either be a plain tensor, an `Image` or `Video` tensor subclass, or a PIL image, "
f"Input can either be a plain tensor, an `Image` or `Video` datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
......@@ -59,18 +59,18 @@ def get_num_channels_video(video: torch.Tensor) -> int:
return get_num_channels_image_tensor(video)
def get_num_channels(inpt: Union[features.ImageTypeJIT, features.VideoTypeJIT]) -> int:
def get_num_channels(inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT]) -> int:
if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, (features.Image, features.Video))
torch.jit.is_scripting() or not isinstance(inpt, (datapoints.Image, datapoints.Video))
):
return get_num_channels_image_tensor(inpt)
elif isinstance(inpt, (features.Image, features.Video)):
elif isinstance(inpt, (datapoints.Image, datapoints.Video)):
return inpt.num_channels
elif isinstance(inpt, PIL.Image.Image):
return get_num_channels_image_pil(inpt)
else:
raise TypeError(
f"Input can either be a plain tensor, an `Image` or `Video` tensor subclass, or a PIL image, "
f"Input can either be a plain tensor, an `Image` or `Video` datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
......@@ -104,20 +104,22 @@ def get_spatial_size_mask(mask: torch.Tensor) -> List[int]:
@torch.jit.unused
def get_spatial_size_bounding_box(bounding_box: features.BoundingBox) -> List[int]:
def get_spatial_size_bounding_box(bounding_box: datapoints.BoundingBox) -> List[int]:
return list(bounding_box.spatial_size)
def get_spatial_size(inpt: features.InputTypeJIT) -> List[int]:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
def get_spatial_size(inpt: datapoints.InputTypeJIT) -> List[int]:
if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
return get_spatial_size_image_tensor(inpt)
elif isinstance(inpt, (features.Image, features.Video, features.BoundingBox, features.Mask)):
elif isinstance(inpt, (datapoints.Image, datapoints.Video, datapoints.BoundingBox, datapoints.Mask)):
return list(inpt.spatial_size)
elif isinstance(inpt, PIL.Image.Image):
return get_spatial_size_image_pil(inpt) # type: ignore[no-any-return]
else:
raise TypeError(
f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, "
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
......@@ -126,15 +128,13 @@ def get_num_frames_video(video: torch.Tensor) -> int:
return video.shape[-4]
def get_num_frames(inpt: features.VideoTypeJIT) -> int:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features.Video)):
def get_num_frames(inpt: datapoints.VideoTypeJIT) -> int:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, datapoints.Video)):
return get_num_frames_video(inpt)
elif isinstance(inpt, features.Video):
elif isinstance(inpt, datapoints.Video):
return inpt.num_frames
else:
raise TypeError(
f"Input can either be a plain tensor or a `Video` tensor subclass, but got {type(inpt)} instead."
)
raise TypeError(f"Input can either be a plain tensor or a `Video` datapoint, but got {type(inpt)} instead.")
def _xywh_to_xyxy(xywh: torch.Tensor, inplace: bool) -> torch.Tensor:
......@@ -202,7 +202,7 @@ def clamp_bounding_box(
# TODO: Investigate if it makes sense from a performance perspective to have an implementation for every
# BoundingBoxFormat instead of converting back and forth
xyxy_boxes = convert_format_bounding_box(
bounding_box.clone(), old_format=format, new_format=features.BoundingBoxFormat.XYXY, inplace=True
bounding_box.clone(), old_format=format, new_format=datapoints.BoundingBoxFormat.XYXY, inplace=True
)
xyxy_boxes[..., 0::2].clamp_(min=0, max=spatial_size[1])
xyxy_boxes[..., 1::2].clamp_(min=0, max=spatial_size[0])
......@@ -309,12 +309,12 @@ def convert_color_space_video(
def convert_color_space(
inpt: Union[features.ImageTypeJIT, features.VideoTypeJIT],
inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT],
color_space: ColorSpace,
old_color_space: Optional[ColorSpace] = None,
) -> Union[features.ImageTypeJIT, features.VideoTypeJIT]:
) -> Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT]:
if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, (features.Image, features.Video))
torch.jit.is_scripting() or not isinstance(inpt, (datapoints.Image, datapoints.Video))
):
if old_color_space is None:
raise RuntimeError(
......@@ -322,21 +322,21 @@ def convert_color_space(
"the `old_color_space=...` parameter needs to be passed."
)
return convert_color_space_image_tensor(inpt, old_color_space=old_color_space, new_color_space=color_space)
elif isinstance(inpt, features.Image):
elif isinstance(inpt, datapoints.Image):
output = convert_color_space_image_tensor(
inpt.as_subclass(torch.Tensor), old_color_space=inpt.color_space, new_color_space=color_space
)
return features.Image.wrap_like(inpt, output, color_space=color_space)
elif isinstance(inpt, features.Video):
return datapoints.Image.wrap_like(inpt, output, color_space=color_space)
elif isinstance(inpt, datapoints.Video):
output = convert_color_space_video(
inpt.as_subclass(torch.Tensor), old_color_space=inpt.color_space, new_color_space=color_space
)
return features.Video.wrap_like(inpt, output, color_space=color_space)
return datapoints.Video.wrap_like(inpt, output, color_space=color_space)
elif isinstance(inpt, PIL.Image.Image):
return convert_color_space_image_pil(inpt, color_space=color_space)
else:
raise TypeError(
f"Input can either be a plain tensor, an `Image` or `Video` tensor subclass, or a PIL image, "
f"Input can either be a plain tensor, an `Image` or `Video` datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
......@@ -415,20 +415,19 @@ def convert_dtype_video(video: torch.Tensor, dtype: torch.dtype = torch.float) -
def convert_dtype(
inpt: Union[features.ImageTypeJIT, features.VideoTypeJIT], dtype: torch.dtype = torch.float
inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT], dtype: torch.dtype = torch.float
) -> torch.Tensor:
if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, (features.Image, features.Video))
torch.jit.is_scripting() or not isinstance(inpt, (datapoints.Image, datapoints.Video))
):
return convert_dtype_image_tensor(inpt, dtype)
elif isinstance(inpt, features.Image):
elif isinstance(inpt, datapoints.Image):
output = convert_dtype_image_tensor(inpt.as_subclass(torch.Tensor), dtype)
return features.Image.wrap_like(inpt, output)
elif isinstance(inpt, features.Video):
return datapoints.Image.wrap_like(inpt, output)
elif isinstance(inpt, datapoints.Video):
output = convert_dtype_video(inpt.as_subclass(torch.Tensor), dtype)
return features.Video.wrap_like(inpt, output)
return datapoints.Video.wrap_like(inpt, output)
else:
raise TypeError(
f"Input can either be a plain tensor or an `Image` or `Video` tensor subclass, "
f"but got {type(inpt)} instead."
f"Input can either be a plain tensor or an `Image` or `Video` datapoint, " f"but got {type(inpt)} instead."
)
......@@ -4,9 +4,12 @@ from typing import List, Optional, Union
import PIL.Image
import torch
from torch.nn.functional import conv2d, pad as torch_pad
from torchvision.prototype import features
from torchvision.prototype import datapoints
from torchvision.transforms.functional import pil_to_tensor, to_pil_image
from ..utils import is_simple_tensor
def normalize_image_tensor(
image: torch.Tensor, mean: List[float], std: List[float], inplace: bool = False
......@@ -48,17 +51,17 @@ def normalize_video(video: torch.Tensor, mean: List[float], std: List[float], in
def normalize(
inpt: Union[features.TensorImageTypeJIT, features.TensorVideoTypeJIT],
inpt: Union[datapoints.TensorImageTypeJIT, datapoints.TensorVideoTypeJIT],
mean: List[float],
std: List[float],
inplace: bool = False,
) -> torch.Tensor:
if not torch.jit.is_scripting():
if features.is_simple_tensor(inpt) or isinstance(inpt, (features.Image, features.Video)):
if is_simple_tensor(inpt) or isinstance(inpt, (datapoints.Image, datapoints.Video)):
inpt = inpt.as_subclass(torch.Tensor)
else:
raise TypeError(
f"Input can either be a plain tensor or an `Image` or `Video` tensor subclass, "
f"Input can either be a plain tensor or an `Image` or `Video` datapoint, "
f"but got {type(inpt)} instead."
)
......@@ -163,16 +166,18 @@ def gaussian_blur_video(
def gaussian_blur(
inpt: features.InputTypeJIT, kernel_size: List[int], sigma: Optional[List[float]] = None
) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
inpt: datapoints.InputTypeJIT, kernel_size: List[int], sigma: Optional[List[float]] = None
) -> datapoints.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
return gaussian_blur_image_tensor(inpt, kernel_size=kernel_size, sigma=sigma)
elif isinstance(inpt, features._Feature):
elif isinstance(inpt, datapoints._datapoint.Datapoint):
return inpt.gaussian_blur(kernel_size=kernel_size, sigma=sigma)
elif isinstance(inpt, PIL.Image.Image):
return gaussian_blur_image_pil(inpt, kernel_size=kernel_size, sigma=sigma)
else:
raise TypeError(
f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, "
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
import torch
from torchvision.prototype import features
from torchvision.prototype import datapoints
def uniform_temporal_subsample_video(video: torch.Tensor, num_samples: int, temporal_dim: int = -4) -> torch.Tensor:
......@@ -11,18 +11,16 @@ def uniform_temporal_subsample_video(video: torch.Tensor, num_samples: int, temp
def uniform_temporal_subsample(
inpt: features.VideoTypeJIT, num_samples: int, temporal_dim: int = -4
) -> features.VideoTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features.Video)):
inpt: datapoints.VideoTypeJIT, num_samples: int, temporal_dim: int = -4
) -> datapoints.VideoTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, datapoints.Video)):
return uniform_temporal_subsample_video(inpt, num_samples, temporal_dim=temporal_dim)
elif isinstance(inpt, features.Video):
elif isinstance(inpt, datapoints.Video):
if temporal_dim != -4 and inpt.ndim - 4 != temporal_dim:
raise ValueError("Video inputs must have temporal_dim equivalent to -4")
output = uniform_temporal_subsample_video(
inpt.as_subclass(torch.Tensor), num_samples, temporal_dim=temporal_dim
)
return features.Video.wrap_like(inpt, output)
return datapoints.Video.wrap_like(inpt, output)
else:
raise TypeError(
f"Input can either be a plain tensor or a `Video` tensor subclass, but got {type(inpt)} instead."
)
raise TypeError(f"Input can either be a plain tensor or a `Video` datapoint, but got {type(inpt)} instead.")
......@@ -3,12 +3,12 @@ from typing import Union
import numpy as np
import PIL.Image
import torch
from torchvision.prototype import features
from torchvision.prototype import datapoints
from torchvision.transforms import functional as _F
@torch.jit.unused
def to_image_tensor(inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray]) -> features.Image:
def to_image_tensor(inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray]) -> datapoints.Image:
if isinstance(inpt, np.ndarray):
output = torch.from_numpy(inpt).permute((2, 0, 1)).contiguous()
elif isinstance(inpt, PIL.Image.Image):
......@@ -17,7 +17,7 @@ def to_image_tensor(inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray]) -> f
output = inpt
else:
raise TypeError(f"Input can either be a numpy array or a PIL image, but got {type(inpt)} instead.")
return features.Image(output)
return datapoints.Image(output)
to_image_pil = _F.to_pil_image
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment