Unverified Commit 3991ab99 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Promote prototype transforms to beta status (#7261)


Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
Co-authored-by: default avatarvfdev-5 <vfdev.5@gmail.com>
parent d010e82f
from torchvision.transforms import AutoAugmentPolicy, InterpolationMode # usort: skip
from . import functional, utils # usort: skip
from ._transform import Transform # usort: skip
from ._augment import RandomErasing
from ._auto_augment import AugMix, AutoAugment, RandAugment, TrivialAugmentWide
from ._color import (
ColorJitter,
Grayscale,
RandomAdjustSharpness,
RandomAutocontrast,
RandomEqualize,
RandomGrayscale,
RandomInvert,
RandomPhotometricDistort,
RandomPosterize,
RandomSolarize,
)
from ._container import Compose, RandomApply, RandomChoice, RandomOrder
from ._geometry import (
CenterCrop,
ElasticTransform,
FiveCrop,
Pad,
RandomAffine,
RandomCrop,
RandomHorizontalFlip,
RandomIoUCrop,
RandomPerspective,
RandomResize,
RandomResizedCrop,
RandomRotation,
RandomShortestSize,
RandomVerticalFlip,
RandomZoomOut,
Resize,
ScaleJitter,
TenCrop,
)
from ._meta import ClampBoundingBox, ConvertBoundingBoxFormat, ConvertDtype, ConvertImageDtype
from ._misc import GaussianBlur, Identity, Lambda, LinearTransformation, Normalize, SanitizeBoundingBoxes, ToDtype
from ._temporal import UniformTemporalSubsample
from ._type_conversion import PILToTensor, ToImagePIL, ToImageTensor, ToPILImage
from ._deprecated import ToTensor # usort: skip
import math
import numbers
import warnings
from typing import Any, Dict, List, Tuple, Union
import PIL.Image
import torch
from torchvision import datapoints, transforms as _transforms
from torchvision.transforms.v2 import functional as F
from ._transform import _RandomApplyTransform
from .utils import is_simple_tensor, query_chw
class RandomErasing(_RandomApplyTransform):
_v1_transform_cls = _transforms.RandomErasing
def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
return dict(
super()._extract_params_for_v1_transform(),
value="random" if self.value is None else self.value,
)
_transformed_types = (is_simple_tensor, datapoints.Image, PIL.Image.Image, datapoints.Video)
def __init__(
self,
p: float = 0.5,
scale: Tuple[float, float] = (0.02, 0.33),
ratio: Tuple[float, float] = (0.3, 3.3),
value: float = 0.0,
inplace: bool = False,
):
super().__init__(p=p)
if not isinstance(value, (numbers.Number, str, tuple, list)):
raise TypeError("Argument value should be either a number or str or a sequence")
if isinstance(value, str) and value != "random":
raise ValueError("If value is str, it should be 'random'")
if not isinstance(scale, (tuple, list)):
raise TypeError("Scale should be a sequence")
if not isinstance(ratio, (tuple, list)):
raise TypeError("Ratio should be a sequence")
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
warnings.warn("Scale and ratio should be of kind (min, max)")
if scale[0] < 0 or scale[1] > 1:
raise ValueError("Scale should be between 0 and 1")
self.scale = scale
self.ratio = ratio
if isinstance(value, (int, float)):
self.value = [float(value)]
elif isinstance(value, str):
self.value = None
elif isinstance(value, (list, tuple)):
self.value = [float(v) for v in value]
else:
self.value = value
self.inplace = inplace
self._log_ratio = torch.log(torch.tensor(self.ratio))
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
img_c, img_h, img_w = query_chw(flat_inputs)
if self.value is not None and not (len(self.value) in (1, img_c)):
raise ValueError(
f"If value is a sequence, it should have either a single value or {img_c} (number of inpt channels)"
)
area = img_h * img_w
log_ratio = self._log_ratio
for _ in range(10):
erase_area = area * torch.empty(1).uniform_(self.scale[0], self.scale[1]).item()
aspect_ratio = torch.exp(
torch.empty(1).uniform_(
log_ratio[0], # type: ignore[arg-type]
log_ratio[1], # type: ignore[arg-type]
)
).item()
h = int(round(math.sqrt(erase_area * aspect_ratio)))
w = int(round(math.sqrt(erase_area / aspect_ratio)))
if not (h < img_h and w < img_w):
continue
if self.value is None:
v = torch.empty([img_c, h, w], dtype=torch.float32).normal_()
else:
v = torch.tensor(self.value)[:, None, None]
i = torch.randint(0, img_h - h + 1, size=(1,)).item()
j = torch.randint(0, img_w - w + 1, size=(1,)).item()
break
else:
i, j, h, w, v = 0, 0, img_h, img_w, None
return dict(i=i, j=j, h=h, w=w, v=v)
def _transform(
self, inpt: Union[datapoints.ImageType, datapoints.VideoType], params: Dict[str, Any]
) -> Union[datapoints.ImageType, datapoints.VideoType]:
if params["v"] is not None:
inpt = F.erase(inpt, **params, inplace=self.inplace)
return inpt
......@@ -5,12 +5,11 @@ import PIL.Image
import torch
from torch.utils._pytree import tree_flatten, tree_unflatten, TreeSpec
from torchvision import transforms as _transforms
from torchvision.prototype import datapoints
from torchvision.prototype.transforms import AutoAugmentPolicy, functional as F, InterpolationMode, Transform
from torchvision.prototype.transforms.functional._geometry import _check_interpolation
from torchvision.prototype.transforms.functional._meta import get_spatial_size
from torchvision import datapoints, transforms as _transforms
from torchvision.transforms import functional_tensor as _FT
from torchvision.transforms.v2 import AutoAugmentPolicy, functional as F, InterpolationMode, Transform
from torchvision.transforms.v2.functional._geometry import _check_interpolation
from torchvision.transforms.v2.functional._meta import get_spatial_size
from ._utils import _setup_fill_arg
from .utils import check_type, is_simple_tensor
......
......@@ -3,9 +3,8 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import PIL.Image
import torch
from torchvision import transforms as _transforms
from torchvision.prototype import datapoints
from torchvision.prototype.transforms import functional as F, Transform
from torchvision import datapoints, transforms as _transforms
from torchvision.transforms.v2 import functional as F, Transform
from ._transform import _RandomApplyTransform
from .utils import is_simple_tensor, query_chw
......
......@@ -5,7 +5,7 @@ import torch
from torch import nn
from torchvision import transforms as _transforms
from torchvision.prototype.transforms import Transform
from torchvision.transforms.v2 import Transform
class Compose(Transform):
......
......@@ -4,10 +4,10 @@ from typing import Any, Dict, Union
import numpy as np
import PIL.Image
import torch
from torchvision.prototype.transforms import Transform
from torchvision.transforms import functional as _F
from torchvision.transforms.v2 import Transform
class ToTensor(Transform):
_transformed_types = (PIL.Image.Image, np.ndarray)
......
import math
import numbers
import warnings
from typing import Any, cast, Dict, List, Literal, Optional, Sequence, Tuple, Type, Union
import PIL.Image
import torch
from torchvision import datapoints, transforms as _transforms
from torchvision.ops.boxes import box_iou
from torchvision.transforms.functional import _get_perspective_coeffs
from torchvision.transforms.v2 import functional as F, InterpolationMode, Transform
from torchvision.transforms.v2.functional._geometry import _check_interpolation
from ._transform import _RandomApplyTransform
from ._utils import (
_check_padding_arg,
_check_padding_mode_arg,
_check_sequence_input,
_setup_angle,
_setup_fill_arg,
_setup_float_or_seq,
_setup_size,
)
from .utils import has_all, has_any, is_simple_tensor, query_bounding_box, query_spatial_size
class RandomHorizontalFlip(_RandomApplyTransform):
_v1_transform_cls = _transforms.RandomHorizontalFlip
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.horizontal_flip(inpt)
class RandomVerticalFlip(_RandomApplyTransform):
_v1_transform_cls = _transforms.RandomVerticalFlip
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.vertical_flip(inpt)
class Resize(Transform):
_v1_transform_cls = _transforms.Resize
def __init__(
self,
size: Union[int, Sequence[int]],
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
max_size: Optional[int] = None,
antialias: Optional[Union[str, bool]] = "warn",
) -> None:
super().__init__()
if isinstance(size, int):
size = [size]
elif isinstance(size, (list, tuple)) and len(size) in {1, 2}:
size = list(size)
else:
raise ValueError(
f"size can either be an integer or a list or tuple of one or two integers, " f"but got {size} instead."
)
self.size = size
self.interpolation = _check_interpolation(interpolation)
self.max_size = max_size
self.antialias = antialias
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.resize(
inpt,
self.size,
interpolation=self.interpolation,
max_size=self.max_size,
antialias=self.antialias,
)
class CenterCrop(Transform):
_v1_transform_cls = _transforms.CenterCrop
def __init__(self, size: Union[int, Sequence[int]]):
super().__init__()
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.center_crop(inpt, output_size=self.size)
class RandomResizedCrop(Transform):
_v1_transform_cls = _transforms.RandomResizedCrop
def __init__(
self,
size: Union[int, Sequence[int]],
scale: Tuple[float, float] = (0.08, 1.0),
ratio: Tuple[float, float] = (3.0 / 4.0, 4.0 / 3.0),
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
antialias: Optional[Union[str, bool]] = "warn",
) -> None:
super().__init__()
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
if not isinstance(scale, Sequence):
raise TypeError("Scale should be a sequence")
scale = cast(Tuple[float, float], scale)
if not isinstance(ratio, Sequence):
raise TypeError("Ratio should be a sequence")
ratio = cast(Tuple[float, float], ratio)
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
warnings.warn("Scale and ratio should be of kind (min, max)")
self.scale = scale
self.ratio = ratio
self.interpolation = _check_interpolation(interpolation)
self.antialias = antialias
self._log_ratio = torch.log(torch.tensor(self.ratio))
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
height, width = query_spatial_size(flat_inputs)
area = height * width
log_ratio = self._log_ratio
for _ in range(10):
target_area = area * torch.empty(1).uniform_(self.scale[0], self.scale[1]).item()
aspect_ratio = torch.exp(
torch.empty(1).uniform_(
log_ratio[0], # type: ignore[arg-type]
log_ratio[1], # type: ignore[arg-type]
)
).item()
w = int(round(math.sqrt(target_area * aspect_ratio)))
h = int(round(math.sqrt(target_area / aspect_ratio)))
if 0 < w <= width and 0 < h <= height:
i = torch.randint(0, height - h + 1, size=(1,)).item()
j = torch.randint(0, width - w + 1, size=(1,)).item()
break
else:
# Fallback to central crop
in_ratio = float(width) / float(height)
if in_ratio < min(self.ratio):
w = width
h = int(round(w / min(self.ratio)))
elif in_ratio > max(self.ratio):
h = height
w = int(round(h * max(self.ratio)))
else: # whole image
w = width
h = height
i = (height - h) // 2
j = (width - w) // 2
return dict(top=i, left=j, height=h, width=w)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.resized_crop(
inpt, **params, size=self.size, interpolation=self.interpolation, antialias=self.antialias
)
ImageOrVideoTypeJIT = Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT]
class FiveCrop(Transform):
"""
Example:
>>> class BatchMultiCrop(transforms.Transform):
... def forward(self, sample: Tuple[Tuple[Union[datapoints.Image, datapoints.Video], ...], int]):
... images_or_videos, labels = sample
... batch_size = len(images_or_videos)
... image_or_video = images_or_videos[0]
... images_or_videos = image_or_video.wrap_like(image_or_video, torch.stack(images_or_videos))
... labels = torch.full((batch_size,), label, device=images_or_videos.device)
... return images_or_videos, labels
...
>>> image = datapoints.Image(torch.rand(3, 256, 256))
>>> label = 3
>>> transform = transforms.Compose([transforms.FiveCrop(224), BatchMultiCrop()])
>>> images, labels = transform(image, label)
>>> images.shape
torch.Size([5, 3, 224, 224])
>>> labels
tensor([3, 3, 3, 3, 3])
"""
_v1_transform_cls = _transforms.FiveCrop
_transformed_types = (
datapoints.Image,
PIL.Image.Image,
is_simple_tensor,
datapoints.Video,
)
def __init__(self, size: Union[int, Sequence[int]]) -> None:
super().__init__()
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
def _transform(
self, inpt: ImageOrVideoTypeJIT, params: Dict[str, Any]
) -> Tuple[ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT]:
return F.five_crop(inpt, self.size)
def _check_inputs(self, flat_inputs: List[Any]) -> None:
if has_any(flat_inputs, datapoints.BoundingBox, datapoints.Mask):
raise TypeError(f"BoundingBox'es and Mask's are not supported by {type(self).__name__}()")
class TenCrop(Transform):
"""
See :class:`~torchvision.transforms.v2.FiveCrop` for an example.
"""
_v1_transform_cls = _transforms.TenCrop
_transformed_types = (
datapoints.Image,
PIL.Image.Image,
is_simple_tensor,
datapoints.Video,
)
def __init__(self, size: Union[int, Sequence[int]], vertical_flip: bool = False) -> None:
super().__init__()
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
self.vertical_flip = vertical_flip
def _check_inputs(self, flat_inputs: List[Any]) -> None:
if has_any(flat_inputs, datapoints.BoundingBox, datapoints.Mask):
raise TypeError(f"BoundingBox'es and Mask's are not supported by {type(self).__name__}()")
def _transform(
self, inpt: Union[datapoints.ImageType, datapoints.VideoType], params: Dict[str, Any]
) -> Tuple[
ImageOrVideoTypeJIT,
ImageOrVideoTypeJIT,
ImageOrVideoTypeJIT,
ImageOrVideoTypeJIT,
ImageOrVideoTypeJIT,
ImageOrVideoTypeJIT,
ImageOrVideoTypeJIT,
ImageOrVideoTypeJIT,
ImageOrVideoTypeJIT,
ImageOrVideoTypeJIT,
]:
return F.ten_crop(inpt, self.size, vertical_flip=self.vertical_flip)
class Pad(Transform):
_v1_transform_cls = _transforms.Pad
def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
params = super()._extract_params_for_v1_transform()
if not (params["fill"] is None or isinstance(params["fill"], (int, float))):
raise ValueError(
f"{type(self.__name__)}() can only be scripted for a scalar `fill`, but got {self.fill} for images."
)
return params
def __init__(
self,
padding: Union[int, Sequence[int]],
fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = 0,
padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant",
) -> None:
super().__init__()
_check_padding_arg(padding)
_check_padding_mode_arg(padding_mode)
# This cast does Sequence[int] -> List[int] and is required to make mypy happy
if not isinstance(padding, int):
padding = list(padding)
self.padding = padding
self.fill = _setup_fill_arg(fill)
self.padding_mode = padding_mode
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self.fill[type(inpt)]
return F.pad(inpt, padding=self.padding, fill=fill, padding_mode=self.padding_mode) # type: ignore[arg-type]
class RandomZoomOut(_RandomApplyTransform):
def __init__(
self,
fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = 0,
side_range: Sequence[float] = (1.0, 4.0),
p: float = 0.5,
) -> None:
super().__init__(p=p)
self.fill = _setup_fill_arg(fill)
_check_sequence_input(side_range, "side_range", req_sizes=(2,))
self.side_range = side_range
if side_range[0] < 1.0 or side_range[0] > side_range[1]:
raise ValueError(f"Invalid canvas side range provided {side_range}.")
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
orig_h, orig_w = query_spatial_size(flat_inputs)
r = self.side_range[0] + torch.rand(1) * (self.side_range[1] - self.side_range[0])
canvas_width = int(orig_w * r)
canvas_height = int(orig_h * r)
r = torch.rand(2)
left = int((canvas_width - orig_w) * r[0])
top = int((canvas_height - orig_h) * r[1])
right = canvas_width - (left + orig_w)
bottom = canvas_height - (top + orig_h)
padding = [left, top, right, bottom]
return dict(padding=padding)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self.fill[type(inpt)]
return F.pad(inpt, **params, fill=fill)
class RandomRotation(Transform):
_v1_transform_cls = _transforms.RandomRotation
def __init__(
self,
degrees: Union[numbers.Number, Sequence],
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
expand: bool = False,
fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = 0,
center: Optional[List[float]] = None,
) -> None:
super().__init__()
self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,))
self.interpolation = _check_interpolation(interpolation)
self.expand = expand
self.fill = _setup_fill_arg(fill)
if center is not None:
_check_sequence_input(center, "center", req_sizes=(2,))
self.center = center
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
angle = torch.empty(1).uniform_(self.degrees[0], self.degrees[1]).item()
return dict(angle=angle)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self.fill[type(inpt)]
return F.rotate(
inpt,
**params,
interpolation=self.interpolation,
expand=self.expand,
center=self.center,
fill=fill,
)
class RandomAffine(Transform):
_v1_transform_cls = _transforms.RandomAffine
def __init__(
self,
degrees: Union[numbers.Number, Sequence],
translate: Optional[Sequence[float]] = None,
scale: Optional[Sequence[float]] = None,
shear: Optional[Union[int, float, Sequence[float]]] = None,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = 0,
center: Optional[List[float]] = None,
) -> None:
super().__init__()
self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,))
if translate is not None:
_check_sequence_input(translate, "translate", req_sizes=(2,))
for t in translate:
if not (0.0 <= t <= 1.0):
raise ValueError("translation values should be between 0 and 1")
self.translate = translate
if scale is not None:
_check_sequence_input(scale, "scale", req_sizes=(2,))
for s in scale:
if s <= 0:
raise ValueError("scale values should be positive")
self.scale = scale
if shear is not None:
self.shear = _setup_angle(shear, name="shear", req_sizes=(2, 4))
else:
self.shear = shear
self.interpolation = _check_interpolation(interpolation)
self.fill = _setup_fill_arg(fill)
if center is not None:
_check_sequence_input(center, "center", req_sizes=(2,))
self.center = center
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
height, width = query_spatial_size(flat_inputs)
angle = torch.empty(1).uniform_(self.degrees[0], self.degrees[1]).item()
if self.translate is not None:
max_dx = float(self.translate[0] * width)
max_dy = float(self.translate[1] * height)
tx = int(round(torch.empty(1).uniform_(-max_dx, max_dx).item()))
ty = int(round(torch.empty(1).uniform_(-max_dy, max_dy).item()))
translate = (tx, ty)
else:
translate = (0, 0)
if self.scale is not None:
scale = torch.empty(1).uniform_(self.scale[0], self.scale[1]).item()
else:
scale = 1.0
shear_x = shear_y = 0.0
if self.shear is not None:
shear_x = torch.empty(1).uniform_(self.shear[0], self.shear[1]).item()
if len(self.shear) == 4:
shear_y = torch.empty(1).uniform_(self.shear[2], self.shear[3]).item()
shear = (shear_x, shear_y)
return dict(angle=angle, translate=translate, scale=scale, shear=shear)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self.fill[type(inpt)]
return F.affine(
inpt,
**params,
interpolation=self.interpolation,
fill=fill,
center=self.center,
)
class RandomCrop(Transform):
_v1_transform_cls = _transforms.RandomCrop
def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
params = super()._extract_params_for_v1_transform()
if not (params["fill"] is None or isinstance(params["fill"], (int, float))):
raise ValueError(
f"{type(self.__name__)}() can only be scripted for a scalar `fill`, but got {self.fill} for images."
)
padding = self.padding
if padding is not None:
pad_left, pad_right, pad_top, pad_bottom = padding
padding = [pad_left, pad_top, pad_right, pad_bottom]
params["padding"] = padding
return params
def __init__(
self,
size: Union[int, Sequence[int]],
padding: Optional[Union[int, Sequence[int]]] = None,
pad_if_needed: bool = False,
fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = 0,
padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant",
) -> None:
super().__init__()
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
if pad_if_needed or padding is not None:
if padding is not None:
_check_padding_arg(padding)
_check_padding_mode_arg(padding_mode)
self.padding = F._geometry._parse_pad_padding(padding) if padding else None # type: ignore[arg-type]
self.pad_if_needed = pad_if_needed
self.fill = _setup_fill_arg(fill)
self.padding_mode = padding_mode
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
padded_height, padded_width = query_spatial_size(flat_inputs)
if self.padding is not None:
pad_left, pad_right, pad_top, pad_bottom = self.padding
padded_height += pad_top + pad_bottom
padded_width += pad_left + pad_right
else:
pad_left = pad_right = pad_top = pad_bottom = 0
cropped_height, cropped_width = self.size
if self.pad_if_needed:
if padded_height < cropped_height:
diff = cropped_height - padded_height
pad_top += diff
pad_bottom += diff
padded_height += 2 * diff
if padded_width < cropped_width:
diff = cropped_width - padded_width
pad_left += diff
pad_right += diff
padded_width += 2 * diff
if padded_height < cropped_height or padded_width < cropped_width:
raise ValueError(
f"Required crop size {(cropped_height, cropped_width)} is larger than "
f"{'padded ' if self.padding is not None else ''}input image size {(padded_height, padded_width)}."
)
# We need a different order here than we have in self.padding since this padding will be parsed again in `F.pad`
padding = [pad_left, pad_top, pad_right, pad_bottom]
needs_pad = any(padding)
needs_vert_crop, top = (
(True, int(torch.randint(0, padded_height - cropped_height + 1, size=())))
if padded_height > cropped_height
else (False, 0)
)
needs_horz_crop, left = (
(True, int(torch.randint(0, padded_width - cropped_width + 1, size=())))
if padded_width > cropped_width
else (False, 0)
)
return dict(
needs_crop=needs_vert_crop or needs_horz_crop,
top=top,
left=left,
height=cropped_height,
width=cropped_width,
needs_pad=needs_pad,
padding=padding,
)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if params["needs_pad"]:
fill = self.fill[type(inpt)]
inpt = F.pad(inpt, padding=params["padding"], fill=fill, padding_mode=self.padding_mode)
if params["needs_crop"]:
inpt = F.crop(inpt, top=params["top"], left=params["left"], height=params["height"], width=params["width"])
return inpt
class RandomPerspective(_RandomApplyTransform):
_v1_transform_cls = _transforms.RandomPerspective
def __init__(
self,
distortion_scale: float = 0.5,
fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = 0,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
p: float = 0.5,
) -> None:
super().__init__(p=p)
if not (0 <= distortion_scale <= 1):
raise ValueError("Argument distortion_scale value should be between 0 and 1")
self.distortion_scale = distortion_scale
self.interpolation = _check_interpolation(interpolation)
self.fill = _setup_fill_arg(fill)
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
height, width = query_spatial_size(flat_inputs)
distortion_scale = self.distortion_scale
half_height = height // 2
half_width = width // 2
bound_height = int(distortion_scale * half_height) + 1
bound_width = int(distortion_scale * half_width) + 1
topleft = [
int(torch.randint(0, bound_width, size=(1,))),
int(torch.randint(0, bound_height, size=(1,))),
]
topright = [
int(torch.randint(width - bound_width, width, size=(1,))),
int(torch.randint(0, bound_height, size=(1,))),
]
botright = [
int(torch.randint(width - bound_width, width, size=(1,))),
int(torch.randint(height - bound_height, height, size=(1,))),
]
botleft = [
int(torch.randint(0, bound_width, size=(1,))),
int(torch.randint(height - bound_height, height, size=(1,))),
]
startpoints = [[0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]]
endpoints = [topleft, topright, botright, botleft]
perspective_coeffs = _get_perspective_coeffs(startpoints, endpoints)
return dict(coefficients=perspective_coeffs)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self.fill[type(inpt)]
return F.perspective(
inpt,
None,
None,
fill=fill,
interpolation=self.interpolation,
**params,
)
class ElasticTransform(Transform):
_v1_transform_cls = _transforms.ElasticTransform
def __init__(
self,
alpha: Union[float, Sequence[float]] = 50.0,
sigma: Union[float, Sequence[float]] = 5.0,
fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = 0,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
) -> None:
super().__init__()
self.alpha = _setup_float_or_seq(alpha, "alpha", 2)
self.sigma = _setup_float_or_seq(sigma, "sigma", 2)
self.interpolation = _check_interpolation(interpolation)
self.fill = _setup_fill_arg(fill)
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
size = list(query_spatial_size(flat_inputs))
dx = torch.rand([1, 1] + size) * 2 - 1
if self.sigma[0] > 0.0:
kx = int(8 * self.sigma[0] + 1)
# if kernel size is even we have to make it odd
if kx % 2 == 0:
kx += 1
dx = F.gaussian_blur(dx, [kx, kx], list(self.sigma))
dx = dx * self.alpha[0] / size[0]
dy = torch.rand([1, 1] + size) * 2 - 1
if self.sigma[1] > 0.0:
ky = int(8 * self.sigma[1] + 1)
# if kernel size is even we have to make it odd
if ky % 2 == 0:
ky += 1
dy = F.gaussian_blur(dy, [ky, ky], list(self.sigma))
dy = dy * self.alpha[1] / size[1]
displacement = torch.concat([dx, dy], 1).permute([0, 2, 3, 1]) # 1 x H x W x 2
return dict(displacement=displacement)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self.fill[type(inpt)]
return F.elastic(
inpt,
**params,
fill=fill,
interpolation=self.interpolation,
)
class RandomIoUCrop(Transform):
def __init__(
self,
min_scale: float = 0.3,
max_scale: float = 1.0,
min_aspect_ratio: float = 0.5,
max_aspect_ratio: float = 2.0,
sampler_options: Optional[List[float]] = None,
trials: int = 40,
):
super().__init__()
# Configuration similar to https://github.com/weiliu89/caffe/blob/ssd/examples/ssd/ssd_coco.py#L89-L174
self.min_scale = min_scale
self.max_scale = max_scale
self.min_aspect_ratio = min_aspect_ratio
self.max_aspect_ratio = max_aspect_ratio
if sampler_options is None:
sampler_options = [0.0, 0.1, 0.3, 0.5, 0.7, 0.9, 1.0]
self.options = sampler_options
self.trials = trials
def _check_inputs(self, flat_inputs: List[Any]) -> None:
if not (
has_all(flat_inputs, datapoints.BoundingBox)
and has_any(flat_inputs, PIL.Image.Image, datapoints.Image, is_simple_tensor)
):
raise TypeError(
f"{type(self).__name__}() requires input sample to contain tensor or PIL images "
"and bounding boxes. Sample can also contain masks."
)
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
orig_h, orig_w = query_spatial_size(flat_inputs)
bboxes = query_bounding_box(flat_inputs)
while True:
# sample an option
idx = int(torch.randint(low=0, high=len(self.options), size=(1,)))
min_jaccard_overlap = self.options[idx]
if min_jaccard_overlap >= 1.0: # a value larger than 1 encodes the leave as-is option
return dict()
for _ in range(self.trials):
# check the aspect ratio limitations
r = self.min_scale + (self.max_scale - self.min_scale) * torch.rand(2)
new_w = int(orig_w * r[0])
new_h = int(orig_h * r[1])
aspect_ratio = new_w / new_h
if not (self.min_aspect_ratio <= aspect_ratio <= self.max_aspect_ratio):
continue
# check for 0 area crops
r = torch.rand(2)
left = int((orig_w - new_w) * r[0])
top = int((orig_h - new_h) * r[1])
right = left + new_w
bottom = top + new_h
if left == right or top == bottom:
continue
# FIXME: I think we can stop here?
# check for any valid boxes with centers within the crop area
xyxy_bboxes = F.convert_format_bounding_box(
bboxes.as_subclass(torch.Tensor), bboxes.format, datapoints.BoundingBoxFormat.XYXY
)
cx = 0.5 * (xyxy_bboxes[..., 0] + xyxy_bboxes[..., 2])
cy = 0.5 * (xyxy_bboxes[..., 1] + xyxy_bboxes[..., 3])
is_within_crop_area = (left < cx) & (cx < right) & (top < cy) & (cy < bottom)
if not is_within_crop_area.any():
continue
# check at least 1 box with jaccard limitations
xyxy_bboxes = xyxy_bboxes[is_within_crop_area]
ious = box_iou(
xyxy_bboxes,
torch.tensor([[left, top, right, bottom]], dtype=xyxy_bboxes.dtype, device=xyxy_bboxes.device),
)
if ious.max() < min_jaccard_overlap:
continue
return dict(top=top, left=left, height=new_h, width=new_w, is_within_crop_area=is_within_crop_area)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
# FIXME: refactor this to not remove anything
if len(params) < 1:
return inpt
is_within_crop_area = params["is_within_crop_area"]
output = F.crop(inpt, top=params["top"], left=params["left"], height=params["height"], width=params["width"])
if isinstance(output, datapoints.BoundingBox):
bboxes = output[is_within_crop_area]
bboxes = F.clamp_bounding_box(bboxes, output.format, output.spatial_size)
output = datapoints.BoundingBox.wrap_like(output, bboxes)
elif isinstance(output, datapoints.Mask):
# apply is_within_crop_area if mask is one-hot encoded
masks = output[is_within_crop_area]
output = datapoints.Mask.wrap_like(output, masks)
return output
class ScaleJitter(Transform):
def __init__(
self,
target_size: Tuple[int, int],
scale_range: Tuple[float, float] = (0.1, 2.0),
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
antialias: Optional[Union[str, bool]] = "warn",
):
super().__init__()
self.target_size = target_size
self.scale_range = scale_range
self.interpolation = _check_interpolation(interpolation)
self.antialias = antialias
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
orig_height, orig_width = query_spatial_size(flat_inputs)
scale = self.scale_range[0] + torch.rand(1) * (self.scale_range[1] - self.scale_range[0])
r = min(self.target_size[1] / orig_height, self.target_size[0] / orig_width) * scale
new_width = int(orig_width * r)
new_height = int(orig_height * r)
return dict(size=(new_height, new_width))
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.resize(inpt, size=params["size"], interpolation=self.interpolation, antialias=self.antialias)
class RandomShortestSize(Transform):
def __init__(
self,
min_size: Union[List[int], Tuple[int], int],
max_size: Optional[int] = None,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
antialias: Optional[Union[str, bool]] = "warn",
):
super().__init__()
self.min_size = [min_size] if isinstance(min_size, int) else list(min_size)
self.max_size = max_size
self.interpolation = _check_interpolation(interpolation)
self.antialias = antialias
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
orig_height, orig_width = query_spatial_size(flat_inputs)
min_size = self.min_size[int(torch.randint(len(self.min_size), ()))]
r = min_size / min(orig_height, orig_width)
if self.max_size is not None:
r = min(r, self.max_size / max(orig_height, orig_width))
new_width = int(orig_width * r)
new_height = int(orig_height * r)
return dict(size=(new_height, new_width))
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.resize(inpt, size=params["size"], interpolation=self.interpolation, antialias=self.antialias)
class RandomResize(Transform):
def __init__(
self,
min_size: int,
max_size: int,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
antialias: Optional[Union[str, bool]] = "warn",
) -> None:
super().__init__()
self.min_size = min_size
self.max_size = max_size
self.interpolation = _check_interpolation(interpolation)
self.antialias = antialias
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
size = int(torch.randint(self.min_size, self.max_size, ()))
return dict(size=[size])
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.resize(inpt, params["size"], interpolation=self.interpolation, antialias=self.antialias)
......@@ -2,9 +2,8 @@ from typing import Any, Dict, Union
import torch
from torchvision import transforms as _transforms
from torchvision.prototype import datapoints
from torchvision.prototype.transforms import functional as F, Transform
from torchvision import datapoints, transforms as _transforms
from torchvision.transforms.v2 import functional as F, Transform
from .utils import is_simple_tensor
......
import collections
import warnings
from contextlib import suppress
from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Type, Union
import PIL.Image
import torch
from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision import datapoints, transforms as _transforms
from torchvision.transforms.v2 import functional as F, Transform
from ._utils import _get_defaultdict, _setup_float_or_seq, _setup_size
from .utils import has_any, is_simple_tensor, query_bounding_box
class Identity(Transform):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return inpt
class Lambda(Transform):
def __init__(self, lambd: Callable[[Any], Any], *types: Type):
super().__init__()
self.lambd = lambd
self.types = types or (object,)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if isinstance(inpt, self.types):
return self.lambd(inpt)
else:
return inpt
def extra_repr(self) -> str:
extras = []
name = getattr(self.lambd, "__name__", None)
if name:
extras.append(name)
extras.append(f"types={[type.__name__ for type in self.types]}")
return ", ".join(extras)
class LinearTransformation(Transform):
_v1_transform_cls = _transforms.LinearTransformation
_transformed_types = (is_simple_tensor, datapoints.Image, datapoints.Video)
def __init__(self, transformation_matrix: torch.Tensor, mean_vector: torch.Tensor):
super().__init__()
if transformation_matrix.size(0) != transformation_matrix.size(1):
raise ValueError(
"transformation_matrix should be square. Got "
f"{tuple(transformation_matrix.size())} rectangular matrix."
)
if mean_vector.size(0) != transformation_matrix.size(0):
raise ValueError(
f"mean_vector should have the same length {mean_vector.size(0)}"
f" as any one of the dimensions of the transformation_matrix [{tuple(transformation_matrix.size())}]"
)
if transformation_matrix.device != mean_vector.device:
raise ValueError(
f"Input tensors should be on the same device. Got {transformation_matrix.device} and {mean_vector.device}"
)
if transformation_matrix.dtype != mean_vector.dtype:
raise ValueError(
f"Input tensors should have the same dtype. Got {transformation_matrix.dtype} and {mean_vector.dtype}"
)
self.transformation_matrix = transformation_matrix
self.mean_vector = mean_vector
def _check_inputs(self, sample: Any) -> Any:
if has_any(sample, PIL.Image.Image):
raise TypeError("LinearTransformation does not work on PIL Images")
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
shape = inpt.shape
n = shape[-3] * shape[-2] * shape[-1]
if n != self.transformation_matrix.shape[0]:
raise ValueError(
"Input tensor and transformation matrix have incompatible shape."
+ f"[{shape[-3]} x {shape[-2]} x {shape[-1]}] != "
+ f"{self.transformation_matrix.shape[0]}"
)
if inpt.device.type != self.mean_vector.device.type:
raise ValueError(
"Input tensor should be on the same device as transformation matrix and mean vector. "
f"Got {inpt.device} vs {self.mean_vector.device}"
)
flat_inpt = inpt.reshape(-1, n) - self.mean_vector
transformation_matrix = self.transformation_matrix.to(flat_inpt.dtype)
output = torch.mm(flat_inpt, transformation_matrix)
output = output.reshape(shape)
if isinstance(inpt, (datapoints.Image, datapoints.Video)):
output = type(inpt).wrap_like(inpt, output) # type: ignore[arg-type]
return output
class Normalize(Transform):
_v1_transform_cls = _transforms.Normalize
_transformed_types = (datapoints.Image, is_simple_tensor, datapoints.Video)
def __init__(self, mean: Sequence[float], std: Sequence[float], inplace: bool = False):
super().__init__()
self.mean = list(mean)
self.std = list(std)
self.inplace = inplace
def _check_inputs(self, sample: Any) -> Any:
if has_any(sample, PIL.Image.Image):
raise TypeError(f"{type(self).__name__}() does not support PIL images.")
def _transform(
self, inpt: Union[datapoints.TensorImageType, datapoints.TensorVideoType], params: Dict[str, Any]
) -> Any:
return F.normalize(inpt, mean=self.mean, std=self.std, inplace=self.inplace)
class GaussianBlur(Transform):
_v1_transform_cls = _transforms.GaussianBlur
def __init__(
self, kernel_size: Union[int, Sequence[int]], sigma: Union[int, float, Sequence[float]] = (0.1, 2.0)
) -> None:
super().__init__()
self.kernel_size = _setup_size(kernel_size, "Kernel size should be a tuple/list of two integers")
for ks in self.kernel_size:
if ks <= 0 or ks % 2 == 0:
raise ValueError("Kernel size value should be an odd and positive number.")
if isinstance(sigma, (int, float)):
if sigma <= 0:
raise ValueError("If sigma is a single number, it must be positive.")
sigma = float(sigma)
elif isinstance(sigma, Sequence) and len(sigma) == 2:
if not 0.0 < sigma[0] <= sigma[1]:
raise ValueError("sigma values should be positive and of the form (min, max).")
else:
raise TypeError("sigma should be a single int or float or a list/tuple with length 2 floats.")
self.sigma = _setup_float_or_seq(sigma, "sigma", 2)
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
sigma = torch.empty(1).uniform_(self.sigma[0], self.sigma[1]).item()
return dict(sigma=[sigma, sigma])
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.gaussian_blur(inpt, self.kernel_size, **params)
class ToDtype(Transform):
_transformed_types = (torch.Tensor,)
def __init__(self, dtype: Union[torch.dtype, Dict[Type, Optional[torch.dtype]]]) -> None:
super().__init__()
if not isinstance(dtype, dict):
dtype = _get_defaultdict(dtype)
if torch.Tensor in dtype and any(cls in dtype for cls in [datapoints.Image, datapoints.Video]):
warnings.warn(
"Got `dtype` values for `torch.Tensor` and either `datapoints.Image` or `datapoints.Video`. "
"Note that a plain `torch.Tensor` will *not* be transformed by this (or any other transformation) "
"in case a `datapoints.Image` or `datapoints.Video` is present in the input."
)
self.dtype = dtype
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
dtype = self.dtype[type(inpt)]
if dtype is None:
return inpt
return inpt.to(dtype=dtype)
class SanitizeBoundingBoxes(Transform):
# This removes boxes and their corresponding labels:
# - small or degenerate bboxes based on min_size (this includes those where X2 <= X1 or Y2 <= Y1)
# - boxes with any coordinate outside the range of the image (negative, or > spatial_size)
def __init__(
self,
min_size: float = 1.0,
labels_getter: Union[Callable[[Any], Optional[torch.Tensor]], str, None] = "default",
) -> None:
super().__init__()
if min_size < 1:
raise ValueError(f"min_size must be >= 1, got {min_size}.")
self.min_size = min_size
self.labels_getter = labels_getter
self._labels_getter: Optional[Callable[[Any], Optional[torch.Tensor]]]
if labels_getter == "default":
self._labels_getter = self._find_labels_default_heuristic
elif callable(labels_getter):
self._labels_getter = labels_getter
elif isinstance(labels_getter, str):
self._labels_getter = lambda inputs: inputs[labels_getter]
elif labels_getter is None:
self._labels_getter = None
else:
raise ValueError(
"labels_getter should either be a str, callable, or 'default'. "
f"Got {labels_getter} of type {type(labels_getter)}."
)
@staticmethod
def _find_labels_default_heuristic(inputs: Dict[str, Any]) -> Optional[torch.Tensor]:
# Tries to find a "label" key, otherwise tries for the first key that contains "label" - case insensitive
# Returns None if nothing is found
candidate_key = None
with suppress(StopIteration):
candidate_key = next(key for key in inputs.keys() if key.lower() == "labels")
if candidate_key is None:
with suppress(StopIteration):
candidate_key = next(key for key in inputs.keys() if "label" in key.lower())
if candidate_key is None:
raise ValueError(
"Could not infer where the labels are in the sample. Try passing a callable as the labels_getter parameter?"
"If there are no samples and it is by design, pass labels_getter=None."
)
return inputs[candidate_key]
def forward(self, *inputs: Any) -> Any:
inputs = inputs if len(inputs) > 1 else inputs[0]
if isinstance(self.labels_getter, str) and not isinstance(inputs, collections.abc.Mapping):
raise ValueError(
f"If labels_getter is a str or 'default' (got {self.labels_getter}), "
f"then the input to forward() must be a dict. Got {type(inputs)} instead."
)
if self._labels_getter is None:
labels = None
else:
labels = self._labels_getter(inputs)
if labels is not None and not isinstance(labels, torch.Tensor):
raise ValueError(f"The labels in the input to forward() must be a tensor, got {type(labels)} instead.")
flat_inputs, spec = tree_flatten(inputs)
# TODO: this enforces one single BoundingBox entry.
# Assuming this transform needs to be called at the end of *any* pipeline that has bboxes...
# should we just enforce it for all transforms?? What are the benefits of *not* enforcing this?
boxes = query_bounding_box(flat_inputs)
if boxes.ndim != 2:
raise ValueError(f"boxes must be of shape (num_boxes, 4), got {boxes.shape}")
if labels is not None and boxes.shape[0] != labels.shape[0]:
raise ValueError(
f"Number of boxes (shape={boxes.shape}) and number of labels (shape={labels.shape}) do not match."
)
boxes = cast(
datapoints.BoundingBox,
F.convert_format_bounding_box(
boxes,
new_format=datapoints.BoundingBoxFormat.XYXY,
),
)
ws, hs = boxes[:, 2] - boxes[:, 0], boxes[:, 3] - boxes[:, 1]
mask = (ws >= self.min_size) & (hs >= self.min_size) & (boxes >= 0).all(dim=-1)
# TODO: Do we really need to check for out of bounds here? All
# transforms should be clamping anyway, so this should never happen?
image_h, image_w = boxes.spatial_size
mask &= (boxes[:, 0] <= image_w) & (boxes[:, 2] <= image_w)
mask &= (boxes[:, 1] <= image_h) & (boxes[:, 3] <= image_h)
params = dict(mask=mask, labels=labels)
flat_outputs = [
# Even-though it may look like we're transforming all inputs, we don't:
# _transform() will only care about BoundingBoxes and the labels
self._transform(inpt, params)
for inpt in flat_inputs
]
return tree_unflatten(flat_outputs, spec)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if (inpt is not None and inpt is params["labels"]) or isinstance(inpt, datapoints.BoundingBox):
inpt = inpt[params["mask"]]
return inpt
from typing import Any, Dict
from torchvision.prototype import datapoints
from torchvision.prototype.transforms import functional as F, Transform
from torchvision import datapoints
from torchvision.transforms.v2 import functional as F, Transform
from torchvision.prototype.transforms.utils import is_simple_tensor
from torchvision.transforms.v2.utils import is_simple_tensor
class UniformTemporalSubsample(Transform):
_transformed_types = (is_simple_tensor, datapoints.Video)
def __init__(self, num_samples: int, temporal_dim: int = -4):
def __init__(self, num_samples: int):
super().__init__()
self.num_samples = num_samples
self.temporal_dim = temporal_dim
def _transform(self, inpt: datapoints.VideoType, params: Dict[str, Any]) -> datapoints.VideoType:
return F.uniform_temporal_subsample(inpt, self.num_samples, temporal_dim=self.temporal_dim)
return F.uniform_temporal_subsample(inpt, self.num_samples)
......@@ -7,8 +7,8 @@ import PIL.Image
import torch
from torch import nn
from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision.prototype import datapoints
from torchvision.prototype.transforms.utils import check_type, has_any, is_simple_tensor
from torchvision import datapoints
from torchvision.transforms.v2.utils import check_type, has_any, is_simple_tensor
from torchvision.utils import _log_api_usage_once
......
from typing import Any, Dict, Optional, Union
import numpy as np
import PIL.Image
import torch
from torchvision import datapoints
from torchvision.transforms.v2 import functional as F, Transform
from torchvision.transforms.v2.utils import is_simple_tensor
class PILToTensor(Transform):
_transformed_types = (PIL.Image.Image,)
def _transform(self, inpt: PIL.Image.Image, params: Dict[str, Any]) -> torch.Tensor:
return F.pil_to_tensor(inpt)
class ToImageTensor(Transform):
_transformed_types = (is_simple_tensor, PIL.Image.Image, np.ndarray)
def _transform(
self, inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray], params: Dict[str, Any]
) -> datapoints.Image:
return F.to_image_tensor(inpt)
class ToImagePIL(Transform):
_transformed_types = (is_simple_tensor, datapoints.Image, np.ndarray)
def __init__(self, mode: Optional[str] = None) -> None:
super().__init__()
self.mode = mode
def _transform(
self, inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray], params: Dict[str, Any]
) -> PIL.Image.Image:
return F.to_image_pil(inpt, mode=self.mode)
# We changed the name to align them with the new naming scheme. Still, `ToPILImage` is
# prevalent and well understood. Thus, we just alias it without deprecating the old name.
ToPILImage = ToImagePIL
......@@ -3,8 +3,8 @@ import numbers
from collections import defaultdict
from typing import Any, Dict, Literal, Sequence, Type, TypeVar, Union
from torchvision.prototype import datapoints
from torchvision.prototype.datapoints._datapoint import FillType, FillTypeJIT
from torchvision import datapoints
from torchvision.datapoints._datapoint import FillType, FillTypeJIT
from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401
......
......@@ -3,7 +3,7 @@ from typing import Union
import PIL.Image
import torch
from torchvision.prototype import datapoints
from torchvision import datapoints
from torchvision.transforms.functional import pil_to_tensor, to_pil_image
from torchvision.utils import _log_api_usage_once
......
......@@ -3,7 +3,7 @@ from typing import Union
import PIL.Image
import torch
from torch.nn.functional import conv2d
from torchvision.prototype import datapoints
from torchvision import datapoints
from torchvision.transforms import functional_pil as _FP
from torchvision.transforms.functional_tensor import _max_value
......
......@@ -4,7 +4,7 @@ from typing import Any, List, Union
import PIL.Image
import torch
from torchvision.prototype import datapoints
from torchvision import datapoints
from torchvision.transforms import functional as _F
......
......@@ -7,7 +7,7 @@ import PIL.Image
import torch
from torch.nn.functional import grid_sample, interpolate, pad as torch_pad
from torchvision.prototype import datapoints
from torchvision import datapoints
from torchvision.transforms import functional_pil as _FP
from torchvision.transforms.functional import (
_check_antialias,
......
......@@ -2,8 +2,8 @@ from typing import List, Optional, Tuple, Union
import PIL.Image
import torch
from torchvision.prototype import datapoints
from torchvision.prototype.datapoints import BoundingBoxFormat
from torchvision import datapoints
from torchvision.datapoints import BoundingBoxFormat
from torchvision.transforms import functional_pil as _FP
from torchvision.transforms.functional_tensor import _max_value
......
......@@ -5,7 +5,7 @@ import PIL.Image
import torch
from torch.nn.functional import conv2d, pad as torch_pad
from torchvision.prototype import datapoints
from torchvision import datapoints
from torchvision.transforms.functional import pil_to_tensor, to_pil_image
from torchvision.utils import _log_api_usage_once
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment