Unverified Commit 3991ab99 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Promote prototype transforms to beta status (#7261)


Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
Co-authored-by: default avatarvfdev-5 <vfdev.5@gmail.com>
parent d010e82f
......@@ -4,7 +4,8 @@ from typing import Any, BinaryIO, Dict, Iterator, List, Optional, Sequence, Tupl
import torch
from torchdata.datapipes.iter import Filter, IterDataPipe, IterKeyZipper, Mapper, Zipper
from torchvision.prototype.datapoints import BoundingBox, Label
from torchvision.datapoints import BoundingBox
from torchvision.prototype.datapoints import Label
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, GDriveResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import (
getitem,
......
......@@ -6,7 +6,8 @@ from typing import Any, BinaryIO, cast, Dict, Iterator, List, Optional, Tuple, U
import numpy as np
from torchdata.datapipes.iter import Filter, IterDataPipe, Mapper
from torchvision.prototype.datapoints import Image, Label
from torchvision.datapoints import Image
from torchvision.prototype.datapoints import Label
from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import (
hint_sharding,
......
......@@ -14,7 +14,8 @@ from torchdata.datapipes.iter import (
Mapper,
UnBatcher,
)
from torchvision.prototype.datapoints import BoundingBox, Label, Mask
from torchvision.datapoints import BoundingBox, Mask
from torchvision.prototype.datapoints import Label
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import (
getitem,
......
......@@ -15,7 +15,8 @@ from torchdata.datapipes.iter import (
Mapper,
)
from torchdata.datapipes.map import IterToMapConverter
from torchvision.prototype.datapoints import BoundingBox, Label
from torchvision.datapoints import BoundingBox
from torchvision.prototype.datapoints import Label
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, GDriveResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import (
getitem,
......
......@@ -3,7 +3,8 @@ from typing import Any, Dict, List, Union
import torch
from torchdata.datapipes.iter import CSVDictParser, IterDataPipe, Mapper
from torchvision.prototype.datapoints import Image, Label
from torchvision.datapoints import Image
from torchvision.prototype.datapoints import Label
from torchvision.prototype.datasets.utils import Dataset, KaggleDownloadResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling
......
......@@ -2,7 +2,8 @@ import pathlib
from typing import Any, Dict, List, Optional, Tuple, Union
from torchdata.datapipes.iter import CSVDictParser, Demultiplexer, Filter, IterDataPipe, Mapper, Zipper
from torchvision.prototype.datapoints import BoundingBox, Label
from torchvision.datapoints import BoundingBox
from torchvision.prototype.datapoints import Label
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import (
hint_sharding,
......
......@@ -7,7 +7,8 @@ from typing import Any, BinaryIO, cast, Dict, Iterator, List, Optional, Sequence
import torch
from torchdata.datapipes.iter import Decompressor, Demultiplexer, IterDataPipe, Mapper, Zipper
from torchvision.prototype.datapoints import Image, Label
from torchvision.datapoints import Image
from torchvision.prototype.datapoints import Label
from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling, INFINITE_BUFFER_SIZE
from torchvision.prototype.utils._internal import fromfile
......
......@@ -4,7 +4,8 @@ from collections import namedtuple
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
from torchdata.datapipes.iter import IterDataPipe, Mapper, Zipper
from torchvision.prototype.datapoints import Image, Label
from torchvision.datapoints import Image
from torchvision.prototype.datapoints import Label
from torchvision.prototype.datasets.utils import Dataset, GDriveResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling
......
......@@ -3,7 +3,8 @@ from typing import Any, Dict, List, Tuple, Union
import torch
from torchdata.datapipes.iter import CSVParser, IterDataPipe, Mapper
from torchvision.prototype.datapoints import Image, OneHotLabel
from torchvision.datapoints import Image
from torchvision.prototype.datapoints import OneHotLabel
from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling
......
......@@ -2,7 +2,8 @@ import pathlib
from typing import Any, BinaryIO, Dict, Iterator, List, Tuple, Union
from torchdata.datapipes.iter import Filter, IterDataPipe, Mapper, Zipper
from torchvision.prototype.datapoints import BoundingBox, Label
from torchvision.datapoints import BoundingBox
from torchvision.prototype.datapoints import Label
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import (
hint_sharding,
......
......@@ -3,7 +3,8 @@ from typing import Any, BinaryIO, Dict, List, Tuple, Union
import numpy as np
from torchdata.datapipes.iter import IterDataPipe, Mapper, UnBatcher
from torchvision.prototype.datapoints import Image, Label
from torchvision.datapoints import Image
from torchvision.prototype.datapoints import Label
from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling, read_mat
......
......@@ -3,7 +3,8 @@ from typing import Any, Dict, List, Union
import torch
from torchdata.datapipes.iter import Decompressor, IterDataPipe, LineReader, Mapper
from torchvision.prototype.datapoints import Image, Label
from torchvision.datapoints import Image
from torchvision.prototype.datapoints import Label
from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling
......
......@@ -5,8 +5,9 @@ from typing import Any, BinaryIO, cast, Dict, List, Optional, Tuple, Union
from xml.etree import ElementTree
from torchdata.datapipes.iter import Demultiplexer, Filter, IterDataPipe, IterKeyZipper, LineReader, Mapper
from torchvision.datapoints import BoundingBox
from torchvision.datasets import VOCDetection
from torchvision.prototype.datapoints import BoundingBox, Label
from torchvision.prototype.datapoints import Label
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import (
getitem,
......
......@@ -7,7 +7,7 @@ from typing import Any, BinaryIO, Optional, Tuple, Type, TypeVar, Union
import PIL.Image
import torch
from torchvision.prototype.datapoints._datapoint import Datapoint
from torchvision.datapoints._datapoint import Datapoint
from torchvision.prototype.utils._internal import fromfile, ReadOnlyTensorBuffer
D = TypeVar("D", bound="EncodedData")
......
from torchvision.transforms import AutoAugmentPolicy, InterpolationMode # usort: skip
from . import functional, utils # usort: skip
from ._transform import Transform # usort: skip
from ._presets import StereoMatching # usort: skip
from ._augment import RandomCutmix, RandomErasing, RandomMixup, SimpleCopyPaste
from ._auto_augment import AugMix, AutoAugment, RandAugment, TrivialAugmentWide
from ._color import (
ColorJitter,
Grayscale,
RandomAdjustSharpness,
RandomAutocontrast,
RandomEqualize,
RandomGrayscale,
RandomInvert,
RandomPhotometricDistort,
RandomPosterize,
RandomSolarize,
)
from ._container import Compose, RandomApply, RandomChoice, RandomOrder
from ._geometry import (
CenterCrop,
ElasticTransform,
FiveCrop,
FixedSizeCrop,
Pad,
RandomAffine,
RandomCrop,
RandomHorizontalFlip,
RandomIoUCrop,
RandomPerspective,
RandomResize,
RandomResizedCrop,
RandomRotation,
RandomShortestSize,
RandomVerticalFlip,
RandomZoomOut,
Resize,
ScaleJitter,
TenCrop,
)
from ._meta import ClampBoundingBox, ConvertBoundingBoxFormat, ConvertDtype, ConvertImageDtype
from ._misc import (
GaussianBlur,
Identity,
Lambda,
LinearTransformation,
Normalize,
PermuteDimensions,
SanitizeBoundingBoxes,
ToDtype,
TransposeDimensions,
)
from ._temporal import UniformTemporalSubsample
from ._type_conversion import LabelToOneHot, PILToTensor, ToImagePIL, ToImageTensor, ToPILImage
from ._deprecated import ToTensor # usort: skip
from ._augment import RandomCutmix, RandomMixup, SimpleCopyPaste
from ._geometry import FixedSizeCrop
from ._misc import PermuteDimensions, TransposeDimensions
from ._type_conversion import LabelToOneHot
import math
import numbers
import warnings
from typing import Any, cast, Dict, List, Optional, Tuple, Union
import PIL.Image
import torch
from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision import transforms as _transforms
from torchvision import datapoints
from torchvision.ops import masks_to_boxes
from torchvision.prototype import datapoints
from torchvision.prototype.transforms import functional as F, InterpolationMode, Transform
from torchvision.prototype.transforms.functional._geometry import _check_interpolation
from torchvision.prototype import datapoints as proto_datapoints
from torchvision.transforms.v2 import functional as F, InterpolationMode, Transform
from ._transform import _RandomApplyTransform
from .utils import has_any, is_simple_tensor, query_chw, query_spatial_size
class RandomErasing(_RandomApplyTransform):
_v1_transform_cls = _transforms.RandomErasing
def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
return dict(
super()._extract_params_for_v1_transform(),
value="random" if self.value is None else self.value,
)
_transformed_types = (is_simple_tensor, datapoints.Image, PIL.Image.Image, datapoints.Video)
def __init__(
self,
p: float = 0.5,
scale: Tuple[float, float] = (0.02, 0.33),
ratio: Tuple[float, float] = (0.3, 3.3),
value: float = 0.0,
inplace: bool = False,
):
super().__init__(p=p)
if not isinstance(value, (numbers.Number, str, tuple, list)):
raise TypeError("Argument value should be either a number or str or a sequence")
if isinstance(value, str) and value != "random":
raise ValueError("If value is str, it should be 'random'")
if not isinstance(scale, (tuple, list)):
raise TypeError("Scale should be a sequence")
if not isinstance(ratio, (tuple, list)):
raise TypeError("Ratio should be a sequence")
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
warnings.warn("Scale and ratio should be of kind (min, max)")
if scale[0] < 0 or scale[1] > 1:
raise ValueError("Scale should be between 0 and 1")
self.scale = scale
self.ratio = ratio
if isinstance(value, (int, float)):
self.value = [float(value)]
elif isinstance(value, str):
self.value = None
elif isinstance(value, (list, tuple)):
self.value = [float(v) for v in value]
else:
self.value = value
self.inplace = inplace
self._log_ratio = torch.log(torch.tensor(self.ratio))
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
img_c, img_h, img_w = query_chw(flat_inputs)
if self.value is not None and not (len(self.value) in (1, img_c)):
raise ValueError(
f"If value is a sequence, it should have either a single value or {img_c} (number of inpt channels)"
)
area = img_h * img_w
log_ratio = self._log_ratio
for _ in range(10):
erase_area = area * torch.empty(1).uniform_(self.scale[0], self.scale[1]).item()
aspect_ratio = torch.exp(
torch.empty(1).uniform_(
log_ratio[0], # type: ignore[arg-type]
log_ratio[1], # type: ignore[arg-type]
)
).item()
h = int(round(math.sqrt(erase_area * aspect_ratio)))
w = int(round(math.sqrt(erase_area / aspect_ratio)))
if not (h < img_h and w < img_w):
continue
if self.value is None:
v = torch.empty([img_c, h, w], dtype=torch.float32).normal_()
else:
v = torch.tensor(self.value)[:, None, None]
i = torch.randint(0, img_h - h + 1, size=(1,)).item()
j = torch.randint(0, img_w - w + 1, size=(1,)).item()
break
else:
i, j, h, w, v = 0, 0, img_h, img_w, None
return dict(i=i, j=j, h=h, w=w, v=v)
def _transform(
self, inpt: Union[datapoints.ImageType, datapoints.VideoType], params: Dict[str, Any]
) -> Union[datapoints.ImageType, datapoints.VideoType]:
if params["v"] is not None:
inpt = F.erase(inpt, **params, inplace=self.inplace)
return inpt
from torchvision.transforms.v2._transform import _RandomApplyTransform
from torchvision.transforms.v2.functional._geometry import _check_interpolation
from torchvision.transforms.v2.utils import has_any, is_simple_tensor, query_spatial_size
class _BaseMixupCutmix(_RandomApplyTransform):
......@@ -118,19 +23,19 @@ class _BaseMixupCutmix(_RandomApplyTransform):
def _check_inputs(self, flat_inputs: List[Any]) -> None:
if not (
has_any(flat_inputs, datapoints.Image, datapoints.Video, is_simple_tensor)
and has_any(flat_inputs, datapoints.OneHotLabel)
and has_any(flat_inputs, proto_datapoints.OneHotLabel)
):
raise TypeError(f"{type(self).__name__}() is only defined for tensor images/videos and one-hot labels.")
if has_any(flat_inputs, PIL.Image.Image, datapoints.BoundingBox, datapoints.Mask, datapoints.Label):
if has_any(flat_inputs, PIL.Image.Image, datapoints.BoundingBox, datapoints.Mask, proto_datapoints.Label):
raise TypeError(
f"{type(self).__name__}() does not support PIL images, bounding boxes, masks and plain labels."
)
def _mixup_onehotlabel(self, inpt: datapoints.OneHotLabel, lam: float) -> datapoints.OneHotLabel:
def _mixup_onehotlabel(self, inpt: proto_datapoints.OneHotLabel, lam: float) -> proto_datapoints.OneHotLabel:
if inpt.ndim < 2:
raise ValueError("Need a batch of one hot labels")
output = inpt.roll(1, 0).mul_(1.0 - lam).add_(inpt.mul(lam))
return datapoints.OneHotLabel.wrap_like(inpt, output)
return proto_datapoints.OneHotLabel.wrap_like(inpt, output)
class RandomMixup(_BaseMixupCutmix):
......@@ -149,7 +54,7 @@ class RandomMixup(_BaseMixupCutmix):
output = type(inpt).wrap_like(inpt, output) # type: ignore[arg-type]
return output
elif isinstance(inpt, datapoints.OneHotLabel):
elif isinstance(inpt, proto_datapoints.OneHotLabel):
return self._mixup_onehotlabel(inpt, lam)
else:
return inpt
......@@ -193,7 +98,7 @@ class RandomCutmix(_BaseMixupCutmix):
output = inpt.wrap_like(inpt, output) # type: ignore[arg-type]
return output
elif isinstance(inpt, datapoints.OneHotLabel):
elif isinstance(inpt, proto_datapoints.OneHotLabel):
lam_adjusted = params["lam_adjusted"]
return self._mixup_onehotlabel(inpt, lam_adjusted)
else:
......@@ -307,7 +212,7 @@ class SimpleCopyPaste(Transform):
bboxes.append(obj)
elif isinstance(obj, datapoints.Mask):
masks.append(obj)
elif isinstance(obj, (datapoints.Label, datapoints.OneHotLabel)):
elif isinstance(obj, (proto_datapoints.Label, proto_datapoints.OneHotLabel)):
labels.append(obj)
if not (len(images) == len(bboxes) == len(masks) == len(labels)):
......@@ -345,7 +250,7 @@ class SimpleCopyPaste(Transform):
elif isinstance(obj, datapoints.Mask):
flat_sample[i] = datapoints.Mask.wrap_like(obj, output_targets[c2]["masks"])
c2 += 1
elif isinstance(obj, (datapoints.Label, datapoints.OneHotLabel)):
elif isinstance(obj, (proto_datapoints.Label, proto_datapoints.OneHotLabel)):
flat_sample[i] = obj.wrap_like(obj, output_targets[c3]["labels"]) # type: ignore[arg-type]
c3 += 1
......
import math
import numbers
import warnings
from typing import Any, cast, Dict, List, Literal, Optional, Sequence, Tuple, Type, Union
from typing import Any, Dict, List, Optional, Sequence, Type, Union
import PIL.Image
import torch
from torchvision import transforms as _transforms
from torchvision.ops.boxes import box_iou
from torchvision.prototype import datapoints
from torchvision.prototype.transforms import functional as F, InterpolationMode, Transform
from torchvision.prototype.transforms.functional._geometry import _check_interpolation
from torchvision.transforms.functional import _get_perspective_coeffs
from ._transform import _RandomApplyTransform
from ._utils import (
_check_padding_arg,
_check_padding_mode_arg,
_check_sequence_input,
_setup_angle,
_setup_fill_arg,
_setup_float_or_seq,
_setup_size,
)
from .utils import has_all, has_any, is_simple_tensor, query_bounding_box, query_spatial_size
class RandomHorizontalFlip(_RandomApplyTransform):
_v1_transform_cls = _transforms.RandomHorizontalFlip
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.horizontal_flip(inpt)
class RandomVerticalFlip(_RandomApplyTransform):
_v1_transform_cls = _transforms.RandomVerticalFlip
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.vertical_flip(inpt)
class Resize(Transform):
_v1_transform_cls = _transforms.Resize
def __init__(
self,
size: Union[int, Sequence[int]],
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
max_size: Optional[int] = None,
antialias: Optional[Union[str, bool]] = "warn",
) -> None:
super().__init__()
if isinstance(size, int):
size = [size]
elif isinstance(size, (list, tuple)) and len(size) in {1, 2}:
size = list(size)
else:
raise ValueError(
f"size can either be an integer or a list or tuple of one or two integers, " f"but got {size} instead."
)
self.size = size
self.interpolation = _check_interpolation(interpolation)
self.max_size = max_size
self.antialias = antialias
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.resize(
inpt,
self.size,
interpolation=self.interpolation,
max_size=self.max_size,
antialias=self.antialias,
)
class CenterCrop(Transform):
_v1_transform_cls = _transforms.CenterCrop
def __init__(self, size: Union[int, Sequence[int]]):
super().__init__()
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.center_crop(inpt, output_size=self.size)
class RandomResizedCrop(Transform):
_v1_transform_cls = _transforms.RandomResizedCrop
def __init__(
self,
size: Union[int, Sequence[int]],
scale: Tuple[float, float] = (0.08, 1.0),
ratio: Tuple[float, float] = (3.0 / 4.0, 4.0 / 3.0),
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
antialias: Optional[Union[str, bool]] = "warn",
) -> None:
super().__init__()
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
if not isinstance(scale, Sequence):
raise TypeError("Scale should be a sequence")
scale = cast(Tuple[float, float], scale)
if not isinstance(ratio, Sequence):
raise TypeError("Ratio should be a sequence")
ratio = cast(Tuple[float, float], ratio)
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
warnings.warn("Scale and ratio should be of kind (min, max)")
self.scale = scale
self.ratio = ratio
self.interpolation = _check_interpolation(interpolation)
self.antialias = antialias
self._log_ratio = torch.log(torch.tensor(self.ratio))
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
height, width = query_spatial_size(flat_inputs)
area = height * width
log_ratio = self._log_ratio
for _ in range(10):
target_area = area * torch.empty(1).uniform_(self.scale[0], self.scale[1]).item()
aspect_ratio = torch.exp(
torch.empty(1).uniform_(
log_ratio[0], # type: ignore[arg-type]
log_ratio[1], # type: ignore[arg-type]
)
).item()
w = int(round(math.sqrt(target_area * aspect_ratio)))
h = int(round(math.sqrt(target_area / aspect_ratio)))
if 0 < w <= width and 0 < h <= height:
i = torch.randint(0, height - h + 1, size=(1,)).item()
j = torch.randint(0, width - w + 1, size=(1,)).item()
break
else:
# Fallback to central crop
in_ratio = float(width) / float(height)
if in_ratio < min(self.ratio):
w = width
h = int(round(w / min(self.ratio)))
elif in_ratio > max(self.ratio):
h = height
w = int(round(h * max(self.ratio)))
else: # whole image
w = width
h = height
i = (height - h) // 2
j = (width - w) // 2
return dict(top=i, left=j, height=h, width=w)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.resized_crop(
inpt, **params, size=self.size, interpolation=self.interpolation, antialias=self.antialias
)
ImageOrVideoTypeJIT = Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT]
class FiveCrop(Transform):
"""
Example:
>>> class BatchMultiCrop(transforms.Transform):
... def forward(self, sample: Tuple[Tuple[Union[datapoints.Image, datapoints.Video], ...], datapoints.Label]):
... images_or_videos, labels = sample
... batch_size = len(images_or_videos)
... image_or_video = images_or_videos[0]
... images_or_videos = image_or_video.wrap_like(image_or_video, torch.stack(images_or_videos))
... labels = datapoints.Label.wrap_like(labels, labels.repeat(batch_size))
... return images_or_videos, labels
...
>>> image = datapoints.Image(torch.rand(3, 256, 256))
>>> label = datapoints.Label(0)
>>> transform = transforms.Compose([transforms.FiveCrop(), BatchMultiCrop()])
>>> images, labels = transform(image, label)
>>> images.shape
torch.Size([5, 3, 224, 224])
>>> labels.shape
torch.Size([5])
"""
_v1_transform_cls = _transforms.FiveCrop
_transformed_types = (
datapoints.Image,
PIL.Image.Image,
is_simple_tensor,
datapoints.Video,
)
def __init__(self, size: Union[int, Sequence[int]]) -> None:
super().__init__()
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
def _transform(
self, inpt: ImageOrVideoTypeJIT, params: Dict[str, Any]
) -> Tuple[ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT]:
return F.five_crop(inpt, self.size)
def _check_inputs(self, flat_inputs: List[Any]) -> None:
if has_any(flat_inputs, datapoints.BoundingBox, datapoints.Mask):
raise TypeError(f"BoundingBox'es and Mask's are not supported by {type(self).__name__}()")
class TenCrop(Transform):
"""
See :class:`~torchvision.prototype.transforms.FiveCrop` for an example.
"""
_v1_transform_cls = _transforms.TenCrop
_transformed_types = (
datapoints.Image,
PIL.Image.Image,
is_simple_tensor,
datapoints.Video,
)
def __init__(self, size: Union[int, Sequence[int]], vertical_flip: bool = False) -> None:
super().__init__()
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
self.vertical_flip = vertical_flip
def _check_inputs(self, flat_inputs: List[Any]) -> None:
if has_any(flat_inputs, datapoints.BoundingBox, datapoints.Mask):
raise TypeError(f"BoundingBox'es and Mask's are not supported by {type(self).__name__}()")
def _transform(
self, inpt: Union[datapoints.ImageType, datapoints.VideoType], params: Dict[str, Any]
) -> Tuple[
ImageOrVideoTypeJIT,
ImageOrVideoTypeJIT,
ImageOrVideoTypeJIT,
ImageOrVideoTypeJIT,
ImageOrVideoTypeJIT,
ImageOrVideoTypeJIT,
ImageOrVideoTypeJIT,
ImageOrVideoTypeJIT,
ImageOrVideoTypeJIT,
ImageOrVideoTypeJIT,
]:
return F.ten_crop(inpt, self.size, vertical_flip=self.vertical_flip)
class Pad(Transform):
_v1_transform_cls = _transforms.Pad
def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
params = super()._extract_params_for_v1_transform()
if not (params["fill"] is None or isinstance(params["fill"], (int, float))):
raise ValueError(
f"{type(self.__name__)}() can only be scripted for a scalar `fill`, but got {self.fill} for images."
)
return params
def __init__(
self,
padding: Union[int, Sequence[int]],
fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = 0,
padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant",
) -> None:
super().__init__()
_check_padding_arg(padding)
_check_padding_mode_arg(padding_mode)
# This cast does Sequence[int] -> List[int] and is required to make mypy happy
if not isinstance(padding, int):
padding = list(padding)
self.padding = padding
self.fill = _setup_fill_arg(fill)
self.padding_mode = padding_mode
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self.fill[type(inpt)]
return F.pad(inpt, padding=self.padding, fill=fill, padding_mode=self.padding_mode) # type: ignore[arg-type]
class RandomZoomOut(_RandomApplyTransform):
def __init__(
self,
fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = 0,
side_range: Sequence[float] = (1.0, 4.0),
p: float = 0.5,
) -> None:
super().__init__(p=p)
self.fill = _setup_fill_arg(fill)
_check_sequence_input(side_range, "side_range", req_sizes=(2,))
self.side_range = side_range
if side_range[0] < 1.0 or side_range[0] > side_range[1]:
raise ValueError(f"Invalid canvas side range provided {side_range}.")
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
orig_h, orig_w = query_spatial_size(flat_inputs)
r = self.side_range[0] + torch.rand(1) * (self.side_range[1] - self.side_range[0])
canvas_width = int(orig_w * r)
canvas_height = int(orig_h * r)
r = torch.rand(2)
left = int((canvas_width - orig_w) * r[0])
top = int((canvas_height - orig_h) * r[1])
right = canvas_width - (left + orig_w)
bottom = canvas_height - (top + orig_h)
padding = [left, top, right, bottom]
return dict(padding=padding)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self.fill[type(inpt)]
return F.pad(inpt, **params, fill=fill)
class RandomRotation(Transform):
_v1_transform_cls = _transforms.RandomRotation
def __init__(
self,
degrees: Union[numbers.Number, Sequence],
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
expand: bool = False,
fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = 0,
center: Optional[List[float]] = None,
) -> None:
super().__init__()
self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,))
self.interpolation = _check_interpolation(interpolation)
self.expand = expand
self.fill = _setup_fill_arg(fill)
if center is not None:
_check_sequence_input(center, "center", req_sizes=(2,))
self.center = center
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
angle = torch.empty(1).uniform_(self.degrees[0], self.degrees[1]).item()
return dict(angle=angle)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self.fill[type(inpt)]
return F.rotate(
inpt,
**params,
interpolation=self.interpolation,
expand=self.expand,
center=self.center,
fill=fill,
)
class RandomAffine(Transform):
_v1_transform_cls = _transforms.RandomAffine
def __init__(
self,
degrees: Union[numbers.Number, Sequence],
translate: Optional[Sequence[float]] = None,
scale: Optional[Sequence[float]] = None,
shear: Optional[Union[int, float, Sequence[float]]] = None,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = 0,
center: Optional[List[float]] = None,
) -> None:
super().__init__()
self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,))
if translate is not None:
_check_sequence_input(translate, "translate", req_sizes=(2,))
for t in translate:
if not (0.0 <= t <= 1.0):
raise ValueError("translation values should be between 0 and 1")
self.translate = translate
if scale is not None:
_check_sequence_input(scale, "scale", req_sizes=(2,))
for s in scale:
if s <= 0:
raise ValueError("scale values should be positive")
self.scale = scale
if shear is not None:
self.shear = _setup_angle(shear, name="shear", req_sizes=(2, 4))
else:
self.shear = shear
self.interpolation = _check_interpolation(interpolation)
self.fill = _setup_fill_arg(fill)
if center is not None:
_check_sequence_input(center, "center", req_sizes=(2,))
self.center = center
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
height, width = query_spatial_size(flat_inputs)
angle = torch.empty(1).uniform_(self.degrees[0], self.degrees[1]).item()
if self.translate is not None:
max_dx = float(self.translate[0] * width)
max_dy = float(self.translate[1] * height)
tx = int(round(torch.empty(1).uniform_(-max_dx, max_dx).item()))
ty = int(round(torch.empty(1).uniform_(-max_dy, max_dy).item()))
translate = (tx, ty)
else:
translate = (0, 0)
if self.scale is not None:
scale = torch.empty(1).uniform_(self.scale[0], self.scale[1]).item()
else:
scale = 1.0
shear_x = shear_y = 0.0
if self.shear is not None:
shear_x = torch.empty(1).uniform_(self.shear[0], self.shear[1]).item()
if len(self.shear) == 4:
shear_y = torch.empty(1).uniform_(self.shear[2], self.shear[3]).item()
shear = (shear_x, shear_y)
return dict(angle=angle, translate=translate, scale=scale, shear=shear)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self.fill[type(inpt)]
return F.affine(
inpt,
**params,
interpolation=self.interpolation,
fill=fill,
center=self.center,
)
class RandomCrop(Transform):
_v1_transform_cls = _transforms.RandomCrop
def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
params = super()._extract_params_for_v1_transform()
if not (params["fill"] is None or isinstance(params["fill"], (int, float))):
raise ValueError(
f"{type(self.__name__)}() can only be scripted for a scalar `fill`, but got {self.fill} for images."
)
padding = self.padding
if padding is not None:
pad_left, pad_right, pad_top, pad_bottom = padding
padding = [pad_left, pad_top, pad_right, pad_bottom]
params["padding"] = padding
return params
def __init__(
self,
size: Union[int, Sequence[int]],
padding: Optional[Union[int, Sequence[int]]] = None,
pad_if_needed: bool = False,
fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = 0,
padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant",
) -> None:
super().__init__()
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
if pad_if_needed or padding is not None:
if padding is not None:
_check_padding_arg(padding)
_check_padding_mode_arg(padding_mode)
self.padding = F._geometry._parse_pad_padding(padding) if padding else None # type: ignore[arg-type]
self.pad_if_needed = pad_if_needed
self.fill = _setup_fill_arg(fill)
self.padding_mode = padding_mode
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
padded_height, padded_width = query_spatial_size(flat_inputs)
if self.padding is not None:
pad_left, pad_right, pad_top, pad_bottom = self.padding
padded_height += pad_top + pad_bottom
padded_width += pad_left + pad_right
else:
pad_left = pad_right = pad_top = pad_bottom = 0
cropped_height, cropped_width = self.size
if self.pad_if_needed:
if padded_height < cropped_height:
diff = cropped_height - padded_height
pad_top += diff
pad_bottom += diff
padded_height += 2 * diff
if padded_width < cropped_width:
diff = cropped_width - padded_width
pad_left += diff
pad_right += diff
padded_width += 2 * diff
if padded_height < cropped_height or padded_width < cropped_width:
raise ValueError(
f"Required crop size {(cropped_height, cropped_width)} is larger than "
f"{'padded ' if self.padding is not None else ''}input image size {(padded_height, padded_width)}."
)
# We need a different order here than we have in self.padding since this padding will be parsed again in `F.pad`
padding = [pad_left, pad_top, pad_right, pad_bottom]
needs_pad = any(padding)
needs_vert_crop, top = (
(True, int(torch.randint(0, padded_height - cropped_height + 1, size=())))
if padded_height > cropped_height
else (False, 0)
)
needs_horz_crop, left = (
(True, int(torch.randint(0, padded_width - cropped_width + 1, size=())))
if padded_width > cropped_width
else (False, 0)
)
return dict(
needs_crop=needs_vert_crop or needs_horz_crop,
top=top,
left=left,
height=cropped_height,
width=cropped_width,
needs_pad=needs_pad,
padding=padding,
)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if params["needs_pad"]:
fill = self.fill[type(inpt)]
inpt = F.pad(inpt, padding=params["padding"], fill=fill, padding_mode=self.padding_mode)
if params["needs_crop"]:
inpt = F.crop(inpt, top=params["top"], left=params["left"], height=params["height"], width=params["width"])
return inpt
class RandomPerspective(_RandomApplyTransform):
_v1_transform_cls = _transforms.RandomPerspective
def __init__(
self,
distortion_scale: float = 0.5,
fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = 0,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
p: float = 0.5,
) -> None:
super().__init__(p=p)
if not (0 <= distortion_scale <= 1):
raise ValueError("Argument distortion_scale value should be between 0 and 1")
self.distortion_scale = distortion_scale
self.interpolation = _check_interpolation(interpolation)
self.fill = _setup_fill_arg(fill)
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
height, width = query_spatial_size(flat_inputs)
distortion_scale = self.distortion_scale
half_height = height // 2
half_width = width // 2
bound_height = int(distortion_scale * half_height) + 1
bound_width = int(distortion_scale * half_width) + 1
topleft = [
int(torch.randint(0, bound_width, size=(1,))),
int(torch.randint(0, bound_height, size=(1,))),
]
topright = [
int(torch.randint(width - bound_width, width, size=(1,))),
int(torch.randint(0, bound_height, size=(1,))),
]
botright = [
int(torch.randint(width - bound_width, width, size=(1,))),
int(torch.randint(height - bound_height, height, size=(1,))),
]
botleft = [
int(torch.randint(0, bound_width, size=(1,))),
int(torch.randint(height - bound_height, height, size=(1,))),
]
startpoints = [[0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]]
endpoints = [topleft, topright, botright, botleft]
perspective_coeffs = _get_perspective_coeffs(startpoints, endpoints)
return dict(coefficients=perspective_coeffs)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self.fill[type(inpt)]
return F.perspective(
inpt,
None,
None,
fill=fill,
interpolation=self.interpolation,
**params,
)
class ElasticTransform(Transform):
_v1_transform_cls = _transforms.ElasticTransform
def __init__(
self,
alpha: Union[float, Sequence[float]] = 50.0,
sigma: Union[float, Sequence[float]] = 5.0,
fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = 0,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
) -> None:
super().__init__()
self.alpha = _setup_float_or_seq(alpha, "alpha", 2)
self.sigma = _setup_float_or_seq(sigma, "sigma", 2)
self.interpolation = _check_interpolation(interpolation)
self.fill = _setup_fill_arg(fill)
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
size = list(query_spatial_size(flat_inputs))
dx = torch.rand([1, 1] + size) * 2 - 1
if self.sigma[0] > 0.0:
kx = int(8 * self.sigma[0] + 1)
# if kernel size is even we have to make it odd
if kx % 2 == 0:
kx += 1
dx = F.gaussian_blur(dx, [kx, kx], list(self.sigma))
dx = dx * self.alpha[0] / size[0]
dy = torch.rand([1, 1] + size) * 2 - 1
if self.sigma[1] > 0.0:
ky = int(8 * self.sigma[1] + 1)
# if kernel size is even we have to make it odd
if ky % 2 == 0:
ky += 1
dy = F.gaussian_blur(dy, [ky, ky], list(self.sigma))
dy = dy * self.alpha[1] / size[1]
displacement = torch.concat([dx, dy], 1).permute([0, 2, 3, 1]) # 1 x H x W x 2
return dict(displacement=displacement)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self.fill[type(inpt)]
return F.elastic(
inpt,
**params,
fill=fill,
interpolation=self.interpolation,
)
class RandomIoUCrop(Transform):
def __init__(
self,
min_scale: float = 0.3,
max_scale: float = 1.0,
min_aspect_ratio: float = 0.5,
max_aspect_ratio: float = 2.0,
sampler_options: Optional[List[float]] = None,
trials: int = 40,
):
super().__init__()
# Configuration similar to https://github.com/weiliu89/caffe/blob/ssd/examples/ssd/ssd_coco.py#L89-L174
self.min_scale = min_scale
self.max_scale = max_scale
self.min_aspect_ratio = min_aspect_ratio
self.max_aspect_ratio = max_aspect_ratio
if sampler_options is None:
sampler_options = [0.0, 0.1, 0.3, 0.5, 0.7, 0.9, 1.0]
self.options = sampler_options
self.trials = trials
def _check_inputs(self, flat_inputs: List[Any]) -> None:
if not (
has_all(flat_inputs, datapoints.BoundingBox)
and has_any(flat_inputs, PIL.Image.Image, datapoints.Image, is_simple_tensor)
and has_any(flat_inputs, datapoints.Label, datapoints.OneHotLabel)
):
raise TypeError(
f"{type(self).__name__}() requires input sample to contain Images or PIL Images, "
"BoundingBoxes and Labels or OneHotLabels. Sample can also contain Masks."
)
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
orig_h, orig_w = query_spatial_size(flat_inputs)
bboxes = query_bounding_box(flat_inputs)
while True:
# sample an option
idx = int(torch.randint(low=0, high=len(self.options), size=(1,)))
min_jaccard_overlap = self.options[idx]
if min_jaccard_overlap >= 1.0: # a value larger than 1 encodes the leave as-is option
return dict()
for _ in range(self.trials):
# check the aspect ratio limitations
r = self.min_scale + (self.max_scale - self.min_scale) * torch.rand(2)
new_w = int(orig_w * r[0])
new_h = int(orig_h * r[1])
aspect_ratio = new_w / new_h
if not (self.min_aspect_ratio <= aspect_ratio <= self.max_aspect_ratio):
continue
# check for 0 area crops
r = torch.rand(2)
left = int((orig_w - new_w) * r[0])
top = int((orig_h - new_h) * r[1])
right = left + new_w
bottom = top + new_h
if left == right or top == bottom:
continue
# check for any valid boxes with centers within the crop area
xyxy_bboxes = F.convert_format_bounding_box(
bboxes.as_subclass(torch.Tensor), bboxes.format, datapoints.BoundingBoxFormat.XYXY
)
cx = 0.5 * (xyxy_bboxes[..., 0] + xyxy_bboxes[..., 2])
cy = 0.5 * (xyxy_bboxes[..., 1] + xyxy_bboxes[..., 3])
is_within_crop_area = (left < cx) & (cx < right) & (top < cy) & (cy < bottom)
if not is_within_crop_area.any():
continue
# check at least 1 box with jaccard limitations
xyxy_bboxes = xyxy_bboxes[is_within_crop_area]
ious = box_iou(
xyxy_bboxes,
torch.tensor([[left, top, right, bottom]], dtype=xyxy_bboxes.dtype, device=xyxy_bboxes.device),
)
if ious.max() < min_jaccard_overlap:
continue
return dict(top=top, left=left, height=new_h, width=new_w, is_within_crop_area=is_within_crop_area)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if len(params) < 1:
return inpt
is_within_crop_area = params["is_within_crop_area"]
if isinstance(inpt, (datapoints.Label, datapoints.OneHotLabel)):
return inpt.wrap_like(inpt, inpt[is_within_crop_area]) # type: ignore[arg-type]
output = F.crop(inpt, top=params["top"], left=params["left"], height=params["height"], width=params["width"])
if isinstance(output, datapoints.BoundingBox):
bboxes = output[is_within_crop_area]
bboxes = F.clamp_bounding_box(bboxes, output.format, output.spatial_size)
output = datapoints.BoundingBox.wrap_like(output, bboxes)
elif isinstance(output, datapoints.Mask):
# apply is_within_crop_area if mask is one-hot encoded
masks = output[is_within_crop_area]
output = datapoints.Mask.wrap_like(output, masks)
return output
class ScaleJitter(Transform):
def __init__(
self,
target_size: Tuple[int, int],
scale_range: Tuple[float, float] = (0.1, 2.0),
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
antialias: Optional[Union[str, bool]] = "warn",
):
super().__init__()
self.target_size = target_size
self.scale_range = scale_range
self.interpolation = _check_interpolation(interpolation)
self.antialias = antialias
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
orig_height, orig_width = query_spatial_size(flat_inputs)
scale = self.scale_range[0] + torch.rand(1) * (self.scale_range[1] - self.scale_range[0])
r = min(self.target_size[1] / orig_height, self.target_size[0] / orig_width) * scale
new_width = int(orig_width * r)
new_height = int(orig_height * r)
return dict(size=(new_height, new_width))
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.resize(inpt, size=params["size"], interpolation=self.interpolation, antialias=self.antialias)
class RandomShortestSize(Transform):
def __init__(
self,
min_size: Union[List[int], Tuple[int], int],
max_size: Optional[int] = None,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
antialias: Optional[Union[str, bool]] = "warn",
):
super().__init__()
self.min_size = [min_size] if isinstance(min_size, int) else list(min_size)
self.max_size = max_size
self.interpolation = _check_interpolation(interpolation)
self.antialias = antialias
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
orig_height, orig_width = query_spatial_size(flat_inputs)
min_size = self.min_size[int(torch.randint(len(self.min_size), ()))]
r = min_size / min(orig_height, orig_width)
if self.max_size is not None:
r = min(r, self.max_size / max(orig_height, orig_width))
new_width = int(orig_width * r)
new_height = int(orig_height * r)
return dict(size=(new_height, new_width))
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.resize(inpt, size=params["size"], interpolation=self.interpolation, antialias=self.antialias)
from torchvision import datapoints
from torchvision.prototype.datapoints import Label, OneHotLabel
from torchvision.transforms.v2 import functional as F, Transform
from torchvision.transforms.v2._utils import _setup_fill_arg, _setup_size
from torchvision.transforms.v2.utils import has_any, is_simple_tensor, query_bounding_box, query_spatial_size
class FixedSizeCrop(Transform):
......@@ -854,9 +38,7 @@ class FixedSizeCrop(Transform):
f"{type(self).__name__}() requires input sample to contain an tensor or PIL image or a Video."
)
if has_any(flat_inputs, datapoints.BoundingBox) and not has_any(
flat_inputs, datapoints.Label, datapoints.OneHotLabel
):
if has_any(flat_inputs, datapoints.BoundingBox) and not has_any(flat_inputs, Label, OneHotLabel):
raise TypeError(
f"If a BoundingBox is contained in the input sample, "
f"{type(self).__name__}() also requires it to contain a Label or OneHotLabel."
......@@ -927,7 +109,7 @@ class FixedSizeCrop(Transform):
)
if params["is_valid"] is not None:
if isinstance(inpt, (datapoints.Label, datapoints.OneHotLabel, datapoints.Mask)):
if isinstance(inpt, (Label, OneHotLabel, datapoints.Mask)):
inpt = inpt.wrap_like(inpt, inpt[params["is_valid"]]) # type: ignore[arg-type]
elif isinstance(inpt, datapoints.BoundingBox):
inpt = datapoints.BoundingBox.wrap_like(
......@@ -940,25 +122,3 @@ class FixedSizeCrop(Transform):
inpt = F.pad(inpt, params["padding"], fill=fill, padding_mode=self.padding_mode)
return inpt
class RandomResize(Transform):
def __init__(
self,
min_size: int,
max_size: int,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
antialias: Optional[Union[str, bool]] = "warn",
) -> None:
super().__init__()
self.min_size = min_size
self.max_size = max_size
self.interpolation = _check_interpolation(interpolation)
self.antialias = antialias
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
size = int(torch.randint(self.min_size, self.max_size, ()))
return dict(size=[size])
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.resize(inpt, params["size"], interpolation=self.interpolation, antialias=self.antialias)
import collections
import warnings
from contextlib import suppress
from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Tuple, Type, Union
import PIL.Image
from typing import Any, Dict, Optional, Sequence, Tuple, Type, Union
import torch
from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision import transforms as _transforms
from torchvision.prototype import datapoints
from torchvision.prototype.transforms import functional as F, Transform
from ._utils import _get_defaultdict, _setup_float_or_seq, _setup_size
from .utils import has_any, is_simple_tensor, query_bounding_box
class Identity(Transform):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return inpt
class Lambda(Transform):
def __init__(self, lambd: Callable[[Any], Any], *types: Type):
super().__init__()
self.lambd = lambd
self.types = types or (object,)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if isinstance(inpt, self.types):
return self.lambd(inpt)
else:
return inpt
def extra_repr(self) -> str:
extras = []
name = getattr(self.lambd, "__name__", None)
if name:
extras.append(name)
extras.append(f"types={[type.__name__ for type in self.types]}")
return ", ".join(extras)
class LinearTransformation(Transform):
_v1_transform_cls = _transforms.LinearTransformation
_transformed_types = (is_simple_tensor, datapoints.Image, datapoints.Video)
def __init__(self, transformation_matrix: torch.Tensor, mean_vector: torch.Tensor):
super().__init__()
if transformation_matrix.size(0) != transformation_matrix.size(1):
raise ValueError(
"transformation_matrix should be square. Got "
f"{tuple(transformation_matrix.size())} rectangular matrix."
)
if mean_vector.size(0) != transformation_matrix.size(0):
raise ValueError(
f"mean_vector should have the same length {mean_vector.size(0)}"
f" as any one of the dimensions of the transformation_matrix [{tuple(transformation_matrix.size())}]"
)
if transformation_matrix.device != mean_vector.device:
raise ValueError(
f"Input tensors should be on the same device. Got {transformation_matrix.device} and {mean_vector.device}"
)
if transformation_matrix.dtype != mean_vector.dtype:
raise ValueError(
f"Input tensors should have the same dtype. Got {transformation_matrix.dtype} and {mean_vector.dtype}"
)
self.transformation_matrix = transformation_matrix
self.mean_vector = mean_vector
def _check_inputs(self, sample: Any) -> Any:
if has_any(sample, PIL.Image.Image):
raise TypeError("LinearTransformation does not work on PIL Images")
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
shape = inpt.shape
n = shape[-3] * shape[-2] * shape[-1]
if n != self.transformation_matrix.shape[0]:
raise ValueError(
"Input tensor and transformation matrix have incompatible shape."
+ f"[{shape[-3]} x {shape[-2]} x {shape[-1]}] != "
+ f"{self.transformation_matrix.shape[0]}"
)
if inpt.device.type != self.mean_vector.device.type:
raise ValueError(
"Input tensor should be on the same device as transformation matrix and mean vector. "
f"Got {inpt.device} vs {self.mean_vector.device}"
)
flat_inpt = inpt.reshape(-1, n) - self.mean_vector
transformation_matrix = self.transformation_matrix.to(flat_inpt.dtype)
output = torch.mm(flat_inpt, transformation_matrix)
output = output.reshape(shape)
if isinstance(inpt, (datapoints.Image, datapoints.Video)):
output = type(inpt).wrap_like(inpt, output) # type: ignore[arg-type]
return output
class Normalize(Transform):
_v1_transform_cls = _transforms.Normalize
_transformed_types = (datapoints.Image, is_simple_tensor, datapoints.Video)
def __init__(self, mean: Sequence[float], std: Sequence[float], inplace: bool = False):
super().__init__()
self.mean = list(mean)
self.std = list(std)
self.inplace = inplace
def _check_inputs(self, sample: Any) -> Any:
if has_any(sample, PIL.Image.Image):
raise TypeError(f"{type(self).__name__}() does not support PIL images.")
def _transform(
self, inpt: Union[datapoints.TensorImageType, datapoints.TensorVideoType], params: Dict[str, Any]
) -> Any:
return F.normalize(inpt, mean=self.mean, std=self.std, inplace=self.inplace)
class GaussianBlur(Transform):
_v1_transform_cls = _transforms.GaussianBlur
def __init__(
self, kernel_size: Union[int, Sequence[int]], sigma: Union[int, float, Sequence[float]] = (0.1, 2.0)
) -> None:
super().__init__()
self.kernel_size = _setup_size(kernel_size, "Kernel size should be a tuple/list of two integers")
for ks in self.kernel_size:
if ks <= 0 or ks % 2 == 0:
raise ValueError("Kernel size value should be an odd and positive number.")
if isinstance(sigma, (int, float)):
if sigma <= 0:
raise ValueError("If sigma is a single number, it must be positive.")
sigma = float(sigma)
elif isinstance(sigma, Sequence) and len(sigma) == 2:
if not 0.0 < sigma[0] <= sigma[1]:
raise ValueError("sigma values should be positive and of the form (min, max).")
else:
raise TypeError("sigma should be a single int or float or a list/tuple with length 2 floats.")
self.sigma = _setup_float_or_seq(sigma, "sigma", 2)
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
sigma = torch.empty(1).uniform_(self.sigma[0], self.sigma[1]).item()
return dict(sigma=[sigma, sigma])
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.gaussian_blur(inpt, self.kernel_size, **params)
class ToDtype(Transform):
_transformed_types = (torch.Tensor,)
from torchvision import datapoints
from torchvision.transforms.v2 import Transform
def __init__(self, dtype: Union[torch.dtype, Dict[Type, Optional[torch.dtype]]]) -> None:
super().__init__()
if not isinstance(dtype, dict):
dtype = _get_defaultdict(dtype)
if torch.Tensor in dtype and any(cls in dtype for cls in [datapoints.Image, datapoints.Video]):
warnings.warn(
"Got `dtype` values for `torch.Tensor` and either `datapoints.Image` or `datapoints.Video`. "
"Note that a plain `torch.Tensor` will *not* be transformed by this (or any other transformation) "
"in case a `datapoints.Image` or `datapoints.Video` is present in the input."
)
self.dtype = dtype
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
dtype = self.dtype[type(inpt)]
if dtype is None:
return inpt
return inpt.to(dtype=dtype)
from torchvision.transforms.v2._utils import _get_defaultdict
from torchvision.transforms.v2.utils import is_simple_tensor
class PermuteDimensions(Transform):
......@@ -225,115 +56,3 @@ class TransposeDimensions(Transform):
if dims is None:
return inpt.as_subclass(torch.Tensor)
return inpt.transpose(*dims)
class SanitizeBoundingBoxes(Transform):
# This removes boxes and their corresponding labels:
# - small or degenerate bboxes based on min_size (this includes those where X2 <= X1 or Y2 <= Y1)
# - boxes with any coordinate outside the range of the image (negative, or > spatial_size)
def __init__(
self,
min_size: float = 1.0,
labels_getter: Union[Callable[[Any], Optional[torch.Tensor]], str, None] = "default",
) -> None:
super().__init__()
if min_size < 1:
raise ValueError(f"min_size must be >= 1, got {min_size}.")
self.min_size = min_size
self.labels_getter = labels_getter
self._labels_getter: Optional[Callable[[Any], Optional[torch.Tensor]]]
if labels_getter == "default":
self._labels_getter = self._find_labels_default_heuristic
elif callable(labels_getter):
self._labels_getter = labels_getter
elif isinstance(labels_getter, str):
self._labels_getter = lambda inputs: inputs[labels_getter]
elif labels_getter is None:
self._labels_getter = None
else:
raise ValueError(
"labels_getter should either be a str, callable, or 'default'. "
f"Got {labels_getter} of type {type(labels_getter)}."
)
@staticmethod
def _find_labels_default_heuristic(inputs: Dict[str, Any]) -> Optional[torch.Tensor]:
# Tries to find a "label" key, otherwise tries for the first key that contains "label" - case insensitive
# Returns None if nothing is found
candidate_key = None
with suppress(StopIteration):
candidate_key = next(key for key in inputs.keys() if key.lower() == "labels")
if candidate_key is None:
with suppress(StopIteration):
candidate_key = next(key for key in inputs.keys() if "label" in key.lower())
if candidate_key is None:
raise ValueError(
"Could not infer where the labels are in the sample. Try passing a callable as the labels_getter parameter?"
"If there are no samples and it is by design, pass labels_getter=None."
)
return inputs[candidate_key]
def forward(self, *inputs: Any) -> Any:
inputs = inputs if len(inputs) > 1 else inputs[0]
if isinstance(self.labels_getter, str) and not isinstance(inputs, collections.abc.Mapping):
raise ValueError(
f"If labels_getter is a str or 'default' (got {self.labels_getter}), "
f"then the input to forward() must be a dict. Got {type(inputs)} instead."
)
if self._labels_getter is None:
labels = None
else:
labels = self._labels_getter(inputs)
if labels is not None and not isinstance(labels, torch.Tensor):
raise ValueError(f"The labels in the input to forward() must be a tensor, got {type(labels)} instead.")
flat_inputs, spec = tree_flatten(inputs)
# TODO: this enforces one single BoundingBox entry.
# Assuming this transform needs to be called at the end of *any* pipeline that has bboxes...
# should we just enforce it for all transforms?? What are the benefits of *not* enforcing this?
boxes = query_bounding_box(flat_inputs)
if boxes.ndim != 2:
raise ValueError(f"boxes must be of shape (num_boxes, 4), got {boxes.shape}")
if labels is not None and boxes.shape[0] != labels.shape[0]:
raise ValueError(
f"Number of boxes (shape={boxes.shape}) and number of labels (shape={labels.shape}) do not match."
)
boxes = cast(
datapoints.BoundingBox,
F.convert_format_bounding_box(
boxes,
new_format=datapoints.BoundingBoxFormat.XYXY,
),
)
ws, hs = boxes[:, 2] - boxes[:, 0], boxes[:, 3] - boxes[:, 1]
mask = (ws >= self.min_size) & (hs >= self.min_size) & (boxes >= 0).all(dim=-1)
# TODO: Do we really need to check for out of bounds here? All
# transforms should be clamping anyway, so this should never happen?
image_h, image_w = boxes.spatial_size
mask &= (boxes[:, 0] <= image_w) & (boxes[:, 2] <= image_w)
mask &= (boxes[:, 1] <= image_h) & (boxes[:, 3] <= image_h)
params = dict(mask=mask, labels=labels)
flat_outputs = [
# Even-though it may look like we're transforming all inputs, we don't:
# _transform() will only care about BoundingBoxes and the labels
self._transform(inpt, params)
for inpt in flat_inputs
]
return tree_unflatten(flat_outputs, spec)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if (inpt is not None and inpt is params["labels"]) or isinstance(inpt, datapoints.BoundingBox):
inpt = inpt[params["mask"]]
return inpt
......@@ -9,9 +9,9 @@ import PIL.Image
import torch
from torch import Tensor
from torchvision.prototype.transforms.functional._geometry import _check_interpolation
from torchvision.transforms.v2 import functional as F, InterpolationMode
from . import functional as F, InterpolationMode
from torchvision.transforms.v2.functional._geometry import _check_interpolation
__all__ = ["StereoMatching"]
......
from typing import Any, Dict, Optional, Union
from typing import Any, Dict
import numpy as np
import PIL.Image
import torch
from torch.nn.functional import one_hot
from torchvision.prototype import datapoints
from torchvision.prototype.transforms import functional as F, Transform
from torchvision.prototype.transforms.utils import is_simple_tensor
from torchvision.prototype import datapoints as proto_datapoints
from torchvision.transforms.v2 import Transform
class LabelToOneHot(Transform):
_transformed_types = (datapoints.Label,)
_transformed_types = (proto_datapoints.Label,)
def __init__(self, num_categories: int = -1):
super().__init__()
self.num_categories = num_categories
def _transform(self, inpt: datapoints.Label, params: Dict[str, Any]) -> datapoints.OneHotLabel:
def _transform(self, inpt: proto_datapoints.Label, params: Dict[str, Any]) -> proto_datapoints.OneHotLabel:
num_categories = self.num_categories
if num_categories == -1 and inpt.categories is not None:
num_categories = len(inpt.categories)
output = one_hot(inpt.as_subclass(torch.Tensor), num_classes=num_categories)
return datapoints.OneHotLabel(output, categories=inpt.categories)
return proto_datapoints.OneHotLabel(output, categories=inpt.categories)
def extra_repr(self) -> str:
if self.num_categories == -1:
return ""
return f"num_categories={self.num_categories}"
class PILToTensor(Transform):
_transformed_types = (PIL.Image.Image,)
def _transform(self, inpt: Union[PIL.Image.Image], params: Dict[str, Any]) -> torch.Tensor:
return F.pil_to_tensor(inpt)
class ToImageTensor(Transform):
_transformed_types = (is_simple_tensor, PIL.Image.Image, np.ndarray)
def _transform(
self, inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray], params: Dict[str, Any]
) -> datapoints.Image:
return F.to_image_tensor(inpt)
class ToImagePIL(Transform):
_transformed_types = (is_simple_tensor, datapoints.Image, np.ndarray)
def __init__(self, mode: Optional[str] = None) -> None:
super().__init__()
self.mode = mode
def _transform(
self, inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray], params: Dict[str, Any]
) -> PIL.Image.Image:
return F.to_image_pil(inpt, mode=self.mode)
# We changed the name to align them with the new naming scheme. Still, `ToPILImage` is
# prevalent and well understood. Thus, we just alias it without deprecating the old name.
ToPILImage = ToImagePIL
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment