"src/libtorchaudio/rir/rir.cpp" did not exist on "8c5c9a9bbf1dc3da60a6b89069d1775fc347fca1"
Unverified Commit 3991ab99 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Promote prototype transforms to beta status (#7261)


Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
Co-authored-by: default avatarvfdev-5 <vfdev.5@gmail.com>
parent d010e82f
...@@ -4,7 +4,8 @@ from typing import Any, BinaryIO, Dict, Iterator, List, Optional, Sequence, Tupl ...@@ -4,7 +4,8 @@ from typing import Any, BinaryIO, Dict, Iterator, List, Optional, Sequence, Tupl
import torch import torch
from torchdata.datapipes.iter import Filter, IterDataPipe, IterKeyZipper, Mapper, Zipper from torchdata.datapipes.iter import Filter, IterDataPipe, IterKeyZipper, Mapper, Zipper
from torchvision.prototype.datapoints import BoundingBox, Label from torchvision.datapoints import BoundingBox
from torchvision.prototype.datapoints import Label
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, GDriveResource, OnlineResource from torchvision.prototype.datasets.utils import Dataset, EncodedImage, GDriveResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import ( from torchvision.prototype.datasets.utils._internal import (
getitem, getitem,
......
...@@ -6,7 +6,8 @@ from typing import Any, BinaryIO, cast, Dict, Iterator, List, Optional, Tuple, U ...@@ -6,7 +6,8 @@ from typing import Any, BinaryIO, cast, Dict, Iterator, List, Optional, Tuple, U
import numpy as np import numpy as np
from torchdata.datapipes.iter import Filter, IterDataPipe, Mapper from torchdata.datapipes.iter import Filter, IterDataPipe, Mapper
from torchvision.prototype.datapoints import Image, Label from torchvision.datapoints import Image
from torchvision.prototype.datapoints import Label
from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import ( from torchvision.prototype.datasets.utils._internal import (
hint_sharding, hint_sharding,
......
...@@ -14,7 +14,8 @@ from torchdata.datapipes.iter import ( ...@@ -14,7 +14,8 @@ from torchdata.datapipes.iter import (
Mapper, Mapper,
UnBatcher, UnBatcher,
) )
from torchvision.prototype.datapoints import BoundingBox, Label, Mask from torchvision.datapoints import BoundingBox, Mask
from torchvision.prototype.datapoints import Label
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import ( from torchvision.prototype.datasets.utils._internal import (
getitem, getitem,
......
...@@ -15,7 +15,8 @@ from torchdata.datapipes.iter import ( ...@@ -15,7 +15,8 @@ from torchdata.datapipes.iter import (
Mapper, Mapper,
) )
from torchdata.datapipes.map import IterToMapConverter from torchdata.datapipes.map import IterToMapConverter
from torchvision.prototype.datapoints import BoundingBox, Label from torchvision.datapoints import BoundingBox
from torchvision.prototype.datapoints import Label
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, GDriveResource, OnlineResource from torchvision.prototype.datasets.utils import Dataset, EncodedImage, GDriveResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import ( from torchvision.prototype.datasets.utils._internal import (
getitem, getitem,
......
...@@ -3,7 +3,8 @@ from typing import Any, Dict, List, Union ...@@ -3,7 +3,8 @@ from typing import Any, Dict, List, Union
import torch import torch
from torchdata.datapipes.iter import CSVDictParser, IterDataPipe, Mapper from torchdata.datapipes.iter import CSVDictParser, IterDataPipe, Mapper
from torchvision.prototype.datapoints import Image, Label from torchvision.datapoints import Image
from torchvision.prototype.datapoints import Label
from torchvision.prototype.datasets.utils import Dataset, KaggleDownloadResource, OnlineResource from torchvision.prototype.datasets.utils import Dataset, KaggleDownloadResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling
......
...@@ -2,7 +2,8 @@ import pathlib ...@@ -2,7 +2,8 @@ import pathlib
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
from torchdata.datapipes.iter import CSVDictParser, Demultiplexer, Filter, IterDataPipe, Mapper, Zipper from torchdata.datapipes.iter import CSVDictParser, Demultiplexer, Filter, IterDataPipe, Mapper, Zipper
from torchvision.prototype.datapoints import BoundingBox, Label from torchvision.datapoints import BoundingBox
from torchvision.prototype.datapoints import Label
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import ( from torchvision.prototype.datasets.utils._internal import (
hint_sharding, hint_sharding,
......
...@@ -7,7 +7,8 @@ from typing import Any, BinaryIO, cast, Dict, Iterator, List, Optional, Sequence ...@@ -7,7 +7,8 @@ from typing import Any, BinaryIO, cast, Dict, Iterator, List, Optional, Sequence
import torch import torch
from torchdata.datapipes.iter import Decompressor, Demultiplexer, IterDataPipe, Mapper, Zipper from torchdata.datapipes.iter import Decompressor, Demultiplexer, IterDataPipe, Mapper, Zipper
from torchvision.prototype.datapoints import Image, Label from torchvision.datapoints import Image
from torchvision.prototype.datapoints import Label
from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling, INFINITE_BUFFER_SIZE from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling, INFINITE_BUFFER_SIZE
from torchvision.prototype.utils._internal import fromfile from torchvision.prototype.utils._internal import fromfile
......
...@@ -4,7 +4,8 @@ from collections import namedtuple ...@@ -4,7 +4,8 @@ from collections import namedtuple
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
from torchdata.datapipes.iter import IterDataPipe, Mapper, Zipper from torchdata.datapipes.iter import IterDataPipe, Mapper, Zipper
from torchvision.prototype.datapoints import Image, Label from torchvision.datapoints import Image
from torchvision.prototype.datapoints import Label
from torchvision.prototype.datasets.utils import Dataset, GDriveResource, OnlineResource from torchvision.prototype.datasets.utils import Dataset, GDriveResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling
......
...@@ -3,7 +3,8 @@ from typing import Any, Dict, List, Tuple, Union ...@@ -3,7 +3,8 @@ from typing import Any, Dict, List, Tuple, Union
import torch import torch
from torchdata.datapipes.iter import CSVParser, IterDataPipe, Mapper from torchdata.datapipes.iter import CSVParser, IterDataPipe, Mapper
from torchvision.prototype.datapoints import Image, OneHotLabel from torchvision.datapoints import Image
from torchvision.prototype.datapoints import OneHotLabel
from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling
......
...@@ -2,7 +2,8 @@ import pathlib ...@@ -2,7 +2,8 @@ import pathlib
from typing import Any, BinaryIO, Dict, Iterator, List, Tuple, Union from typing import Any, BinaryIO, Dict, Iterator, List, Tuple, Union
from torchdata.datapipes.iter import Filter, IterDataPipe, Mapper, Zipper from torchdata.datapipes.iter import Filter, IterDataPipe, Mapper, Zipper
from torchvision.prototype.datapoints import BoundingBox, Label from torchvision.datapoints import BoundingBox
from torchvision.prototype.datapoints import Label
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import ( from torchvision.prototype.datasets.utils._internal import (
hint_sharding, hint_sharding,
......
...@@ -3,7 +3,8 @@ from typing import Any, BinaryIO, Dict, List, Tuple, Union ...@@ -3,7 +3,8 @@ from typing import Any, BinaryIO, Dict, List, Tuple, Union
import numpy as np import numpy as np
from torchdata.datapipes.iter import IterDataPipe, Mapper, UnBatcher from torchdata.datapipes.iter import IterDataPipe, Mapper, UnBatcher
from torchvision.prototype.datapoints import Image, Label from torchvision.datapoints import Image
from torchvision.prototype.datapoints import Label
from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling, read_mat from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling, read_mat
......
...@@ -3,7 +3,8 @@ from typing import Any, Dict, List, Union ...@@ -3,7 +3,8 @@ from typing import Any, Dict, List, Union
import torch import torch
from torchdata.datapipes.iter import Decompressor, IterDataPipe, LineReader, Mapper from torchdata.datapipes.iter import Decompressor, IterDataPipe, LineReader, Mapper
from torchvision.prototype.datapoints import Image, Label from torchvision.datapoints import Image
from torchvision.prototype.datapoints import Label
from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling
......
...@@ -5,8 +5,9 @@ from typing import Any, BinaryIO, cast, Dict, List, Optional, Tuple, Union ...@@ -5,8 +5,9 @@ from typing import Any, BinaryIO, cast, Dict, List, Optional, Tuple, Union
from xml.etree import ElementTree from xml.etree import ElementTree
from torchdata.datapipes.iter import Demultiplexer, Filter, IterDataPipe, IterKeyZipper, LineReader, Mapper from torchdata.datapipes.iter import Demultiplexer, Filter, IterDataPipe, IterKeyZipper, LineReader, Mapper
from torchvision.datapoints import BoundingBox
from torchvision.datasets import VOCDetection from torchvision.datasets import VOCDetection
from torchvision.prototype.datapoints import BoundingBox, Label from torchvision.prototype.datapoints import Label
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import ( from torchvision.prototype.datasets.utils._internal import (
getitem, getitem,
......
...@@ -7,7 +7,7 @@ from typing import Any, BinaryIO, Optional, Tuple, Type, TypeVar, Union ...@@ -7,7 +7,7 @@ from typing import Any, BinaryIO, Optional, Tuple, Type, TypeVar, Union
import PIL.Image import PIL.Image
import torch import torch
from torchvision.prototype.datapoints._datapoint import Datapoint from torchvision.datapoints._datapoint import Datapoint
from torchvision.prototype.utils._internal import fromfile, ReadOnlyTensorBuffer from torchvision.prototype.utils._internal import fromfile, ReadOnlyTensorBuffer
D = TypeVar("D", bound="EncodedData") D = TypeVar("D", bound="EncodedData")
......
from torchvision.transforms import AutoAugmentPolicy, InterpolationMode # usort: skip
from . import functional, utils # usort: skip
from ._transform import Transform # usort: skip
from ._presets import StereoMatching # usort: skip from ._presets import StereoMatching # usort: skip
from ._augment import RandomCutmix, RandomErasing, RandomMixup, SimpleCopyPaste from ._augment import RandomCutmix, RandomMixup, SimpleCopyPaste
from ._auto_augment import AugMix, AutoAugment, RandAugment, TrivialAugmentWide from ._geometry import FixedSizeCrop
from ._color import ( from ._misc import PermuteDimensions, TransposeDimensions
ColorJitter, from ._type_conversion import LabelToOneHot
Grayscale,
RandomAdjustSharpness,
RandomAutocontrast,
RandomEqualize,
RandomGrayscale,
RandomInvert,
RandomPhotometricDistort,
RandomPosterize,
RandomSolarize,
)
from ._container import Compose, RandomApply, RandomChoice, RandomOrder
from ._geometry import (
CenterCrop,
ElasticTransform,
FiveCrop,
FixedSizeCrop,
Pad,
RandomAffine,
RandomCrop,
RandomHorizontalFlip,
RandomIoUCrop,
RandomPerspective,
RandomResize,
RandomResizedCrop,
RandomRotation,
RandomShortestSize,
RandomVerticalFlip,
RandomZoomOut,
Resize,
ScaleJitter,
TenCrop,
)
from ._meta import ClampBoundingBox, ConvertBoundingBoxFormat, ConvertDtype, ConvertImageDtype
from ._misc import (
GaussianBlur,
Identity,
Lambda,
LinearTransformation,
Normalize,
PermuteDimensions,
SanitizeBoundingBoxes,
ToDtype,
TransposeDimensions,
)
from ._temporal import UniformTemporalSubsample
from ._type_conversion import LabelToOneHot, PILToTensor, ToImagePIL, ToImageTensor, ToPILImage
from ._deprecated import ToTensor # usort: skip
import math import math
import numbers
import warnings
from typing import Any, cast, Dict, List, Optional, Tuple, Union from typing import Any, cast, Dict, List, Optional, Tuple, Union
import PIL.Image import PIL.Image
import torch import torch
from torch.utils._pytree import tree_flatten, tree_unflatten from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision import transforms as _transforms from torchvision import datapoints
from torchvision.ops import masks_to_boxes from torchvision.ops import masks_to_boxes
from torchvision.prototype import datapoints from torchvision.prototype import datapoints as proto_datapoints
from torchvision.prototype.transforms import functional as F, InterpolationMode, Transform from torchvision.transforms.v2 import functional as F, InterpolationMode, Transform
from torchvision.prototype.transforms.functional._geometry import _check_interpolation
from ._transform import _RandomApplyTransform from torchvision.transforms.v2._transform import _RandomApplyTransform
from .utils import has_any, is_simple_tensor, query_chw, query_spatial_size from torchvision.transforms.v2.functional._geometry import _check_interpolation
from torchvision.transforms.v2.utils import has_any, is_simple_tensor, query_spatial_size
class RandomErasing(_RandomApplyTransform):
_v1_transform_cls = _transforms.RandomErasing
def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
return dict(
super()._extract_params_for_v1_transform(),
value="random" if self.value is None else self.value,
)
_transformed_types = (is_simple_tensor, datapoints.Image, PIL.Image.Image, datapoints.Video)
def __init__(
self,
p: float = 0.5,
scale: Tuple[float, float] = (0.02, 0.33),
ratio: Tuple[float, float] = (0.3, 3.3),
value: float = 0.0,
inplace: bool = False,
):
super().__init__(p=p)
if not isinstance(value, (numbers.Number, str, tuple, list)):
raise TypeError("Argument value should be either a number or str or a sequence")
if isinstance(value, str) and value != "random":
raise ValueError("If value is str, it should be 'random'")
if not isinstance(scale, (tuple, list)):
raise TypeError("Scale should be a sequence")
if not isinstance(ratio, (tuple, list)):
raise TypeError("Ratio should be a sequence")
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
warnings.warn("Scale and ratio should be of kind (min, max)")
if scale[0] < 0 or scale[1] > 1:
raise ValueError("Scale should be between 0 and 1")
self.scale = scale
self.ratio = ratio
if isinstance(value, (int, float)):
self.value = [float(value)]
elif isinstance(value, str):
self.value = None
elif isinstance(value, (list, tuple)):
self.value = [float(v) for v in value]
else:
self.value = value
self.inplace = inplace
self._log_ratio = torch.log(torch.tensor(self.ratio))
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
img_c, img_h, img_w = query_chw(flat_inputs)
if self.value is not None and not (len(self.value) in (1, img_c)):
raise ValueError(
f"If value is a sequence, it should have either a single value or {img_c} (number of inpt channels)"
)
area = img_h * img_w
log_ratio = self._log_ratio
for _ in range(10):
erase_area = area * torch.empty(1).uniform_(self.scale[0], self.scale[1]).item()
aspect_ratio = torch.exp(
torch.empty(1).uniform_(
log_ratio[0], # type: ignore[arg-type]
log_ratio[1], # type: ignore[arg-type]
)
).item()
h = int(round(math.sqrt(erase_area * aspect_ratio)))
w = int(round(math.sqrt(erase_area / aspect_ratio)))
if not (h < img_h and w < img_w):
continue
if self.value is None:
v = torch.empty([img_c, h, w], dtype=torch.float32).normal_()
else:
v = torch.tensor(self.value)[:, None, None]
i = torch.randint(0, img_h - h + 1, size=(1,)).item()
j = torch.randint(0, img_w - w + 1, size=(1,)).item()
break
else:
i, j, h, w, v = 0, 0, img_h, img_w, None
return dict(i=i, j=j, h=h, w=w, v=v)
def _transform(
self, inpt: Union[datapoints.ImageType, datapoints.VideoType], params: Dict[str, Any]
) -> Union[datapoints.ImageType, datapoints.VideoType]:
if params["v"] is not None:
inpt = F.erase(inpt, **params, inplace=self.inplace)
return inpt
class _BaseMixupCutmix(_RandomApplyTransform): class _BaseMixupCutmix(_RandomApplyTransform):
...@@ -118,19 +23,19 @@ class _BaseMixupCutmix(_RandomApplyTransform): ...@@ -118,19 +23,19 @@ class _BaseMixupCutmix(_RandomApplyTransform):
def _check_inputs(self, flat_inputs: List[Any]) -> None: def _check_inputs(self, flat_inputs: List[Any]) -> None:
if not ( if not (
has_any(flat_inputs, datapoints.Image, datapoints.Video, is_simple_tensor) has_any(flat_inputs, datapoints.Image, datapoints.Video, is_simple_tensor)
and has_any(flat_inputs, datapoints.OneHotLabel) and has_any(flat_inputs, proto_datapoints.OneHotLabel)
): ):
raise TypeError(f"{type(self).__name__}() is only defined for tensor images/videos and one-hot labels.") raise TypeError(f"{type(self).__name__}() is only defined for tensor images/videos and one-hot labels.")
if has_any(flat_inputs, PIL.Image.Image, datapoints.BoundingBox, datapoints.Mask, datapoints.Label): if has_any(flat_inputs, PIL.Image.Image, datapoints.BoundingBox, datapoints.Mask, proto_datapoints.Label):
raise TypeError( raise TypeError(
f"{type(self).__name__}() does not support PIL images, bounding boxes, masks and plain labels." f"{type(self).__name__}() does not support PIL images, bounding boxes, masks and plain labels."
) )
def _mixup_onehotlabel(self, inpt: datapoints.OneHotLabel, lam: float) -> datapoints.OneHotLabel: def _mixup_onehotlabel(self, inpt: proto_datapoints.OneHotLabel, lam: float) -> proto_datapoints.OneHotLabel:
if inpt.ndim < 2: if inpt.ndim < 2:
raise ValueError("Need a batch of one hot labels") raise ValueError("Need a batch of one hot labels")
output = inpt.roll(1, 0).mul_(1.0 - lam).add_(inpt.mul(lam)) output = inpt.roll(1, 0).mul_(1.0 - lam).add_(inpt.mul(lam))
return datapoints.OneHotLabel.wrap_like(inpt, output) return proto_datapoints.OneHotLabel.wrap_like(inpt, output)
class RandomMixup(_BaseMixupCutmix): class RandomMixup(_BaseMixupCutmix):
...@@ -149,7 +54,7 @@ class RandomMixup(_BaseMixupCutmix): ...@@ -149,7 +54,7 @@ class RandomMixup(_BaseMixupCutmix):
output = type(inpt).wrap_like(inpt, output) # type: ignore[arg-type] output = type(inpt).wrap_like(inpt, output) # type: ignore[arg-type]
return output return output
elif isinstance(inpt, datapoints.OneHotLabel): elif isinstance(inpt, proto_datapoints.OneHotLabel):
return self._mixup_onehotlabel(inpt, lam) return self._mixup_onehotlabel(inpt, lam)
else: else:
return inpt return inpt
...@@ -193,7 +98,7 @@ class RandomCutmix(_BaseMixupCutmix): ...@@ -193,7 +98,7 @@ class RandomCutmix(_BaseMixupCutmix):
output = inpt.wrap_like(inpt, output) # type: ignore[arg-type] output = inpt.wrap_like(inpt, output) # type: ignore[arg-type]
return output return output
elif isinstance(inpt, datapoints.OneHotLabel): elif isinstance(inpt, proto_datapoints.OneHotLabel):
lam_adjusted = params["lam_adjusted"] lam_adjusted = params["lam_adjusted"]
return self._mixup_onehotlabel(inpt, lam_adjusted) return self._mixup_onehotlabel(inpt, lam_adjusted)
else: else:
...@@ -307,7 +212,7 @@ class SimpleCopyPaste(Transform): ...@@ -307,7 +212,7 @@ class SimpleCopyPaste(Transform):
bboxes.append(obj) bboxes.append(obj)
elif isinstance(obj, datapoints.Mask): elif isinstance(obj, datapoints.Mask):
masks.append(obj) masks.append(obj)
elif isinstance(obj, (datapoints.Label, datapoints.OneHotLabel)): elif isinstance(obj, (proto_datapoints.Label, proto_datapoints.OneHotLabel)):
labels.append(obj) labels.append(obj)
if not (len(images) == len(bboxes) == len(masks) == len(labels)): if not (len(images) == len(bboxes) == len(masks) == len(labels)):
...@@ -345,7 +250,7 @@ class SimpleCopyPaste(Transform): ...@@ -345,7 +250,7 @@ class SimpleCopyPaste(Transform):
elif isinstance(obj, datapoints.Mask): elif isinstance(obj, datapoints.Mask):
flat_sample[i] = datapoints.Mask.wrap_like(obj, output_targets[c2]["masks"]) flat_sample[i] = datapoints.Mask.wrap_like(obj, output_targets[c2]["masks"])
c2 += 1 c2 += 1
elif isinstance(obj, (datapoints.Label, datapoints.OneHotLabel)): elif isinstance(obj, (proto_datapoints.Label, proto_datapoints.OneHotLabel)):
flat_sample[i] = obj.wrap_like(obj, output_targets[c3]["labels"]) # type: ignore[arg-type] flat_sample[i] = obj.wrap_like(obj, output_targets[c3]["labels"]) # type: ignore[arg-type]
c3 += 1 c3 += 1
......
import collections
import warnings import warnings
from contextlib import suppress from typing import Any, Dict, Optional, Sequence, Tuple, Type, Union
from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Tuple, Type, Union
import PIL.Image
import torch import torch
from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision import transforms as _transforms
from torchvision.prototype import datapoints
from torchvision.prototype.transforms import functional as F, Transform
from ._utils import _get_defaultdict, _setup_float_or_seq, _setup_size
from .utils import has_any, is_simple_tensor, query_bounding_box
class Identity(Transform):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return inpt
class Lambda(Transform):
def __init__(self, lambd: Callable[[Any], Any], *types: Type):
super().__init__()
self.lambd = lambd
self.types = types or (object,)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if isinstance(inpt, self.types):
return self.lambd(inpt)
else:
return inpt
def extra_repr(self) -> str:
extras = []
name = getattr(self.lambd, "__name__", None)
if name:
extras.append(name)
extras.append(f"types={[type.__name__ for type in self.types]}")
return ", ".join(extras)
class LinearTransformation(Transform):
_v1_transform_cls = _transforms.LinearTransformation
_transformed_types = (is_simple_tensor, datapoints.Image, datapoints.Video)
def __init__(self, transformation_matrix: torch.Tensor, mean_vector: torch.Tensor):
super().__init__()
if transformation_matrix.size(0) != transformation_matrix.size(1):
raise ValueError(
"transformation_matrix should be square. Got "
f"{tuple(transformation_matrix.size())} rectangular matrix."
)
if mean_vector.size(0) != transformation_matrix.size(0):
raise ValueError(
f"mean_vector should have the same length {mean_vector.size(0)}"
f" as any one of the dimensions of the transformation_matrix [{tuple(transformation_matrix.size())}]"
)
if transformation_matrix.device != mean_vector.device:
raise ValueError(
f"Input tensors should be on the same device. Got {transformation_matrix.device} and {mean_vector.device}"
)
if transformation_matrix.dtype != mean_vector.dtype:
raise ValueError(
f"Input tensors should have the same dtype. Got {transformation_matrix.dtype} and {mean_vector.dtype}"
)
self.transformation_matrix = transformation_matrix
self.mean_vector = mean_vector
def _check_inputs(self, sample: Any) -> Any:
if has_any(sample, PIL.Image.Image):
raise TypeError("LinearTransformation does not work on PIL Images")
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
shape = inpt.shape
n = shape[-3] * shape[-2] * shape[-1]
if n != self.transformation_matrix.shape[0]:
raise ValueError(
"Input tensor and transformation matrix have incompatible shape."
+ f"[{shape[-3]} x {shape[-2]} x {shape[-1]}] != "
+ f"{self.transformation_matrix.shape[0]}"
)
if inpt.device.type != self.mean_vector.device.type:
raise ValueError(
"Input tensor should be on the same device as transformation matrix and mean vector. "
f"Got {inpt.device} vs {self.mean_vector.device}"
)
flat_inpt = inpt.reshape(-1, n) - self.mean_vector
transformation_matrix = self.transformation_matrix.to(flat_inpt.dtype)
output = torch.mm(flat_inpt, transformation_matrix)
output = output.reshape(shape)
if isinstance(inpt, (datapoints.Image, datapoints.Video)):
output = type(inpt).wrap_like(inpt, output) # type: ignore[arg-type]
return output
class Normalize(Transform):
_v1_transform_cls = _transforms.Normalize
_transformed_types = (datapoints.Image, is_simple_tensor, datapoints.Video)
def __init__(self, mean: Sequence[float], std: Sequence[float], inplace: bool = False):
super().__init__()
self.mean = list(mean)
self.std = list(std)
self.inplace = inplace
def _check_inputs(self, sample: Any) -> Any:
if has_any(sample, PIL.Image.Image):
raise TypeError(f"{type(self).__name__}() does not support PIL images.")
def _transform(
self, inpt: Union[datapoints.TensorImageType, datapoints.TensorVideoType], params: Dict[str, Any]
) -> Any:
return F.normalize(inpt, mean=self.mean, std=self.std, inplace=self.inplace)
class GaussianBlur(Transform):
_v1_transform_cls = _transforms.GaussianBlur
def __init__(
self, kernel_size: Union[int, Sequence[int]], sigma: Union[int, float, Sequence[float]] = (0.1, 2.0)
) -> None:
super().__init__()
self.kernel_size = _setup_size(kernel_size, "Kernel size should be a tuple/list of two integers")
for ks in self.kernel_size:
if ks <= 0 or ks % 2 == 0:
raise ValueError("Kernel size value should be an odd and positive number.")
if isinstance(sigma, (int, float)):
if sigma <= 0:
raise ValueError("If sigma is a single number, it must be positive.")
sigma = float(sigma)
elif isinstance(sigma, Sequence) and len(sigma) == 2:
if not 0.0 < sigma[0] <= sigma[1]:
raise ValueError("sigma values should be positive and of the form (min, max).")
else:
raise TypeError("sigma should be a single int or float or a list/tuple with length 2 floats.")
self.sigma = _setup_float_or_seq(sigma, "sigma", 2)
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
sigma = torch.empty(1).uniform_(self.sigma[0], self.sigma[1]).item()
return dict(sigma=[sigma, sigma])
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.gaussian_blur(inpt, self.kernel_size, **params)
class ToDtype(Transform): from torchvision import datapoints
_transformed_types = (torch.Tensor,) from torchvision.transforms.v2 import Transform
def __init__(self, dtype: Union[torch.dtype, Dict[Type, Optional[torch.dtype]]]) -> None: from torchvision.transforms.v2._utils import _get_defaultdict
super().__init__() from torchvision.transforms.v2.utils import is_simple_tensor
if not isinstance(dtype, dict):
dtype = _get_defaultdict(dtype)
if torch.Tensor in dtype and any(cls in dtype for cls in [datapoints.Image, datapoints.Video]):
warnings.warn(
"Got `dtype` values for `torch.Tensor` and either `datapoints.Image` or `datapoints.Video`. "
"Note that a plain `torch.Tensor` will *not* be transformed by this (or any other transformation) "
"in case a `datapoints.Image` or `datapoints.Video` is present in the input."
)
self.dtype = dtype
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
dtype = self.dtype[type(inpt)]
if dtype is None:
return inpt
return inpt.to(dtype=dtype)
class PermuteDimensions(Transform): class PermuteDimensions(Transform):
...@@ -225,115 +56,3 @@ class TransposeDimensions(Transform): ...@@ -225,115 +56,3 @@ class TransposeDimensions(Transform):
if dims is None: if dims is None:
return inpt.as_subclass(torch.Tensor) return inpt.as_subclass(torch.Tensor)
return inpt.transpose(*dims) return inpt.transpose(*dims)
class SanitizeBoundingBoxes(Transform):
# This removes boxes and their corresponding labels:
# - small or degenerate bboxes based on min_size (this includes those where X2 <= X1 or Y2 <= Y1)
# - boxes with any coordinate outside the range of the image (negative, or > spatial_size)
def __init__(
self,
min_size: float = 1.0,
labels_getter: Union[Callable[[Any], Optional[torch.Tensor]], str, None] = "default",
) -> None:
super().__init__()
if min_size < 1:
raise ValueError(f"min_size must be >= 1, got {min_size}.")
self.min_size = min_size
self.labels_getter = labels_getter
self._labels_getter: Optional[Callable[[Any], Optional[torch.Tensor]]]
if labels_getter == "default":
self._labels_getter = self._find_labels_default_heuristic
elif callable(labels_getter):
self._labels_getter = labels_getter
elif isinstance(labels_getter, str):
self._labels_getter = lambda inputs: inputs[labels_getter]
elif labels_getter is None:
self._labels_getter = None
else:
raise ValueError(
"labels_getter should either be a str, callable, or 'default'. "
f"Got {labels_getter} of type {type(labels_getter)}."
)
@staticmethod
def _find_labels_default_heuristic(inputs: Dict[str, Any]) -> Optional[torch.Tensor]:
# Tries to find a "label" key, otherwise tries for the first key that contains "label" - case insensitive
# Returns None if nothing is found
candidate_key = None
with suppress(StopIteration):
candidate_key = next(key for key in inputs.keys() if key.lower() == "labels")
if candidate_key is None:
with suppress(StopIteration):
candidate_key = next(key for key in inputs.keys() if "label" in key.lower())
if candidate_key is None:
raise ValueError(
"Could not infer where the labels are in the sample. Try passing a callable as the labels_getter parameter?"
"If there are no samples and it is by design, pass labels_getter=None."
)
return inputs[candidate_key]
def forward(self, *inputs: Any) -> Any:
inputs = inputs if len(inputs) > 1 else inputs[0]
if isinstance(self.labels_getter, str) and not isinstance(inputs, collections.abc.Mapping):
raise ValueError(
f"If labels_getter is a str or 'default' (got {self.labels_getter}), "
f"then the input to forward() must be a dict. Got {type(inputs)} instead."
)
if self._labels_getter is None:
labels = None
else:
labels = self._labels_getter(inputs)
if labels is not None and not isinstance(labels, torch.Tensor):
raise ValueError(f"The labels in the input to forward() must be a tensor, got {type(labels)} instead.")
flat_inputs, spec = tree_flatten(inputs)
# TODO: this enforces one single BoundingBox entry.
# Assuming this transform needs to be called at the end of *any* pipeline that has bboxes...
# should we just enforce it for all transforms?? What are the benefits of *not* enforcing this?
boxes = query_bounding_box(flat_inputs)
if boxes.ndim != 2:
raise ValueError(f"boxes must be of shape (num_boxes, 4), got {boxes.shape}")
if labels is not None and boxes.shape[0] != labels.shape[0]:
raise ValueError(
f"Number of boxes (shape={boxes.shape}) and number of labels (shape={labels.shape}) do not match."
)
boxes = cast(
datapoints.BoundingBox,
F.convert_format_bounding_box(
boxes,
new_format=datapoints.BoundingBoxFormat.XYXY,
),
)
ws, hs = boxes[:, 2] - boxes[:, 0], boxes[:, 3] - boxes[:, 1]
mask = (ws >= self.min_size) & (hs >= self.min_size) & (boxes >= 0).all(dim=-1)
# TODO: Do we really need to check for out of bounds here? All
# transforms should be clamping anyway, so this should never happen?
image_h, image_w = boxes.spatial_size
mask &= (boxes[:, 0] <= image_w) & (boxes[:, 2] <= image_w)
mask &= (boxes[:, 1] <= image_h) & (boxes[:, 3] <= image_h)
params = dict(mask=mask, labels=labels)
flat_outputs = [
# Even-though it may look like we're transforming all inputs, we don't:
# _transform() will only care about BoundingBoxes and the labels
self._transform(inpt, params)
for inpt in flat_inputs
]
return tree_unflatten(flat_outputs, spec)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if (inpt is not None and inpt is params["labels"]) or isinstance(inpt, datapoints.BoundingBox):
inpt = inpt[params["mask"]]
return inpt
...@@ -9,9 +9,9 @@ import PIL.Image ...@@ -9,9 +9,9 @@ import PIL.Image
import torch import torch
from torch import Tensor from torch import Tensor
from torchvision.prototype.transforms.functional._geometry import _check_interpolation from torchvision.transforms.v2 import functional as F, InterpolationMode
from . import functional as F, InterpolationMode from torchvision.transforms.v2.functional._geometry import _check_interpolation
__all__ = ["StereoMatching"] __all__ = ["StereoMatching"]
......
from typing import Any, Dict, Optional, Union from typing import Any, Dict
import numpy as np
import PIL.Image
import torch import torch
from torch.nn.functional import one_hot from torch.nn.functional import one_hot
from torchvision.prototype import datapoints from torchvision.prototype import datapoints as proto_datapoints
from torchvision.prototype.transforms import functional as F, Transform from torchvision.transforms.v2 import Transform
from torchvision.prototype.transforms.utils import is_simple_tensor
class LabelToOneHot(Transform): class LabelToOneHot(Transform):
_transformed_types = (datapoints.Label,) _transformed_types = (proto_datapoints.Label,)
def __init__(self, num_categories: int = -1): def __init__(self, num_categories: int = -1):
super().__init__() super().__init__()
self.num_categories = num_categories self.num_categories = num_categories
def _transform(self, inpt: datapoints.Label, params: Dict[str, Any]) -> datapoints.OneHotLabel: def _transform(self, inpt: proto_datapoints.Label, params: Dict[str, Any]) -> proto_datapoints.OneHotLabel:
num_categories = self.num_categories num_categories = self.num_categories
if num_categories == -1 and inpt.categories is not None: if num_categories == -1 and inpt.categories is not None:
num_categories = len(inpt.categories) num_categories = len(inpt.categories)
output = one_hot(inpt.as_subclass(torch.Tensor), num_classes=num_categories) output = one_hot(inpt.as_subclass(torch.Tensor), num_classes=num_categories)
return datapoints.OneHotLabel(output, categories=inpt.categories) return proto_datapoints.OneHotLabel(output, categories=inpt.categories)
def extra_repr(self) -> str: def extra_repr(self) -> str:
if self.num_categories == -1: if self.num_categories == -1:
return "" return ""
return f"num_categories={self.num_categories}" return f"num_categories={self.num_categories}"
class PILToTensor(Transform):
_transformed_types = (PIL.Image.Image,)
def _transform(self, inpt: Union[PIL.Image.Image], params: Dict[str, Any]) -> torch.Tensor:
return F.pil_to_tensor(inpt)
class ToImageTensor(Transform):
_transformed_types = (is_simple_tensor, PIL.Image.Image, np.ndarray)
def _transform(
self, inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray], params: Dict[str, Any]
) -> datapoints.Image:
return F.to_image_tensor(inpt)
class ToImagePIL(Transform):
_transformed_types = (is_simple_tensor, datapoints.Image, np.ndarray)
def __init__(self, mode: Optional[str] = None) -> None:
super().__init__()
self.mode = mode
def _transform(
self, inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray], params: Dict[str, Any]
) -> PIL.Image.Image:
return F.to_image_pil(inpt, mode=self.mode)
# We changed the name to align them with the new naming scheme. Still, `ToPILImage` is
# prevalent and well understood. Thus, we just alias it without deprecating the old name.
ToPILImage = ToImagePIL
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment